diff --git a/accounts/keystore/account_cache.go b/accounts/keystore/account_cache.go index dc6ac6ccb..4b08cc202 100644 --- a/accounts/keystore/account_cache.go +++ b/accounts/keystore/account_cache.go @@ -31,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" + "gopkg.in/fatih/set.v0" ) // Minimum amount of time between cache reloads. This limit applies if the platform does @@ -71,6 +72,14 @@ type accountCache struct { byAddr map[common.Address][]accounts.Account throttle *time.Timer notify chan struct{} + fileC fileCache +} + +// fileCache is a cache of files seen during scan of keystore +type fileCache struct { + all *set.SetNonTS // list of all files + mtime time.Time // latest mtime seen + mu sync.RWMutex } func newAccountCache(keydir string) (*accountCache, chan struct{}) { @@ -78,6 +87,7 @@ func newAccountCache(keydir string) (*accountCache, chan struct{}) { keydir: keydir, byAddr: make(map[common.Address][]accounts.Account), notify: make(chan struct{}, 1), + fileC: fileCache{all: set.NewNonTS()}, } ac.watcher = newWatcher(ac) return ac, ac.notify @@ -127,6 +137,23 @@ func (ac *accountCache) delete(removed accounts.Account) { } } +// deleteByFile removes an account referenced by the given path. +func (ac *accountCache) deleteByFile(path string) { + ac.mu.Lock() + defer ac.mu.Unlock() + i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path }) + + if i < len(ac.all) && ac.all[i].URL.Path == path { + removed := ac.all[i] + ac.all = append(ac.all[:i], ac.all[i+1:]...) + if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 { + delete(ac.byAddr, removed.Address) + } else { + ac.byAddr[removed.Address] = ba + } + } +} + func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.Account { for i := range slice { if slice[i] == elem { @@ -167,15 +194,16 @@ func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) { default: err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))} copy(err.Matches, matches) + sort.Sort(accountsByURL(err.Matches)) return accounts.Account{}, err } } func (ac *accountCache) maybeReload() { ac.mu.Lock() - defer ac.mu.Unlock() if ac.watcher.running { + ac.mu.Unlock() return // A watcher is running and will keep the cache up-to-date. } if ac.throttle == nil { @@ -184,12 +212,15 @@ func (ac *accountCache) maybeReload() { select { case <-ac.throttle.C: default: + ac.mu.Unlock() return // The cache was reloaded recently. } } + // No watcher running, start it. ac.watcher.start() - ac.reload() ac.throttle.Reset(minReloadInterval) + ac.mu.Unlock() + ac.scanAccounts() } func (ac *accountCache) close() { @@ -205,54 +236,76 @@ func (ac *accountCache) close() { ac.mu.Unlock() } -// reload caches addresses of existing accounts. -// Callers must hold ac.mu. -func (ac *accountCache) reload() { - accounts, err := ac.scan() +// scanFiles performs a new scan on the given directory, compares against the already +// cached filenames, and returns file sets: new, missing , modified +func (fc *fileCache) scanFiles(keyDir string) (set.Interface, set.Interface, set.Interface, error) { + t0 := time.Now() + files, err := ioutil.ReadDir(keyDir) + t1 := time.Now() if err != nil { - log.Debug("Failed to reload keystore contents", "err", err) + return nil, nil, nil, err } - ac.all = accounts - sort.Sort(ac.all) - for k := range ac.byAddr { - delete(ac.byAddr, k) - } - for _, a := range accounts { - ac.byAddr[a.Address] = append(ac.byAddr[a.Address], a) - } - select { - case ac.notify <- struct{}{}: - default: - } - log.Debug("Reloaded keystore contents", "accounts", len(ac.all)) -} + fc.mu.RLock() + prevMtime := fc.mtime + fc.mu.RUnlock() -func (ac *accountCache) scan() ([]accounts.Account, error) { - files, err := ioutil.ReadDir(ac.keydir) - if err != nil { - return nil, err - } - - var ( - buf = new(bufio.Reader) - addrs []accounts.Account - keyJSON struct { - Address string `json:"address"` - } - ) + filesNow := set.NewNonTS() + moddedFiles := set.NewNonTS() + var newMtime time.Time for _, fi := range files { - path := filepath.Join(ac.keydir, fi.Name()) + modTime := fi.ModTime() + path := filepath.Join(keyDir, fi.Name()) if skipKeyFile(fi) { log.Trace("Ignoring file on account scan", "path", path) continue } - logger := log.New("path", path) + filesNow.Add(path) + if modTime.After(prevMtime) { + moddedFiles.Add(path) + } + if modTime.After(newMtime) { + newMtime = modTime + } + } + t2 := time.Now() + fc.mu.Lock() + // Missing = previous - current + missing := set.Difference(fc.all, filesNow) + // New = current - previous + newFiles := set.Difference(filesNow, fc.all) + // Modified = modified - new + modified := set.Difference(moddedFiles, newFiles) + fc.all = filesNow + fc.mtime = newMtime + fc.mu.Unlock() + t3 := time.Now() + log.Debug("FS scan times", "list", t1.Sub(t0), "set", t2.Sub(t1), "diff", t3.Sub(t2)) + return newFiles, missing, modified, nil +} + +// scanAccounts checks if any changes have occurred on the filesystem, and +// updates the account cache accordingly +func (ac *accountCache) scanAccounts() error { + newFiles, missingFiles, modified, err := ac.fileC.scanFiles(ac.keydir) + t1 := time.Now() + if err != nil { + log.Debug("Failed to reload keystore contents", "err", err) + return err + } + var ( + buf = new(bufio.Reader) + keyJSON struct { + Address string `json:"address"` + } + ) + readAccount := func(path string) *accounts.Account { fd, err := os.Open(path) if err != nil { - logger.Trace("Failed to open keystore file", "err", err) - continue + log.Trace("Failed to open keystore file", "path", path, "err", err) + return nil } + defer fd.Close() buf.Reset(fd) // Parse the address. keyJSON.Address = "" @@ -260,15 +313,45 @@ func (ac *accountCache) scan() ([]accounts.Account, error) { addr := common.HexToAddress(keyJSON.Address) switch { case err != nil: - logger.Debug("Failed to decode keystore key", "err", err) + log.Debug("Failed to decode keystore key", "path", path, "err", err) case (addr == common.Address{}): - logger.Debug("Failed to decode keystore key", "err", "missing or zero address") + log.Debug("Failed to decode keystore key", "path", path, "err", "missing or zero address") default: - addrs = append(addrs, accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}}) + return &accounts.Account{Address: addr, URL: accounts.URL{Scheme: KeyStoreScheme, Path: path}} } - fd.Close() + return nil } - return addrs, err + + for _, p := range newFiles.List() { + path, _ := p.(string) + a := readAccount(path) + if a != nil { + ac.add(*a) + } + } + for _, p := range missingFiles.List() { + path, _ := p.(string) + ac.deleteByFile(path) + } + + for _, p := range modified.List() { + path, _ := p.(string) + a := readAccount(path) + ac.deleteByFile(path) + if a != nil { + ac.add(*a) + } + } + + t2 := time.Now() + + select { + case ac.notify <- struct{}{}: + default: + } + log.Trace("Handled keystore changes", "time", t2.Sub(t1)) + + return nil } func skipKeyFile(fi os.FileInfo) bool { diff --git a/accounts/keystore/account_cache_test.go b/accounts/keystore/account_cache_test.go index ab8aa9e6c..e3dc31065 100644 --- a/accounts/keystore/account_cache_test.go +++ b/accounts/keystore/account_cache_test.go @@ -18,6 +18,7 @@ package keystore import ( "fmt" + "io/ioutil" "math/rand" "os" "path/filepath" @@ -295,3 +296,101 @@ func TestCacheFind(t *testing.T) { } } } + +func waitForAccounts(wantAccounts []accounts.Account, ks *KeyStore) error { + var list []accounts.Account + for d := 200 * time.Millisecond; d < 8*time.Second; d *= 2 { + list = ks.Accounts() + if reflect.DeepEqual(list, wantAccounts) { + // ks should have also received change notifications + select { + case <-ks.changes: + default: + return fmt.Errorf("wasn't notified of new accounts") + } + return nil + } + time.Sleep(d) + } + return fmt.Errorf("\ngot %v\nwant %v", list, wantAccounts) +} + +// TestUpdatedKeyfileContents tests that updating the contents of a keystore file +// is noticed by the watcher, and the account cache is updated accordingly +func TestUpdatedKeyfileContents(t *testing.T) { + t.Parallel() + + // Create a temporary kesytore to test with + rand.Seed(time.Now().UnixNano()) + dir := filepath.Join(os.TempDir(), fmt.Sprintf("eth-keystore-watch-test-%d-%d", os.Getpid(), rand.Int())) + ks := NewKeyStore(dir, LightScryptN, LightScryptP) + + list := ks.Accounts() + if len(list) > 0 { + t.Error("initial account list not empty:", list) + } + time.Sleep(100 * time.Millisecond) + + // Create the directory and copy a key file into it. + os.MkdirAll(dir, 0700) + defer os.RemoveAll(dir) + file := filepath.Join(dir, "aaa") + + // Place one of our testfiles in there + if err := cp.CopyFile(file, cachetestAccounts[0].URL.Path); err != nil { + t.Fatal(err) + } + + // ks should see the account. + wantAccounts := []accounts.Account{cachetestAccounts[0]} + wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file} + if err := waitForAccounts(wantAccounts, ks); err != nil { + t.Error(err) + return + } + + // Now replace file contents + if err := forceCopyFile(file, cachetestAccounts[1].URL.Path); err != nil { + t.Fatal(err) + return + } + wantAccounts = []accounts.Account{cachetestAccounts[1]} + wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file} + if err := waitForAccounts(wantAccounts, ks); err != nil { + t.Errorf("First replacement failed") + t.Error(err) + return + } + + // Now replace file contents again + if err := forceCopyFile(file, cachetestAccounts[2].URL.Path); err != nil { + t.Fatal(err) + return + } + wantAccounts = []accounts.Account{cachetestAccounts[2]} + wantAccounts[0].URL = accounts.URL{Scheme: KeyStoreScheme, Path: file} + if err := waitForAccounts(wantAccounts, ks); err != nil { + t.Errorf("Second replacement failed") + t.Error(err) + return + } + // Now replace file contents with crap + if err := ioutil.WriteFile(file, []byte("foo"), 0644); err != nil { + t.Fatal(err) + return + } + if err := waitForAccounts([]accounts.Account{}, ks); err != nil { + t.Errorf("Emptying account file failed") + t.Error(err) + return + } +} + +// forceCopyFile is like cp.CopyFile, but doesn't complain if the destination exists. +func forceCopyFile(dst, src string) error { + data, err := ioutil.ReadFile(src) + if err != nil { + return err + } + return ioutil.WriteFile(dst, data, 0644) +} diff --git a/accounts/keystore/keystore_test.go b/accounts/keystore/keystore_test.go index 132b74439..6fb0a7808 100644 --- a/accounts/keystore/keystore_test.go +++ b/accounts/keystore/keystore_test.go @@ -272,82 +272,104 @@ func TestWalletNotifierLifecycle(t *testing.T) { t.Errorf("wallet notifier didn't terminate after unsubscribe") } +type walletEvent struct { + accounts.WalletEvent + a accounts.Account +} + // Tests that wallet notifications and correctly fired when accounts are added // or deleted from the keystore. func TestWalletNotifications(t *testing.T) { - // Create a temporary kesytore to test with dir, ks := tmpKeyStore(t, false) defer os.RemoveAll(dir) - // Subscribe to the wallet feed - updates := make(chan accounts.WalletEvent, 1) - sub := ks.Subscribe(updates) + // Subscribe to the wallet feed and collect events. + var ( + events []walletEvent + updates = make(chan accounts.WalletEvent) + sub = ks.Subscribe(updates) + ) defer sub.Unsubscribe() + go func() { + for { + select { + case ev := <-updates: + events = append(events, walletEvent{ev, ev.Wallet.Accounts()[0]}) + case <-sub.Err(): + close(updates) + return + } + } + }() - // Randomly add and remove account and make sure events and wallets are in sync - live := make(map[common.Address]accounts.Account) + // Randomly add and remove accounts. + var ( + live = make(map[common.Address]accounts.Account) + wantEvents []walletEvent + ) for i := 0; i < 1024; i++ { - // Execute a creation or deletion and ensure event arrival if create := len(live) == 0 || rand.Int()%4 > 0; create { // Add a new account and ensure wallet notifications arrives account, err := ks.NewAccount("") if err != nil { t.Fatalf("failed to create test account: %v", err) } - select { - case event := <-updates: - if event.Kind != accounts.WalletArrived { - t.Errorf("non-arrival event on account creation") - } - if event.Wallet.Accounts()[0] != account { - t.Errorf("account mismatch on created wallet: have %v, want %v", event.Wallet.Accounts()[0], account) - } - default: - t.Errorf("wallet arrival event not fired on account creation") - } live[account.Address] = account + wantEvents = append(wantEvents, walletEvent{accounts.WalletEvent{Kind: accounts.WalletArrived}, account}) } else { - // Select a random account to delete (crude, but works) + // Delete a random account. var account accounts.Account for _, a := range live { account = a break } - // Remove an account and ensure wallet notifiaction arrives if err := ks.Delete(account, ""); err != nil { t.Fatalf("failed to delete test account: %v", err) } - select { - case event := <-updates: - if event.Kind != accounts.WalletDropped { - t.Errorf("non-drop event on account deletion") - } - if event.Wallet.Accounts()[0] != account { - t.Errorf("account mismatch on deleted wallet: have %v, want %v", event.Wallet.Accounts()[0], account) - } - default: - t.Errorf("wallet departure event not fired on account creation") - } delete(live, account.Address) + wantEvents = append(wantEvents, walletEvent{accounts.WalletEvent{Kind: accounts.WalletDropped}, account}) } - // Retrieve the list of wallets and ensure it matches with our required live set - liveList := make([]accounts.Account, 0, len(live)) - for _, account := range live { - liveList = append(liveList, account) - } - sort.Sort(accountsByURL(liveList)) + } - wallets := ks.Wallets() - if len(liveList) != len(wallets) { - t.Errorf("wallet list doesn't match required accounts: have %v, want %v", wallets, liveList) - } else { - for j, wallet := range wallets { - if accs := wallet.Accounts(); len(accs) != 1 { - t.Errorf("wallet %d: contains invalid number of accounts: have %d, want 1", j, len(accs)) - } else if accs[0] != liveList[j] { - t.Errorf("wallet %d: account mismatch: have %v, want %v", j, accs[0], liveList[j]) - } + // Shut down the event collector and check events. + sub.Unsubscribe() + <-updates + checkAccounts(t, live, ks.Wallets()) + checkEvents(t, wantEvents, events) +} + +// checkAccounts checks that all known live accounts are present in the wallet list. +func checkAccounts(t *testing.T, live map[common.Address]accounts.Account, wallets []accounts.Wallet) { + if len(live) != len(wallets) { + t.Errorf("wallet list doesn't match required accounts: have %d, want %d", len(wallets), len(live)) + return + } + liveList := make([]accounts.Account, 0, len(live)) + for _, account := range live { + liveList = append(liveList, account) + } + sort.Sort(accountsByURL(liveList)) + for j, wallet := range wallets { + if accs := wallet.Accounts(); len(accs) != 1 { + t.Errorf("wallet %d: contains invalid number of accounts: have %d, want 1", j, len(accs)) + } else if accs[0] != liveList[j] { + t.Errorf("wallet %d: account mismatch: have %v, want %v", j, accs[0], liveList[j]) + } + } +} + +// checkEvents checks that all events in 'want' are present in 'have'. Events may be present multiple times. +func checkEvents(t *testing.T, want []walletEvent, have []walletEvent) { + for _, wantEv := range want { + nmatch := 0 + for ; len(have) > 0; nmatch++ { + if have[0].Kind != wantEv.Kind || have[0].a != wantEv.a { + break } + have = have[1:] + } + if nmatch == 0 { + t.Fatalf("can't find event with Kind=%v for %x", wantEv.Kind, wantEv.a.Address) } } } diff --git a/accounts/keystore/watch.go b/accounts/keystore/watch.go index f4d647791..602300b10 100644 --- a/accounts/keystore/watch.go +++ b/accounts/keystore/watch.go @@ -70,7 +70,6 @@ func (w *watcher) loop() { return } defer notify.Stop(w.ev) - logger.Trace("Started watching keystore folder") defer logger.Trace("Stopped watching keystore folder") @@ -82,9 +81,9 @@ func (w *watcher) loop() { // When an event occurs, the reload call is delayed a bit so that // multiple events arriving quickly only cause a single reload. var ( - debounce = time.NewTimer(0) - debounceDuration = 500 * time.Millisecond - inCycle, hadEvent bool + debounce = time.NewTimer(0) + debounceDuration = 500 * time.Millisecond + rescanTriggered = false ) defer debounce.Stop() for { @@ -92,22 +91,14 @@ func (w *watcher) loop() { case <-w.quit: return case <-w.ev: - if !inCycle { + // Trigger the scan (with delay), if not already triggered + if !rescanTriggered { debounce.Reset(debounceDuration) - inCycle = true - } else { - hadEvent = true + rescanTriggered = true } case <-debounce.C: - w.ac.mu.Lock() - w.ac.reload() - w.ac.mu.Unlock() - if hadEvent { - debounce.Reset(debounceDuration) - inCycle, hadEvent = true, false - } else { - inCycle, hadEvent = false, false - } + w.ac.scanAccounts() + rescanTriggered = false } } }