diff --git a/common/os.go b/common/os.go index 36fc969f..f1e07115 100644 --- a/common/os.go +++ b/common/os.go @@ -124,32 +124,32 @@ func MustWriteFile(filePath string, contents []byte, mode os.FileMode) { } } -// WriteFileAtomic writes newBytes to temp and atomically moves to filePath -// when everything else succeeds. -func WriteFileAtomic(filePath string, newBytes []byte, mode os.FileMode) error { - dir := filepath.Dir(filePath) - f, err := ioutil.TempFile(dir, "") +// WriteFileAtomic creates a temporary file with data and the perm given and +// swaps it atomically with filename if successful. +func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { + var ( + dir = filepath.Dir(filename) + tempFile = filepath.Join(dir, "write-file-atomic-"+RandStr(32)) + // Override in case it does exist, create in case it doesn't and force kernel + // flush, which still leaves the potential of lingering disk cache. + flag = os.O_WRONLY | os.O_CREATE | os.O_SYNC | os.O_TRUNC + ) + + f, err := os.OpenFile(tempFile, flag, perm) if err != nil { return err } - _, err = f.Write(newBytes) - if err == nil { - err = f.Sync() + // Clean up in any case. Defer stacking order is last-in-first-out. + defer os.Remove(f.Name()) + defer f.Close() + + if n, err := f.Write(data); err != nil { + return err + } else if n < len(data) { + return io.ErrShortWrite } - if closeErr := f.Close(); err == nil { - err = closeErr - } - if permErr := os.Chmod(f.Name(), mode); err == nil { - err = permErr - } - if err == nil { - err = os.Rename(f.Name(), filePath) - } - // any err should result in full cleanup - if err != nil { - os.Remove(f.Name()) - } - return err + + return os.Rename(f.Name(), filename) } //-------------------------------------------------------------------------------- diff --git a/common/os_test.go b/common/os_test.go index 126723aa..97ad672b 100644 --- a/common/os_test.go +++ b/common/os_test.go @@ -2,30 +2,52 @@ package common import ( "bytes" - "fmt" "io/ioutil" + "math/rand" "os" "testing" "time" ) func TestWriteFileAtomic(t *testing.T) { - data := []byte("Becatron") - fname := fmt.Sprintf("/tmp/write-file-atomic-test-%v.txt", time.Now().UnixNano()) - err := WriteFileAtomic(fname, data, 0664) + var ( + seed = rand.New(rand.NewSource(time.Now().UnixNano())) + data = []byte(RandStr(seed.Intn(2048))) + old = RandBytes(seed.Intn(2048)) + perm os.FileMode = 0600 + ) + + f, err := ioutil.TempFile("/tmp", "write-atomic-test-") if err != nil { t.Fatal(err) } - rData, err := ioutil.ReadFile(fname) + defer os.Remove(f.Name()) + + if err := ioutil.WriteFile(f.Name(), old, 0664); err != nil { + t.Fatal(err) + } + + if err := WriteFileAtomic(f.Name(), data, perm); err != nil { + t.Fatal(err) + } + + rData, err := ioutil.ReadFile(f.Name()) if err != nil { t.Fatal(err) } + if !bytes.Equal(data, rData) { t.Fatalf("data mismatch: %v != %v", data, rData) } - if err := os.Remove(fname); err != nil { + + stat, err := os.Stat(f.Name()) + if err != nil { t.Fatal(err) } + + if have, want := stat.Mode().Perm(), perm; have != want { + t.Errorf("have %v, want %v", have, want) + } } func TestGoPath(t *testing.T) {