diff --git a/build/update-license.go b/build/update-license.go index f83a2b34b..96667be15 100644 --- a/build/update-license.go +++ b/build/update-license.go @@ -49,7 +49,6 @@ var ( // don't relicense vendored sources "crypto/sha3/", "crypto/ecies/", "logger/glog/", "crypto/secp256k1/curve.go", - "trie/arc.go", // don't license generated files "contracts/chequebook/contract/", "contracts/ens/contract/", diff --git a/core/blockchain.go b/core/blockchain.go index a5f146a2d..1fbcdfc6f 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -357,7 +357,12 @@ func (self *BlockChain) AuxValidator() pow.PoW { return self.pow } // State returns a new mutable state based on the current HEAD block. func (self *BlockChain) State() (*state.StateDB, error) { - return state.New(self.CurrentBlock().Root(), self.chainDb) + return self.StateAt(self.CurrentBlock().Root()) +} + +// StateAt returns a new mutable state based on a particular point in time. +func (self *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { + return self.stateCache.New(root) } // Reset purges the entire blockchain, restoring it to its genesis state. diff --git a/core/state/iterator.go b/core/state/iterator.go index 9d8a69b7c..14265b277 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -76,7 +76,7 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started if it.stateIt == nil { - it.stateIt = trie.NewNodeIterator(it.state.trie.Trie) + it.stateIt = it.state.trie.NodeIterator() } // If we had data nodes previously, we surely have at least state nodes if it.dataIt != nil { diff --git a/core/state/state_object.go b/core/state/state_object.go index 3496008a6..a54620d55 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -95,8 +95,6 @@ type Account struct { Balance *big.Int Root common.Hash // merkle root of the storage trie CodeHash []byte - - codeSize *int } // NewObject creates a state object. @@ -275,20 +273,9 @@ func (self *StateObject) Code(db trie.Database) []byte { return code } -// CodeSize returns the size of the contract code associated with this object. -func (self *StateObject) CodeSize(db trie.Database) int { - if self.data.codeSize == nil { - self.data.codeSize = new(int) - *self.data.codeSize = len(self.Code(db)) - } - return *self.data.codeSize -} - func (self *StateObject) SetCode(code []byte) { self.code = code self.data.CodeHash = crypto.Keccak256(code) - self.data.codeSize = new(int) - *self.data.codeSize = len(code) self.dirtyCode = true if self.onDirty != nil { self.onDirty(self.Address()) diff --git a/core/state/statedb.go b/core/state/statedb.go index 10f3f4652..5c51e3b59 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -28,23 +29,32 @@ import ( "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + lru "github.com/hashicorp/golang-lru" ) // The starting nonce determines the default nonce when new accounts are being // created. var StartingNonce uint64 +const ( + // Number of past tries to keep. The arbitrarily chosen value here + // is max uncle depth + 1. + maxJournalLength = 8 + + // Number of codehash->size associations to keep. + codeSizeCacheSize = 100000 +) + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: // * Contracts // * Accounts type StateDB struct { - db ethdb.Database - trie *trie.SecureTrie - - // This map caches canon state accounts. - all map[common.Address]Account + db ethdb.Database + trie *trie.SecureTrie + pastTries []*trie.SecureTrie + codeSizeCache *lru.Cache // This map holds 'live' objects, which will get modified while processing a state transition. stateObjects map[common.Address]*StateObject @@ -57,6 +67,8 @@ type StateDB struct { txIndex int logs map[common.Hash]vm.Logs logSize uint + + lock sync.Mutex } // Create a new state from a given trie @@ -65,10 +77,32 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { if err != nil { return nil, err } + csc, _ := lru.New(codeSizeCacheSize) return &StateDB{ db: db, trie: tr, - all: make(map[common.Address]Account), + codeSizeCache: csc, + stateObjects: make(map[common.Address]*StateObject), + stateObjectsDirty: make(map[common.Address]struct{}), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), + }, nil +} + +// New creates a new statedb by reusing any journalled tries to avoid costly +// disk io. +func (self *StateDB) New(root common.Hash) (*StateDB, error) { + self.lock.Lock() + defer self.lock.Unlock() + + tr, err := self.openTrie(root) + if err != nil { + return nil, err + } + return &StateDB{ + db: self.db, + trie: tr, + codeSizeCache: self.codeSizeCache, stateObjects: make(map[common.Address]*StateObject), stateObjectsDirty: make(map[common.Address]struct{}), refund: new(big.Int), @@ -79,27 +113,50 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { // Reset clears out all emphemeral state objects from the state db, but keeps // the underlying state trie to avoid reloading data for the next operations. func (self *StateDB) Reset(root common.Hash) error { - tr, err := trie.NewSecure(root, self.db) + self.lock.Lock() + defer self.lock.Unlock() + + tr, err := self.openTrie(root) if err != nil { return err } - all := self.all - if self.trie.Hash() != root { - // The root has changed, invalidate canon state. - all = make(map[common.Address]Account) - } - *self = StateDB{ - db: self.db, - trie: tr, - all: all, - stateObjects: make(map[common.Address]*StateObject), - stateObjectsDirty: make(map[common.Address]struct{}), - refund: new(big.Int), - logs: make(map[common.Hash]vm.Logs), - } + self.trie = tr + self.stateObjects = make(map[common.Address]*StateObject) + self.stateObjectsDirty = make(map[common.Address]struct{}) + self.refund = new(big.Int) + self.thash = common.Hash{} + self.bhash = common.Hash{} + self.txIndex = 0 + self.logs = make(map[common.Hash]vm.Logs) + self.logSize = 0 + return nil } +// openTrie creates a trie. It uses an existing trie if one is available +// from the journal if available. +func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) { + for i := len(self.pastTries) - 1; i >= 0; i-- { + if self.pastTries[i].Hash() == root { + tr := *self.pastTries[i] + return &tr, nil + } + } + return trie.NewSecure(root, self.db) +} + +func (self *StateDB) pushTrie(t *trie.SecureTrie) { + self.lock.Lock() + defer self.lock.Unlock() + + if len(self.pastTries) >= maxJournalLength { + copy(self.pastTries, self.pastTries[1:]) + self.pastTries[len(self.pastTries)-1] = t + } else { + self.pastTries = append(self.pastTries, t) + } +} + func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { self.thash = thash self.bhash = bhash @@ -165,17 +222,28 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.GetStateObject(addr) if stateObject != nil { - return stateObject.Code(self.db) + code := stateObject.Code(self.db) + key := common.BytesToHash(stateObject.CodeHash()) + self.codeSizeCache.Add(key, len(code)) + return code } return nil } func (self *StateDB) GetCodeSize(addr common.Address) int { stateObject := self.GetStateObject(addr) - if stateObject != nil { - return stateObject.CodeSize(self.db) + if stateObject == nil { + return 0 } - return 0 + key := common.BytesToHash(stateObject.CodeHash()) + if cached, ok := self.codeSizeCache.Get(key); ok { + return cached.(int) + } + size := len(stateObject.Code(self.db)) + if stateObject.dbErr == nil { + self.codeSizeCache.Add(key, size) + } + return size } func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { @@ -269,13 +337,6 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje return obj } - // Use cached account data from the canon state if possible. - if data, ok := self.all[addr]; ok { - obj := NewObject(addr, data, self.MarkStateObjectDirty) - self.SetStateObject(obj) - return obj - } - // Load the object from the database. enc := self.trie.Get(addr[:]) if len(enc) == 0 { @@ -286,10 +347,6 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje glog.Errorf("can't decode object at %x: %v", addr[:], err) return nil } - // Update the all cache. Content in DB always corresponds - // to the current head state so this is ok to do here. - // The object we just loaded has no storage trie and code yet. - self.all[addr] = data // Insert into the live set. obj := NewObject(addr, data, self.MarkStateObjectDirty) self.SetStateObject(obj) @@ -351,11 +408,15 @@ func (self *StateDB) CreateAccount(addr common.Address) vm.Account { // func (self *StateDB) Copy() *StateDB { + self.lock.Lock() + defer self.lock.Unlock() + // Copy all the basic fields, initialize the memory ones state := &StateDB{ db: self.db, trie: self.trie, - all: self.all, + pastTries: self.pastTries, + codeSizeCache: self.codeSizeCache, stateObjects: make(map[common.Address]*StateObject, len(self.stateObjectsDirty)), stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), refund: new(big.Int).Set(self.refund), @@ -375,11 +436,15 @@ func (self *StateDB) Copy() *StateDB { } func (self *StateDB) Set(state *StateDB) { + self.lock.Lock() + defer self.lock.Unlock() + + self.db = state.db self.trie = state.trie + self.pastTries = state.pastTries self.stateObjects = state.stateObjects self.stateObjectsDirty = state.stateObjectsDirty - self.all = state.all - + self.codeSizeCache = state.codeSizeCache self.refund = state.refund self.logs = state.logs self.logSize = state.logSize @@ -444,12 +509,6 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { s.refund = new(big.Int) - defer func() { - if err != nil { - // Committing failed, any updates to the canon state are invalid. - s.all = make(map[common.Address]Account) - } - }() // Commit objects to the trie. for addr, stateObject := range s.stateObjects { @@ -457,7 +516,6 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. s.DeleteStateObject(stateObject) - delete(s.all, addr) } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { @@ -472,12 +530,15 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) } // Update the object in the main account trie. s.UpdateStateObject(stateObject) - s.all[addr] = stateObject.data } delete(s.stateObjectsDirty, addr) } // Write trie changes. - return s.trie.CommitTo(dbw) + root, err = s.trie.CommitTo(dbw) + if err == nil { + s.pushTrie(s.trie) + } + return root, err } func (self *StateDB) Refunds() *big.Int { diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 715645c6c..c768781a4 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -62,9 +62,6 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { } root, _ := state.Commit() - // Remove any potentially cached data from the test state creation - trie.ClearGlobalCache() - // Return the generated state return db, root, accounts } @@ -72,9 +69,6 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { // checkStateAccounts cross references a reconstructed state with an expected // account array. func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { - // Remove any potentially cached data from the state synchronisation - trie.ClearGlobalCache() - // Check root availability and state contents state, err := New(root, db) if err != nil { @@ -98,9 +92,6 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou // checkStateConsistency checks that all nodes in a state trie are indeed present. func checkStateConsistency(db ethdb.Database, root common.Hash) error { - // Remove any potentially cached data from the test state creation or previous checks - trie.ClearGlobalCache() - // Create and iterate a state trie rooted in a sub-node if _, err := db.Get(root.Bytes()); err != nil { return nil // Consider a non existent state consistent diff --git a/eth/api.go b/eth/api.go index d6c0826ed..c2fdbe99c 100644 --- a/eth/api.go +++ b/eth/api.go @@ -293,7 +293,7 @@ func (api *PublicDebugAPI) DumpBlock(number uint64) (state.Dump, error) { if block == nil { return state.Dump{}, fmt.Errorf("block #%d not found", number) } - stateDb, err := state.New(block.Root(), api.eth.ChainDb()) + stateDb, err := api.eth.BlockChain().StateAt(block.Root()) if err != nil { return state.Dump{}, err } @@ -406,7 +406,7 @@ func (api *PrivateDebugAPI) traceBlock(block *types.Block, logConfig *vm.LogConf if err := core.ValidateHeader(api.config, blockchain.AuxValidator(), block.Header(), blockchain.GetHeader(block.ParentHash(), block.NumberU64()-1), true, false); err != nil { return false, structLogger.StructLogs(), err } - statedb, err := state.New(blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1).Root(), api.eth.ChainDb()) + statedb, err := blockchain.StateAt(blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) if err != nil { return false, structLogger.StructLogs(), err } @@ -501,7 +501,7 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, txHash common. if parent == nil { return nil, fmt.Errorf("block parent %x not found", block.ParentHash()) } - stateDb, err := state.New(parent.Root(), api.eth.ChainDb()) + stateDb, err := api.eth.BlockChain().StateAt(parent.Root()) if err != nil { return nil, err } diff --git a/eth/api_backend.go b/eth/api_backend.go index 4f8f06529..4adeb0aa0 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -81,7 +81,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(blockNr rpc.BlockNumber) (ethapi. if header == nil { return nil, nil, nil } - stateDb, err := state.New(header.Root, b.eth.chainDb) + stateDb, err := b.eth.BlockChain().StateAt(header.Root) return EthApiState{stateDb}, header, err } diff --git a/ethdb/database.go b/ethdb/database.go index f93731cfe..a4a27303a 100644 --- a/ethdb/database.go +++ b/ethdb/database.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/filter" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" @@ -84,6 +85,7 @@ func NewLDBDatabase(file string, cache int, handles int) (*LDBDatabase, error) { OpenFilesCacheCapacity: handles, BlockCacheCapacity: cache / 2 * opt.MiB, WriteBuffer: cache / 4 * opt.MiB, // Two of these are used internally + Filter: filter.NewBloomFilter(10), }) if _, corrupted := err.(*errors.ErrCorrupted); corrupted { db, err = leveldb.RecoverFile(file, nil) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 6480085dd..9a97be25f 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -454,6 +454,8 @@ type CallArgs struct { } func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr rpc.BlockNumber) (string, *big.Int, error) { + defer func(start time.Time) { glog.V(logger.Debug).Infof("call took %v", time.Since(start)) }(time.Now()) + state, header, err := s.b.StateAndHeaderByNumber(blockNr) if state == nil || err != nil { return "0x", common.Big0, err diff --git a/light/state_test.go b/light/state_test.go index 90c38604a..d7014a2dc 100644 --- a/light/state_test.go +++ b/light/state_test.go @@ -42,7 +42,6 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { case *TrieRequest: t, _ := trie.New(req.root, odr.sdb) req.proof = t.Prove(req.key) - trie.ClearGlobalCache() case *NodeDataRequest: req.data, _ = odr.sdb.Get(req.hash[:]) } @@ -75,7 +74,6 @@ func TestLightStateOdr(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() for i := byte(0); i < 100; i++ { addr := common.Address{i} @@ -160,7 +158,6 @@ func TestLightStateSetCopy(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() for i := byte(0); i < 100; i++ { addr := common.Address{i} @@ -237,7 +234,6 @@ func TestLightStateDelete(t *testing.T) { odr := &testOdr{sdb: sdb, ldb: ldb} ls := NewLightState(root, odr) ctx := context.Background() - trie.ClearGlobalCache() addr := common.Address{42} diff --git a/miner/worker.go b/miner/worker.go index 1676036d8..ac1ef5ba3 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -361,7 +361,7 @@ func (self *worker) push(work *Work) { // makeCurrent creates a new environment for the current cycle. func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error { - state, err := state.New(parent.Root(), self.eth.ChainDb()) + state, err := self.chain.StateAt(parent.Root()) if err != nil { return err } diff --git a/trie/arc.go b/trie/arc.go deleted file mode 100644 index fc7a3259f..000000000 --- a/trie/arc.go +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) 2015 Hans Alexander Gugel -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// This file contains a modified version of package arc from -// https://github.com/alexanderGugel/arc -// -// It implements the ARC (Adaptive Replacement Cache) algorithm as detailed in -// https://www.usenix.org/legacy/event/fast03/tech/full_papers/megiddo/megiddo.pdf - -package trie - -import ( - "container/list" - "sync" -) - -type arc struct { - p int - c int - t1 *list.List - b1 *list.List - t2 *list.List - b2 *list.List - cache map[string]*entry - mutex sync.Mutex -} - -type entry struct { - key hashNode - value node - ll *list.List - el *list.Element -} - -// newARC returns a new Adaptive Replacement Cache with the -// given capacity. -func newARC(c int) *arc { - return &arc{ - c: c, - t1: list.New(), - b1: list.New(), - t2: list.New(), - b2: list.New(), - cache: make(map[string]*entry, c), - } -} - -// Clear clears the cache -func (a *arc) Clear() { - a.mutex.Lock() - defer a.mutex.Unlock() - a.p = 0 - a.t1 = list.New() - a.b1 = list.New() - a.t2 = list.New() - a.b2 = list.New() - a.cache = make(map[string]*entry, a.c) -} - -// Put inserts a new key-value pair into the cache. -// This optimizes future access to this entry (side effect). -func (a *arc) Put(key hashNode, value node) bool { - a.mutex.Lock() - defer a.mutex.Unlock() - ent, ok := a.cache[string(key)] - if ok != true { - ent = &entry{key: key, value: value} - a.req(ent) - a.cache[string(key)] = ent - } else { - ent.value = value - a.req(ent) - } - return ok -} - -// Get retrieves a previously via Set inserted entry. -// This optimizes future access to this entry (side effect). -func (a *arc) Get(key hashNode) (value node, ok bool) { - a.mutex.Lock() - defer a.mutex.Unlock() - ent, ok := a.cache[string(key)] - if ok { - a.req(ent) - return ent.value, ent.value != nil - } - return nil, false -} - -func (a *arc) req(ent *entry) { - if ent.ll == a.t1 || ent.ll == a.t2 { - // Case I - ent.setMRU(a.t2) - } else if ent.ll == a.b1 { - // Case II - // Cache Miss in t1 and t2 - - // Adaptation - var d int - if a.b1.Len() >= a.b2.Len() { - d = 1 - } else { - d = a.b2.Len() / a.b1.Len() - } - a.p = a.p + d - if a.p > a.c { - a.p = a.c - } - - a.replace(ent) - ent.setMRU(a.t2) - } else if ent.ll == a.b2 { - // Case III - // Cache Miss in t1 and t2 - - // Adaptation - var d int - if a.b2.Len() >= a.b1.Len() { - d = 1 - } else { - d = a.b1.Len() / a.b2.Len() - } - a.p = a.p - d - if a.p < 0 { - a.p = 0 - } - - a.replace(ent) - ent.setMRU(a.t2) - } else if ent.ll == nil { - // Case IV - - if a.t1.Len()+a.b1.Len() == a.c { - // Case A - if a.t1.Len() < a.c { - a.delLRU(a.b1) - a.replace(ent) - } else { - a.delLRU(a.t1) - } - } else if a.t1.Len()+a.b1.Len() < a.c { - // Case B - if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() >= a.c { - if a.t1.Len()+a.t2.Len()+a.b1.Len()+a.b2.Len() == 2*a.c { - a.delLRU(a.b2) - } - a.replace(ent) - } - } - - ent.setMRU(a.t1) - } -} - -func (a *arc) delLRU(list *list.List) { - lru := list.Back() - list.Remove(lru) - delete(a.cache, string(lru.Value.(*entry).key)) -} - -func (a *arc) replace(ent *entry) { - if a.t1.Len() > 0 && ((a.t1.Len() > a.p) || (ent.ll == a.b2 && a.t1.Len() == a.p)) { - lru := a.t1.Back().Value.(*entry) - lru.value = nil - lru.setMRU(a.b1) - } else { - lru := a.t2.Back().Value.(*entry) - lru.value = nil - lru.setMRU(a.b2) - } -} - -func (e *entry) setLRU(list *list.List) { - e.detach() - e.ll = list - e.el = e.ll.PushBack(e) -} - -func (e *entry) setMRU(list *list.List) { - e.detach() - e.ll = list - e.el = e.ll.PushFront(e) -} - -func (e *entry) detach() { - if e.ll != nil { - e.ll.Remove(e.el) - } -} diff --git a/trie/hasher.go b/trie/hasher.go new file mode 100644 index 000000000..87e02fb85 --- /dev/null +++ b/trie/hasher.go @@ -0,0 +1,157 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "hash" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/rlp" +) + +type hasher struct { + tmp *bytes.Buffer + sha hash.Hash +} + +// hashers live in a global pool. +var hasherPool = sync.Pool{ + New: func() interface{} { + return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} + }, +} + +func newHasher() *hasher { + return hasherPool.Get().(*hasher) +} + +func returnHasherToPool(h *hasher) { + hasherPool.Put(h) +} + +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialzied with the computed hash to replace the original one. +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { + // If we're not storing the node, just hashing, use avaialble cached data + if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { + return hash, n, nil + } + // Trie not processed yet or needs storage, walk the children + collapsed, cached, err := h.hashChildren(n, db) + if err != nil { + return hashNode{}, n, err + } + hashed, err := h.store(collapsed, db, force) + if err != nil { + return hashNode{}, n, err + } + // Cache the hash and RLP blob of the ndoe for later reuse + if hash, ok := hashed.(hashNode); ok && !force { + switch cached := cached.(type) { + case shortNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + case fullNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + } + } + return hashed, cached, nil +} + +// hashChildren replaces the children of a node with their hashes if the encoded +// size of the child is larger than a hash, returning the collapsed node as well +// as a replacement for the original node with the child hashes cached in. +func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { + var err error + + switch n := original.(type) { + case shortNode: + // Hash the short node's child, caching the newly hashed subtree + cached := n + cached.Key = common.CopyBytes(cached.Key) + + n.Key = compactEncode(n.Key) + if _, ok := n.Val.(valueNode); !ok { + if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { + return n, original, err + } + } + if n.Val == nil { + n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. + } + return n, cached, nil + + case fullNode: + // Hash the full node's children, caching the newly hashed subtrees + cached := fullNode{dirty: n.dirty} + + for i := 0; i < 16; i++ { + if n.Children[i] != nil { + if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { + return n, original, err + } + } else { + n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. + } + } + cached.Children[16] = n.Children[16] + if n.Children[16] == nil { + n.Children[16] = valueNode(nil) + } + return n, cached, nil + + default: + // Value and hash nodes don't have children so they're left as were + return n, original, nil + } +} + +func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { + // Don't store hashes or empty nodes. + if _, isHash := n.(hashNode); n == nil || isHash { + return n, nil + } + // Generate the RLP encoding of the node + h.tmp.Reset() + if err := rlp.Encode(h.tmp, n); err != nil { + panic("encode error: " + err.Error()) + } + if h.tmp.Len() < 32 && !force { + return n, nil // Nodes smaller than 32 bytes are stored inside their parent + } + // Larger nodes are replaced by their hash and stored in the database. + hash, _ := n.cache() + if hash == nil { + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + hash = hashNode(h.sha.Sum(nil)) + } + if db != nil { + return hash, db.Put(hash, h.tmp.Bytes()) + } + return hash, nil +} diff --git a/trie/iterator.go b/trie/iterator.go index 88c4cee7f..8cad51aff 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -16,18 +16,13 @@ package trie -import ( - "bytes" - "fmt" +import "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/logger/glog" -) - -// Iterator is a key-value trie iterator to traverse the data contents. +// Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { - trie *Trie + trie *Trie + nodeIt *NodeIterator + keyBuf []byte Key []byte // Current data key on which the iterator is positioned on Value []byte // Current data value on which the iterator is positioned on @@ -35,119 +30,45 @@ type Iterator struct { // NewIterator creates a new key-value iterator. func NewIterator(trie *Trie) *Iterator { - return &Iterator{trie: trie, Key: nil} + return &Iterator{ + trie: trie, + nodeIt: NewNodeIterator(trie), + keyBuf: make([]byte, 0, 64), + Key: nil, + } } -// Next moves the iterator forward with one key-value entry. -func (self *Iterator) Next() bool { - isIterStart := false - if self.Key == nil { - isIterStart = true - self.Key = make([]byte, 32) +// Next moves the iterator forward one key-value entry. +func (it *Iterator) Next() bool { + for it.nodeIt.Next() { + if it.nodeIt.Leaf { + it.Key = it.makeKey() + it.Value = it.nodeIt.LeafBlob + return true + } } - - key := remTerm(compactHexDecode(self.Key)) - k := self.next(self.trie.root, key, isIterStart) - - self.Key = []byte(decodeCompact(k)) - - return len(k) > 0 + it.Key = nil + it.Value = nil + return false } -func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byte { - if node == nil { - return nil +func (it *Iterator) makeKey() []byte { + key := it.keyBuf[:0] + for _, se := range it.nodeIt.stack { + switch node := se.node.(type) { + case fullNode: + if se.child <= 16 { + key = append(key, byte(se.child)) + } + case shortNode: + if hasTerm(node.Key) { + key = append(key, node.Key[:len(node.Key)-1]...) + } else { + key = append(key, node.Key...) + } + } } - - switch node := node.(type) { - case fullNode: - if len(key) > 0 { - k := self.next(node.Children[key[0]], key[1:], isIterStart) - if k != nil { - return append([]byte{key[0]}, k...) - } - } - - var r byte - if len(key) > 0 { - r = key[0] + 1 - } - - for i := r; i < 16; i++ { - k := self.key(node.Children[i]) - if k != nil { - return append([]byte{i}, k...) - } - } - - case shortNode: - k := remTerm(node.Key) - if vnode, ok := node.Val.(valueNode); ok { - switch bytes.Compare([]byte(k), key) { - case 0: - if isIterStart { - self.Value = vnode - return k - } - case 1: - self.Value = vnode - return k - } - } else { - cnode := node.Val - - var ret []byte - skey := key[len(k):] - if bytes.HasPrefix(key, k) { - ret = self.next(cnode, skey, isIterStart) - } else if bytes.Compare(k, key[:len(k)]) > 0 { - return self.key(node) - } - - if ret != nil { - return append(k, ret...) - } - } - - case hashNode: - rn, err := self.trie.resolveHash(node, nil, nil) - if err != nil && glog.V(logger.Error) { - glog.Errorf("Unhandled trie error: %v", err) - } - return self.next(rn, key, isIterStart) - } - return nil -} - -func (self *Iterator) key(node interface{}) []byte { - switch node := node.(type) { - case shortNode: - // Leaf node - k := remTerm(node.Key) - if vnode, ok := node.Val.(valueNode); ok { - self.Value = vnode - return k - } - return append(k, self.key(node.Val)...) - case fullNode: - if node.Children[16] != nil { - self.Value = node.Children[16].(valueNode) - return []byte{16} - } - for i := 0; i < 16; i++ { - k := self.key(node.Children[i]) - if k != nil { - return append([]byte{byte(i)}, k...) - } - } - case hashNode: - rn, err := self.trie.resolveHash(node, nil, nil) - if err != nil && glog.V(logger.Error) { - glog.Errorf("Unhandled trie error: %v", err) - } - return self.key(rn) - } - return nil + return decodeCompact(key) } // nodeIteratorState represents the iteration state at one particular node of the @@ -199,25 +120,27 @@ func (it *NodeIterator) Next() bool { // step moves the iterator to the next node of the trie. func (it *NodeIterator) step() error { - // Abort if we reached the end of the iteration if it.trie == nil { + // Abort if we reached the end of the iteration return nil } - // Initialize the iterator if we've just started, or pop off the old node otherwise if len(it.stack) == 0 { - // Always start with a collapsed root + // Initialize the iterator if we've just started. root := it.trie.Hash() - it.stack = append(it.stack, &nodeIteratorState{node: hashNode(root[:]), child: -1}) - if it.stack[0].node == nil { - return fmt.Errorf("root node missing: %x", it.trie.Hash()) + state := &nodeIteratorState{node: it.trie.root, child: -1} + if root != emptyRoot { + state.hash = root } + it.stack = append(it.stack, state) } else { + // Continue iterating at the previous node otherwise. it.stack = it.stack[:len(it.stack)-1] if len(it.stack) == 0 { it.trie = nil return nil } } + // Continue iteration to the next child for { parent := it.stack[len(it.stack)-1] @@ -232,7 +155,12 @@ func (it *NodeIterator) step() error { } for parent.child++; parent.child < len(node.Children); parent.child++ { if current := node.Children[parent.child]; current != nil { - it.stack = append(it.stack, &nodeIteratorState{node: current, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(node.hash), + node: current, + parent: ancestor, + child: -1, + }) break } } @@ -242,7 +170,12 @@ func (it *NodeIterator) step() error { break } parent.child++ - it.stack = append(it.stack, &nodeIteratorState{node: node.Val, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(node.hash), + node: node.Val, + parent: ancestor, + child: -1, + }) } else if hash, ok := parent.node.(hashNode); ok { // Hash node, resolve the hash child from the database, then the node itself if parent.child >= 0 { @@ -254,7 +187,12 @@ func (it *NodeIterator) step() error { if err != nil { return err } - it.stack = append(it.stack, &nodeIteratorState{hash: common.BytesToHash(hash), node: node, parent: ancestor, child: -1}) + it.stack = append(it.stack, &nodeIteratorState{ + hash: common.BytesToHash(hash), + node: node, + parent: ancestor, + child: -1, + }) } else { break } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index dc8276116..2bcc3700e 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -34,21 +34,60 @@ func TestIterator(t *testing.T) { {"dog", "puppy"}, {"somethingveryoddindeedthis is", "myothernodedata"}, } - v := make(map[string]bool) + all := make(map[string]string) for _, val := range vals { - v[val.k] = false + all[val.k] = val.v trie.Update([]byte(val.k), []byte(val.v)) } trie.Commit() + found := make(map[string]string) it := NewIterator(trie) for it.Next() { - v[string(it.Key)] = true + found[string(it.Key)] = string(it.Value) } - for k, found := range v { - if !found { - t.Error("iterator didn't find", k) + for k, v := range all { + if found[k] != v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) + } + } +} + +type kv struct { + k, v []byte + t bool +} + +func TestIteratorLargeData(t *testing.T) { + trie := newEmpty() + vals := make(map[string]*kv) + + for i := byte(0); i < 255; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + + it := NewIterator(trie) + for it.Next() { + vals[string(it.Key)].t = true + } + + var untouched []*kv + for _, value := range vals { + if !value.t { + untouched = append(untouched, value) + } + } + + if len(untouched) > 0 { + t.Errorf("Missed %d nodes", len(untouched)) + for _, value := range untouched { + t.Error(value) } } } diff --git a/trie/proof.go b/trie/proof.go index 5135de047..116c13a1b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -70,15 +70,13 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - if t.hasher == nil { - t.hasher = newHasher() - } + hasher := newHasher() proof := make([]rlp.RawValue, 0, len(nodes)) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. - n, _, _ = t.hasher.hashChildren(n, nil) - hn, _ := t.hasher.store(n, nil, false) + n, _, _ = hasher.hashChildren(n, nil) + hn, _ := hasher.store(n, nil, false) if _, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 1d027c102..2a8b57214 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -17,16 +17,15 @@ package trie import ( - "hash" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" ) var secureKeyPrefix = []byte("secure-key-") +const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash + // SecureTrie wraps a trie with key hashing. In a secure trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that @@ -38,12 +37,11 @@ var secureKeyPrefix = []byte("secure-key-") // // SecureTrie is not safe for concurrent use. type SecureTrie struct { - *Trie - - hash hash.Hash - hashKeyBuf []byte - secKeyBuf []byte - secKeyCache map[string][]byte + trie Trie + hashKeyBuf [secureKeyLength]byte + secKeyBuf [200]byte + secKeyCache map[string][]byte + secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch } // NewSecure creates a trie with an existing root node from db. @@ -61,8 +59,7 @@ func NewSecure(root common.Hash, db Database) (*SecureTrie, error) { return nil, err } return &SecureTrie{ - Trie: trie, - secKeyCache: make(map[string][]byte), + trie: *trie, }, nil } @@ -80,7 +77,7 @@ func (t *SecureTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.Trie.TryGet(t.hashKey(key)) + return t.trie.TryGet(t.hashKey(key)) } // Update associates key with value in the trie. Subsequent calls to @@ -105,11 +102,11 @@ func (t *SecureTrie) Update(key, value []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryUpdate(key, value []byte) error { hk := t.hashKey(key) - err := t.Trie.TryUpdate(hk, value) + err := t.trie.TryUpdate(hk, value) if err != nil { return err } - t.secKeyCache[string(hk)] = common.CopyBytes(key) + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) return nil } @@ -124,17 +121,17 @@ func (t *SecureTrie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryDelete(key []byte) error { hk := t.hashKey(key) - delete(t.secKeyCache, string(hk)) - return t.Trie.TryDelete(hk) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) } // GetKey returns the sha3 preimage of a hashed key that was // previously used to store a value. func (t *SecureTrie) GetKey(shaKey []byte) []byte { - if key, ok := t.secKeyCache[string(shaKey)]; ok { + if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.Trie.db.Get(t.secKey(shaKey)) + key, _ := t.trie.db.Get(t.secKey(shaKey)) return key } @@ -144,7 +141,23 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { // Committing flushes nodes from memory. Subsequent Get calls will load nodes // from the database. func (t *SecureTrie) Commit() (root common.Hash, err error) { - return t.CommitTo(t.db) + return t.CommitTo(t.trie.db) +} + +func (t *SecureTrie) Hash() common.Hash { + return t.trie.Hash() +} + +func (t *SecureTrie) Root() []byte { + return t.trie.Root() +} + +func (t *SecureTrie) Iterator() *Iterator { + return t.trie.Iterator() +} + +func (t *SecureTrie) NodeIterator() *NodeIterator { + return NewNodeIterator(&t.trie) } // CommitTo writes all nodes and the secure hash pre-images to the given database. @@ -154,7 +167,7 @@ func (t *SecureTrie) Commit() (root common.Hash, err error) { // the trie's database. Calling code must ensure that the changes made to db are // written back to the trie's attached database before using the trie. func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - if len(t.secKeyCache) > 0 { + if len(t.getSecKeyCache()) > 0 { for hk, key := range t.secKeyCache { if err := db.Put(t.secKey([]byte(hk)), key); err != nil { return common.Hash{}, err @@ -162,27 +175,37 @@ func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { } t.secKeyCache = make(map[string][]byte) } - n, clean, err := t.hashRoot(db) - if err != nil { - return (common.Hash{}), err - } - t.root = clean - return common.BytesToHash(n.(hashNode)), nil + return t.trie.CommitTo(db) } +// secKey returns the database key for the preimage of key, as an ephemeral buffer. +// The caller must not hold onto the return value because it will become +// invalid on the next call to hashKey or secKey. func (t *SecureTrie) secKey(key []byte) []byte { - t.secKeyBuf = append(t.secKeyBuf[:0], secureKeyPrefix...) - t.secKeyBuf = append(t.secKeyBuf, key...) - return t.secKeyBuf + buf := append(t.secKeyBuf[:0], secureKeyPrefix...) + buf = append(buf, key...) + return buf } +// hashKey returns the hash of key as an ephemeral buffer. +// The caller must not hold onto the return value because it will become +// invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - if t.hash == nil { - t.hash = sha3.NewKeccak256() - t.hashKeyBuf = make([]byte, 32) - } - t.hash.Reset() - t.hash.Write(key) - t.hashKeyBuf = t.hash.Sum(t.hashKeyBuf[:0]) - return t.hashKeyBuf + h := newHasher() + h.sha.Reset() + h.sha.Write(key) + buf := h.sha.Sum(t.hashKeyBuf[:0]) + returnHasherToPool(h) + return buf +} + +// getSecKeyCache returns the current secure key cache, creating a new one if +// ownership changed (i.e. the current secure trie is a copy of another owning +// the actual cache). +func (t *SecureTrie) getSecKeyCache() map[string][]byte { + if t != t.secKeyCacheOwner { + t.secKeyCacheOwner = t + t.secKeyCache = make(map[string][]byte) + } + return t.secKeyCache } diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index 0be5b3d15..3171b8c31 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -18,6 +18,8 @@ package trie import ( "bytes" + "runtime" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -31,6 +33,37 @@ func newEmptySecure() *SecureTrie { return trie } +// makeTestSecureTrie creates a large enough secure trie for testing. +func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { + // Create an empty trie + db, _ := ethdb.NewMemDatabase() + trie, _ := NewSecure(common.Hash{}, db) + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate th trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + trie.Commit() + + // Return the generated trie + return db, trie, content +} + func TestSecureDelete(t *testing.T) { trie := newEmptySecure() vals := []struct{ k, v string }{ @@ -72,3 +105,41 @@ func TestSecureGetKey(t *testing.T) { t.Errorf("GetKey returned %q, want %q", k, key) } } + +func TestSecureTrieConcurrency(t *testing.T) { + // Create an initial trie and copy if for concurrent access + _, trie, _ := makeTestSecureTrie() + + threads := runtime.NumCPU() + tries := make([]*SecureTrie, threads) + for i := 0; i < threads; i++ { + cpy := *trie + tries[i] = &cpy + } + // Start a batch of goroutines interactng with the trie + pend := new(sync.WaitGroup) + pend.Add(threads) + for i := 0; i < threads; i++ { + go func(index int) { + defer pend.Done() + + for j := byte(0); j < 255; j++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} + tries[index].Update(key, val) + + key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} + tries[index].Update(key, val) + + // Add some other data to inflate the trie + for k := byte(3); k < 13; k++ { + key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} + tries[index].Update(key, val) + } + } + tries[index].Commit() + }(i) + } + // Wait for all threads to finish + pend.Wait() +} diff --git a/trie/sync_test.go b/trie/sync_test.go index a81f7650e..a763dc564 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -51,9 +51,6 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { } trie.Commit() - // Remove any potentially cached data from the test trie creation - globalCache.Clear() - // Return the generated trie return db, trie, content } @@ -61,9 +58,6 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { // checkTrieContents cross references a reconstructed trie with an expected data // content map. func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) { - // Remove any potentially cached data from the trie synchronisation - globalCache.Clear() - // Check root availability and trie contents trie, err := New(common.BytesToHash(root), db) if err != nil { @@ -81,9 +75,6 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin // checkTrieConsistency checks that all nodes in a trie are indeed present. func checkTrieConsistency(db Database, root common.Hash) error { - // Remove any potentially cached data from the test trie creation or previous checks - globalCache.Clear() - // Create and iterate a trie rooted in a subnode trie, err := New(root, db) if err != nil { diff --git a/trie/trie.go b/trie/trie.go index a530e7b2a..93e189e2e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -20,22 +20,14 @@ package trie import ( "bytes" "fmt" - "hash" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/crypto/sha3" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" - "github.com/ethereum/go-ethereum/rlp" ) -const defaultCacheCapacity = 800 - var ( - // The global cache stores decoded trie nodes by hash as they get loaded. - globalCache = newARC(defaultCacheCapacity) - // This is the known root hash of an empty trie. emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") @@ -43,11 +35,6 @@ var ( emptyState = crypto.Keccak256Hash(nil) ) -// ClearGlobalCache clears the global trie cache -func ClearGlobalCache() { - globalCache.Clear() -} - // Database must be implemented by backing stores for the trie. type Database interface { DatabaseWriter @@ -72,7 +59,6 @@ type Trie struct { root node db Database originalRoot common.Hash - *hasher } // New creates a trie with an existing root node from db. @@ -118,32 +104,50 @@ func (t *Trie) Get(key []byte) []byte { // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { key = compactHexDecode(key) - pos := 0 - tn := t.root - for pos < len(key) { - switch n := tn.(type) { - case shortNode: - if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { - return nil, nil - } - tn = n.Val - pos += len(n.Key) - case fullNode: - tn = n.Children[key[pos]] - pos++ - case nil: - return nil, nil - case hashNode: - var err error - tn, err = t.resolveHash(n, key[:pos], key[pos:]) - if err != nil { - return nil, err - } - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) - } + value, newroot, didResolve, err := t.tryGet(t.root, key, 0) + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + +func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, false, nil + case shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + if err == nil && didResolve { + n.Val = newnode + return value, n, didResolve, err + } else { + return value, origNode, didResolve, err + } + case fullNode: + child := n.Children[key[pos]] + value, newnode, didResolve, err = t.tryGet(child, key, pos+1) + if err == nil && didResolve { + n.Children[key[pos]] = newnode + return value, n, didResolve, err + } else { + return value, origNode, didResolve, err + } + case hashNode: + child, err := t.resolveHash(n, key[:pos], key[pos:]) + if err != nil { + return nil, n, true, err + } + value, newnode, _, err := t.tryGet(child, key, pos) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } - return tn.(valueNode), nil } // Update associates key with value in the trie. Subsequent calls to @@ -410,9 +414,6 @@ func (t *Trie) resolve(n node, prefix, suffix []byte) (node, error) { } func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { - if v, ok := globalCache.Get(n); ok { - return v, nil - } enc, err := t.db.Get(n) if err != nil || enc == nil { return nil, &MissingNodeError{ @@ -424,9 +425,6 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) { } } dec := mustDecodeNode(n, enc) - if dec != nil { - globalCache.Put(n, dec) - } return dec, nil } @@ -474,127 +472,7 @@ func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } - if t.hasher == nil { - t.hasher = newHasher() - } - return t.hasher.hash(t.root, db, true) -} - -type hasher struct { - tmp *bytes.Buffer - sha hash.Hash -} - -func newHasher() *hasher { - return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} -} - -// hash collapses a node down into a hash node, also returning a copy of the -// original node initialzied with the computed hash to replace the original one. -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { - // If we're not storing the node, just hashing, use avaialble cached data - if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { - return hash, n, nil - } - // Trie not processed yet or needs storage, walk the children - collapsed, cached, err := h.hashChildren(n, db) - if err != nil { - return hashNode{}, n, err - } - hashed, err := h.store(collapsed, db, force) - if err != nil { - return hashNode{}, n, err - } - // Cache the hash and RLP blob of the ndoe for later reuse - if hash, ok := hashed.(hashNode); ok && !force { - switch cached := cached.(type) { - case shortNode: - cached.hash = hash - if db != nil { - cached.dirty = false - } - return hashed, cached, nil - case fullNode: - cached.hash = hash - if db != nil { - cached.dirty = false - } - return hashed, cached, nil - } - } - return hashed, cached, nil -} - -// hashChildren replaces the children of a node with their hashes if the encoded -// size of the child is larger than a hash, returning the collapsed node as well -// as a replacement for the original node with the child hashes cached in. -func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { - var err error - - switch n := original.(type) { - case shortNode: - // Hash the short node's child, caching the newly hashed subtree - cached := n - cached.Key = common.CopyBytes(cached.Key) - - n.Key = compactEncode(n.Key) - if _, ok := n.Val.(valueNode); !ok { - if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { - return n, original, err - } - } - if n.Val == nil { - n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. - } - return n, cached, nil - - case fullNode: - // Hash the full node's children, caching the newly hashed subtrees - cached := fullNode{dirty: n.dirty} - - for i := 0; i < 16; i++ { - if n.Children[i] != nil { - if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { - return n, original, err - } - } else { - n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. - } - } - cached.Children[16] = n.Children[16] - if n.Children[16] == nil { - n.Children[16] = valueNode(nil) - } - return n, cached, nil - - default: - // Value and hash nodes don't have children so they're left as were - return n, original, nil - } -} - -func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { - // Don't store hashes or empty nodes. - if _, isHash := n.(hashNode); n == nil || isHash { - return n, nil - } - // Generate the RLP encoding of the node - h.tmp.Reset() - if err := rlp.Encode(h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } - if h.tmp.Len() < 32 && !force { - return n, nil // Nodes smaller than 32 bytes are stored inside their parent - } - // Larger nodes are replaced by their hash and stored in the database. - hash, _ := n.cache() - if hash == nil { - h.sha.Reset() - h.sha.Write(h.tmp.Bytes()) - hash = hashNode(h.sha.Sum(nil)) - } - if db != nil { - return hash, db.Put(hash, h.tmp.Bytes()) - } - return hash, nil + h := newHasher() + defer returnHasherToPool(h) + return h.hash(t.root, db, true) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 121ba24c1..5a3ea1be9 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -76,8 +76,6 @@ func TestMissingNode(t *testing.T) { updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") root, _ := trie.Commit() - ClearGlobalCache() - trie, _ = New(root, db) _, err := trie.TryGet([]byte("120000")) if err != nil { @@ -109,7 +107,6 @@ func TestMissingNode(t *testing.T) { } db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) - ClearGlobalCache() trie, _ = New(root, db) _, err = trie.TryGet([]byte("120000")) @@ -362,44 +359,6 @@ func TestLargeValue(t *testing.T) { } -type kv struct { - k, v []byte - t bool -} - -func TestLargeData(t *testing.T) { - trie := newEmpty() - vals := make(map[string]*kv) - - for i := byte(0); i < 255; i++ { - value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} - value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} - trie.Update(value.k, value.v) - trie.Update(value2.k, value2.v) - vals[string(value.k)] = value - vals[string(value2.k)] = value2 - } - - it := NewIterator(trie) - for it.Next() { - vals[string(it.Key)].t = true - } - - var untouched []*kv - for _, value := range vals { - if !value.t { - untouched = append(untouched, value) - } - } - - if len(untouched) > 0 { - t.Errorf("Missed %d nodes", len(untouched)) - for _, value := range untouched { - t.Error(value) - } - } -} - func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }