Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions internal/files/files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,27 @@ func TestQuotaEnforced(t *testing.T) {
}
}

func TestWebSaveOverQuotaPreservesExistingFile(t *testing.T) {
svc, _, u := newTestService(t)
sess, _ := svc.newSession(u)
sess.quota = 5

if _, err := sess.webSave("/me/note.txt", strings.NewReader("ok")); err != nil {
t.Fatalf("initial save failed: %v", err)
}
if _, err := sess.webSave("/me/note.txt", strings.NewReader("too-large")); err != errQuota {
t.Fatalf("over-quota replace: got %v, want errQuota", err)
}

got, err := os.ReadFile(filepath.Join(svc.privRoot(u.Name), "note.txt"))
if err != nil {
t.Fatalf("existing file should remain after failed replace: %v", err)
}
if string(got) != "ok" {
t.Fatalf("existing file changed after failed replace: %q", got)
}
}

func TestUsage(t *testing.T) {
svc, _, u := newTestService(t)
if err := svc.ensureWorkspace(u.Name); err != nil {
Expand Down
40 changes: 38 additions & 2 deletions internal/files/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,13 @@ func (s *session) webSave(vpath string, r io.Reader) (int64, error) {
}
existing = fi.Size()
}
f, err := os.OpenFile(res.real, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
dir := filepath.Dir(res.real)
f, err := os.CreateTemp(dir, ".upload-*")
if err != nil {
return 0, err
}
tmpName := f.Name()
defer func() { _ = os.Remove(tmpName) }()
limit := int64(-1)
if metered(res.area) {
if limit = s.quota - (s.used.Load() - existing); limit < 0 {
Expand All @@ -318,18 +321,51 @@ func (s *session) webSave(vpath string, r io.Reader) (int64, error) {
n, werr := copyLimited(f, r, limit)
cerr := f.Close()
if werr != nil {
_ = os.Remove(res.real)
return 0, werr
}
if cerr != nil {
return 0, cerr
}
if err := replaceFile(tmpName, res.real); err != nil {
return 0, err
}
if metered(res.area) {
s.used.Add(n - existing)
}
return n, nil
}

func replaceFile(src, dst string) error {
renameErr := os.Rename(src, dst)
if renameErr == nil {
return nil
}
if _, err := os.Stat(dst); err != nil {
return renameErr
}
dir := filepath.Dir(dst)
backup, err := os.CreateTemp(dir, ".replace-*")
if err != nil {
return err
}
backupName := backup.Name()
if err := backup.Close(); err != nil {
_ = os.Remove(backupName)
return err
}
if err := os.Remove(backupName); err != nil {
return err
}
if err := os.Rename(dst, backupName); err != nil {
return err
}
if err := os.Rename(src, dst); err != nil {
_ = os.Rename(backupName, dst)
return err
}
return os.Remove(backupName)
}

// webMkdir and webRemove reuse the SFTP-side guards (root/writable checks).
func (s *session) webMkdir(vpath string) error { return s.mkdir(vpath) }
func (s *session) webRemove(vpath string) error { return s.remove(vpath) }
Expand Down
Loading