diff --git a/internal/files/files_test.go b/internal/files/files_test.go index 6a06b53..1e34cf4 100644 --- a/internal/files/files_test.go +++ b/internal/files/files_test.go @@ -141,6 +141,27 @@ func TestQuotaEnforced(t *testing.T) { } } +func TestWebSaveOverQuotaPreservesExistingFile(t *testing.T) { + svc, _, u := newTestService(t) + sess, _ := svc.newSession(u) + sess.quota = 5 + + if _, err := sess.webSave("/me/note.txt", strings.NewReader("ok")); err != nil { + t.Fatalf("initial save failed: %v", err) + } + if _, err := sess.webSave("/me/note.txt", strings.NewReader("too-large")); err != errQuota { + t.Fatalf("over-quota replace: got %v, want errQuota", err) + } + + got, err := os.ReadFile(filepath.Join(svc.privRoot(u.Name), "note.txt")) + if err != nil { + t.Fatalf("existing file should remain after failed replace: %v", err) + } + if string(got) != "ok" { + t.Fatalf("existing file changed after failed replace: %q", got) + } +} + func TestUsage(t *testing.T) { svc, _, u := newTestService(t) if err := svc.ensureWorkspace(u.Name); err != nil { diff --git a/internal/files/fs.go b/internal/files/fs.go index 267a095..12af575 100644 --- a/internal/files/fs.go +++ b/internal/files/fs.go @@ -305,10 +305,13 @@ func (s *session) webSave(vpath string, r io.Reader) (int64, error) { } existing = fi.Size() } - f, err := os.OpenFile(res.real, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + dir := filepath.Dir(res.real) + f, err := os.CreateTemp(dir, ".upload-*") if err != nil { return 0, err } + tmpName := f.Name() + defer func() { _ = os.Remove(tmpName) }() limit := int64(-1) if metered(res.area) { if limit = s.quota - (s.used.Load() - existing); limit < 0 { @@ -318,18 +321,51 @@ func (s *session) webSave(vpath string, r io.Reader) (int64, error) { n, werr := copyLimited(f, r, limit) cerr := f.Close() if werr != nil { - _ = os.Remove(res.real) return 0, werr } if cerr != nil { return 0, cerr } + if err := replaceFile(tmpName, res.real); err != nil { + return 0, err + } if metered(res.area) { s.used.Add(n - existing) } return n, nil } +func replaceFile(src, dst string) error { + renameErr := os.Rename(src, dst) + if renameErr == nil { + return nil + } + if _, err := os.Stat(dst); err != nil { + return renameErr + } + dir := filepath.Dir(dst) + backup, err := os.CreateTemp(dir, ".replace-*") + if err != nil { + return err + } + backupName := backup.Name() + if err := backup.Close(); err != nil { + _ = os.Remove(backupName) + return err + } + if err := os.Remove(backupName); err != nil { + return err + } + if err := os.Rename(dst, backupName); err != nil { + return err + } + if err := os.Rename(src, dst); err != nil { + _ = os.Rename(backupName, dst) + return err + } + return os.Remove(backupName) +} + // webMkdir and webRemove reuse the SFTP-side guards (root/writable checks). func (s *session) webMkdir(vpath string) error { return s.mkdir(vpath) } func (s *session) webRemove(vpath string) error { return s.remove(vpath) }