From bb65f097fdb0f0ba9652bb65682676243b255aea Mon Sep 17 00:00:00 2001 From: Alexander Simmerl Date: Mon, 19 Mar 2018 09:38:28 +0100 Subject: [PATCH] Simplify WriteFileAtomic We can make the implementation more robust by adjusting our assumptions and leverage explicit file modes for syncing. Additionally we going to assume that we want to clean up and can't really recover if thos operations (file close and removal) fail. * utilise file mode for majority of concerns * improve test coverage by covering more assumptions * signature parity with ioutil.WriteFile * always clean up Replaces #160 --- common/os.go | 44 ++++++++++++++++++++++---------------------- common/os_test.go | 34 ++++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/common/os.go b/common/os.go index 36fc969f..f1e07115 100644 --- a/common/os.go +++ b/common/os.go @@ -124,32 +124,32 @@ func MustWriteFile(filePath string, contents []byte, mode os.FileMode) { } } -// WriteFileAtomic writes newBytes to temp and atomically moves to filePath -// when everything else succeeds. -func WriteFileAtomic(filePath string, newBytes []byte, mode os.FileMode) error { - dir := filepath.Dir(filePath) - f, err := ioutil.TempFile(dir, "") +// WriteFileAtomic creates a temporary file with data and the perm given and +// swaps it atomically with filename if successful. +func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { + var ( + dir = filepath.Dir(filename) + tempFile = filepath.Join(dir, "write-file-atomic-"+RandStr(32)) + // Override in case it does exist, create in case it doesn't and force kernel + // flush, which still leaves the potential of lingering disk cache. + flag = os.O_WRONLY | os.O_CREATE | os.O_SYNC | os.O_TRUNC + ) + + f, err := os.OpenFile(tempFile, flag, perm) if err != nil { return err } - _, err = f.Write(newBytes) - if err == nil { - err = f.Sync() + // Clean up in any case. Defer stacking order is last-in-first-out. + defer os.Remove(f.Name()) + defer f.Close() + + if n, err := f.Write(data); err != nil { + return err + } else if n < len(data) { + return io.ErrShortWrite } - if closeErr := f.Close(); err == nil { - err = closeErr - } - if permErr := os.Chmod(f.Name(), mode); err == nil { - err = permErr - } - if err == nil { - err = os.Rename(f.Name(), filePath) - } - // any err should result in full cleanup - if err != nil { - os.Remove(f.Name()) - } - return err + + return os.Rename(f.Name(), filename) } //-------------------------------------------------------------------------------- diff --git a/common/os_test.go b/common/os_test.go index 126723aa..97ad672b 100644 --- a/common/os_test.go +++ b/common/os_test.go @@ -2,30 +2,52 @@ package common import ( "bytes" - "fmt" "io/ioutil" + "math/rand" "os" "testing" "time" ) func TestWriteFileAtomic(t *testing.T) { - data := []byte("Becatron") - fname := fmt.Sprintf("/tmp/write-file-atomic-test-%v.txt", time.Now().UnixNano()) - err := WriteFileAtomic(fname, data, 0664) + var ( + seed = rand.New(rand.NewSource(time.Now().UnixNano())) + data = []byte(RandStr(seed.Intn(2048))) + old = RandBytes(seed.Intn(2048)) + perm os.FileMode = 0600 + ) + + f, err := ioutil.TempFile("/tmp", "write-atomic-test-") if err != nil { t.Fatal(err) } - rData, err := ioutil.ReadFile(fname) + defer os.Remove(f.Name()) + + if err := ioutil.WriteFile(f.Name(), old, 0664); err != nil { + t.Fatal(err) + } + + if err := WriteFileAtomic(f.Name(), data, perm); err != nil { + t.Fatal(err) + } + + rData, err := ioutil.ReadFile(f.Name()) if err != nil { t.Fatal(err) } + if !bytes.Equal(data, rData) { t.Fatalf("data mismatch: %v != %v", data, rData) } - if err := os.Remove(fname); err != nil { + + stat, err := os.Stat(f.Name()) + if err != nil { t.Fatal(err) } + + if have, want := stat.Mode().Perm(), perm; have != want { + t.Errorf("have %v, want %v", have, want) + } } func TestGoPath(t *testing.T) {