diff --git a/trie/iterator.go b/trie/iterator.go index 64110c6d9..00b890eb8 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -22,6 +22,7 @@ import ( "errors" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" ) // Iterator is a key-value trie iterator that traverses a Trie. @@ -55,31 +56,50 @@ func (it *Iterator) Next() bool { return false } +// Prove generates the Merkle proof for the leaf node the iterator is currently +// positioned on. +func (it *Iterator) Prove() [][]byte { + return it.nodeIt.LeafProof() +} + // NodeIterator is an iterator to traverse the trie pre-order. type NodeIterator interface { // Next moves the iterator to the next node. If the parameter is false, any child // nodes will be skipped. Next(bool) bool + // Error returns the error status of the iterator. Error() error // Hash returns the hash of the current node. Hash() common.Hash + // Parent returns the hash of the parent of the current node. The hash may be the one // grandparent if the immediate parent is an internal node with no hash. Parent() common.Hash + // Path returns the hex-encoded path to the current node. // Callers must not retain references to the return value after calling Next. // For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. Path() []byte // Leaf returns true iff the current node is a leaf node. - // LeafBlob, LeafKey return the contents and key of the leaf node. These - // method panic if the iterator is not positioned at a leaf. - // Callers must not retain references to their return value after calling Next Leaf() bool - LeafBlob() []byte + + // LeafKey returns the key of the leaf. The method panics if the iterator is not + // positioned at a leaf. Callers must not retain references to the value after + // calling Next. LeafKey() []byte + + // LeafBlob returns the content of the leaf. The method panics if the iterator + // is not positioned at a leaf. Callers must not retain references to the value + // after calling Next. + LeafBlob() []byte + + // LeafProof returns the Merkle proof of the leaf. The method panics if the + // iterator is not positioned at a leaf. Callers must not retain references + // to the value after calling Next. + LeafProof() [][]byte } // nodeIteratorState represents the iteration state at one particular node of the @@ -139,6 +159,15 @@ func (it *nodeIterator) Leaf() bool { return hasTerm(it.path) } +func (it *nodeIterator) LeafKey() []byte { + if len(it.stack) > 0 { + if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { + return hexToKeybytes(it.path) + } + } + panic("not at leaf") +} + func (it *nodeIterator) LeafBlob() []byte { if len(it.stack) > 0 { if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { @@ -148,10 +177,22 @@ func (it *nodeIterator) LeafBlob() []byte { panic("not at leaf") } -func (it *nodeIterator) LeafKey() []byte { +func (it *nodeIterator) LeafProof() [][]byte { if len(it.stack) > 0 { if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { - return hexToKeybytes(it.path) + hasher := newHasher(0, 0, nil) + proofs := make([][]byte, 0, len(it.stack)) + + for i, item := range it.stack[:len(it.stack)-1] { + // Gather nodes that end up as hash nodes (or the root) + node, _, _ := hasher.hashChildren(item.node, nil) + hashed, _ := hasher.store(node, nil, false) + if _, ok := hashed.(hashNode); ok || i == 0 { + enc, _ := rlp.EncodeToBytes(node) + proofs = append(proofs, enc) + } + } + return proofs } } panic("not at leaf") @@ -361,12 +402,16 @@ func (it *differenceIterator) Leaf() bool { return it.b.Leaf() } +func (it *differenceIterator) LeafKey() []byte { + return it.b.LeafKey() +} + func (it *differenceIterator) LeafBlob() []byte { return it.b.LeafBlob() } -func (it *differenceIterator) LeafKey() []byte { - return it.b.LeafKey() +func (it *differenceIterator) LeafProof() [][]byte { + return it.b.LeafProof() } func (it *differenceIterator) Path() []byte { @@ -464,12 +509,16 @@ func (it *unionIterator) Leaf() bool { return (*it.items)[0].Leaf() } +func (it *unionIterator) LeafKey() []byte { + return (*it.items)[0].LeafKey() +} + func (it *unionIterator) LeafBlob() []byte { return (*it.items)[0].LeafBlob() } -func (it *unionIterator) LeafKey() []byte { - return (*it.items)[0].LeafKey() +func (it *unionIterator) LeafProof() [][]byte { + return (*it.items)[0].LeafProof() } func (it *unionIterator) Path() []byte { @@ -509,12 +558,10 @@ func (it *unionIterator) Next(descend bool) bool { heap.Push(it.items, skipped) } } - if least.Next(descend) { it.count++ heap.Push(it.items, least) } - return len(*it.items) > 0 } diff --git a/trie/proof_test.go b/trie/proof_test.go index dee6f7d85..996f87478 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -32,20 +32,46 @@ func init() { mrand.Seed(time.Now().Unix()) } +// makeProvers creates Merkle trie provers based on different implementations to +// test all variations. +func makeProvers(trie *Trie) []func(key []byte) *ethdb.MemDatabase { + var provers []func(key []byte) *ethdb.MemDatabase + + // Create a direct trie based Merkle prover + provers = append(provers, func(key []byte) *ethdb.MemDatabase { + proof := ethdb.NewMemDatabase() + trie.Prove(key, 0, proof) + return proof + }) + // Create a leaf iterator based Merkle prover + provers = append(provers, func(key []byte) *ethdb.MemDatabase { + proof := ethdb.NewMemDatabase() + if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) { + for _, p := range it.Prove() { + proof.Put(crypto.Keccak256(p), p) + } + } + return proof + }) + return provers +} + func TestProof(t *testing.T) { trie, vals := randomTrie(500) root := trie.Hash() - for _, kv := range vals { - proofs := ethdb.NewMemDatabase() - if trie.Prove(kv.k, 0, proofs) != nil { - t.Fatalf("missing key %x while constructing proof", kv.k) - } - val, _, err := VerifyProof(root, kv.k, proofs) - if err != nil { - t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs) - } - if !bytes.Equal(val, kv.v) { - t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v) + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) + } + val, _, err := VerifyProof(root, kv.k, proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) + } } } } @@ -53,37 +79,66 @@ func TestProof(t *testing.T) { func TestOneElementProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") - proofs := ethdb.NewMemDatabase() - trie.Prove([]byte("k"), 0, proofs) - if len(proofs.Keys()) != 1 { - t.Error("proof should have one element") - } - val, _, err := VerifyProof(trie.Hash(), []byte("k"), proofs) - if err != nil { - t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys()) - } - if !bytes.Equal(val, []byte("v")) { - t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val) + for i, prover := range makeProvers(trie) { + proof := prover([]byte("k")) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + if proof.Len() != 1 { + t.Errorf("prover %d: proof should have one element", i) + } + val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if !bytes.Equal(val, []byte("v")) { + t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) + } } } -func TestVerifyBadProof(t *testing.T) { +func TestBadProof(t *testing.T) { trie, vals := randomTrie(800) root := trie.Hash() - for _, kv := range vals { - proofs := ethdb.NewMemDatabase() - trie.Prove(kv.k, 0, proofs) - if len(proofs.Keys()) == 0 { - t.Fatal("zero length proof") + for i, prover := range makeProvers(trie) { + for _, kv := range vals { + proof := prover(kv.k) + if proof == nil { + t.Fatalf("prover %d: nil proof", i) + } + key := proof.Keys()[mrand.Intn(proof.Len())] + val, _ := proof.Get(key) + proof.Delete(key) + + mutateByte(val) + proof.Put(crypto.Keccak256(val), val) + + if _, _, err := VerifyProof(root, kv.k, proof); err == nil { + t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) + } } - keys := proofs.Keys() - key := keys[mrand.Intn(len(keys))] - node, _ := proofs.Get(key) - proofs.Delete(key) - mutateByte(node) - proofs.Put(crypto.Keccak256(node), node) - if _, _, err := VerifyProof(root, kv.k, proofs); err == nil { - t.Fatalf("expected proof to fail for key %x", kv.k) + } +} + +// Tests that missing keys can also be proven. The test explicitly uses a single +// entry trie and checks for missing keys both before and after the single entry. +func TestMissingKeyProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + for i, key := range []string{"a", "j", "l", "z"} { + proof := ethdb.NewMemDatabase() + trie.Prove([]byte(key), 0, proof) + + if proof.Len() != 1 { + t.Errorf("test %d: proof should have one element", i) + } + val, _, err := VerifyProof(trie.Hash(), []byte(key), proof) + if err != nil { + t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if val != nil { + t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val) } } }