diff --git a/ethutil/trie.go b/ethutil/trie.go index 322f77647..83527d364 100644 --- a/ethutil/trie.go +++ b/ethutil/trie.go @@ -70,6 +70,12 @@ func (cache *Cache) Get(key []byte) *Value { return value } +func (cache *Cache) Delete(key []byte) { + delete(cache.nodes, string(key)) + + cache.db.Delete(key) +} + func (cache *Cache) Commit() { // Don't try to commit if it isn't dirty if !cache.IsDirty { @@ -413,3 +419,82 @@ func (t *Trie) Copy() *Trie { return trie } + +type TrieIterator struct { + trie *Trie + key string + value string + + shas [][]byte + values []string +} + +func (t *Trie) NewIterator() *TrieIterator { + return &TrieIterator{trie: t} +} + +// Some time in the near future this will need refactoring :-) +// XXX Note to self, IsSlice == inline node. Str == sha3 to node +func (it *TrieIterator) workNode(currentNode *Value) { + if currentNode.Len() == 2 { + k := CompactDecode(currentNode.Get(0).Str()) + + if currentNode.Get(1).IsSlice() { + it.workNode(currentNode.Get(1)) + } else { + if k[len(k)-1] == 16 { + it.values = append(it.values, currentNode.Get(1).Str()) + } else { + it.shas = append(it.shas, currentNode.Get(1).Bytes()) + it.getNode(currentNode.Get(1).Bytes()) + } + } + } else { + for i := 0; i < currentNode.Len(); i++ { + if i == 16 && currentNode.Get(i).Len() != 0 { + it.values = append(it.values, currentNode.Get(i).Str()) + } else { + if currentNode.Get(i).IsSlice() { + it.workNode(currentNode.Get(i)) + } else { + val := currentNode.Get(i).Str() + if val != "" { + it.shas = append(it.shas, currentNode.Get(1).Bytes()) + it.getNode([]byte(val)) + } + } + } + } + } +} + +func (it *TrieIterator) getNode(node []byte) { + currentNode := it.trie.cache.Get(node) + it.workNode(currentNode) +} + +func (it *TrieIterator) Collect() [][]byte { + if it.trie.Root == "" { + return nil + } + + it.getNode(NewValue(it.trie.Root).Bytes()) + + return it.shas +} + +func (it *TrieIterator) Purge() int { + shas := it.Collect() + for _, sha := range shas { + it.trie.cache.Delete(sha) + } + return len(it.values) +} + +func (it *TrieIterator) Key() string { + return "" +} + +func (it *TrieIterator) Value() string { + return "" +} diff --git a/ethutil/trie_test.go b/ethutil/trie_test.go index 9d2c8e19f..c3a8f224d 100644 --- a/ethutil/trie_test.go +++ b/ethutil/trie_test.go @@ -1,6 +1,7 @@ package ethutil import ( + "fmt" "reflect" "testing" ) @@ -21,6 +22,10 @@ func (db *MemDatabase) Put(key []byte, value []byte) { func (db *MemDatabase) Get(key []byte) ([]byte, error) { return db.db[string(key)], nil } +func (db *MemDatabase) Delete(key []byte) error { + delete(db.db, string(key)) + return nil +} func (db *MemDatabase) Print() {} func (db *MemDatabase) Close() {} func (db *MemDatabase) LastKnownTD() []byte { return nil } @@ -148,3 +153,22 @@ func TestTrieDeleteWithValue(t *testing.T) { } } + +func TestTrieIterator(t *testing.T) { + _, trie := New() + trie.Update("c", LONG_WORD) + trie.Update("ca", LONG_WORD) + trie.Update("cat", LONG_WORD) + + it := trie.NewIterator() + fmt.Println("purging") + fmt.Println("len =", it.Purge()) + /* + for it.Next() { + k := it.Key() + v := it.Value() + + fmt.Println(k, v) + } + */ +}