package merkle import ( "crypto/sha256" . "github.com/tendermint/tendermint/binary" "io" ) // Node type IAVLNode struct { key interface{} value interface{} size uint64 height uint8 hash []byte leftHash []byte leftNode *IAVLNode rightHash []byte rightNode *IAVLNode persisted bool } func NewIAVLNode(key interface{}, value interface{}) *IAVLNode { return &IAVLNode{ key: key, value: value, size: 1, } } func ReadIAVLNode(t *IAVLTree, r io.Reader, n *int64, err *error) *IAVLNode { node := &IAVLNode{} // node header & key node.height = ReadUInt8(r, n, err) node.size = ReadUInt64(r, n, err) node.key = t.keyCodec.Decode(r, n, err) if *err != nil { panic(*err) } // node value or children. if node.height == 0 { node.value = t.valueCodec.Decode(r, n, err) } else { node.leftHash = ReadByteSlice(r, n, err) node.rightHash = ReadByteSlice(r, n, err) } if *err != nil { panic(*err) } return node } func (node *IAVLNode) _copy() *IAVLNode { if node.height == 0 { panic("Why are you copying a value node?") } return &IAVLNode{ key: node.key, size: node.size, height: node.height, hash: nil, // Going to be mutated anyways. leftHash: node.leftHash, leftNode: node.leftNode, rightHash: node.rightHash, rightNode: node.rightNode, persisted: node.persisted, } } func (node *IAVLNode) has(t *IAVLTree, key interface{}) (has bool) { if t.keyCodec.Compare(node.key, key) == 0 { return true } if node.height == 0 { return false } else { if t.keyCodec.Compare(key, node.key) < 0 { return node.getLeftNode(t).has(t, key) } else { return node.getRightNode(t).has(t, key) } } } func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value interface{}) { if node.height == 0 { if t.keyCodec.Compare(node.key, key) == 0 { return 0, node.value } else { return 0, nil } } else { if t.keyCodec.Compare(key, node.key) < 0 { return node.getLeftNode(t).get(t, key) } else { rightNode := node.getRightNode(t) index, value = rightNode.get(t, key) index += node.size - rightNode.size return index, value } } } func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, value interface{}) { if node.height == 0 { if index == 0 { return node.key, node.value } else { panic("getByIndex asked for invalid index") } } else { // TODO: could improve this by storing the // sizes as well as left/right hash. leftNode := node.getLeftNode(t) if index < leftNode.size { return leftNode.getByIndex(t, index) } else { return node.getRightNode(t).getByIndex(t, index-leftNode.size) } } } // NOTE: sets hashes recursively func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint64) { if node.hash != nil { return node.hash, 0 } hasher := sha256.New() _, hashCount, err := node.writeToCountHashes(t, hasher) if err != nil { panic(err) } node.hash = hasher.Sum(nil) return node.hash, hashCount + 1 } // NOTE: sets hashes recursively // NOTE: clears leftNode/rightNode recursively func (node *IAVLNode) save(t *IAVLTree) []byte { if node.hash == nil { node.hash, _ = node.hashWithCount(t) } if node.persisted { return node.hash } // save children if node.leftNode != nil { node.leftHash = node.leftNode.save(t) node.leftNode = nil } if node.rightNode != nil { node.rightHash = node.rightNode.save(t) node.rightNode = nil } // save node t.ndb.SaveNode(t, node) return node.hash } func (node *IAVLNode) set(t *IAVLTree, key interface{}, value interface{}) (newSelf *IAVLNode, updated bool) { if node.height == 0 { cmp := t.keyCodec.Compare(key, node.key) if cmp < 0 { return &IAVLNode{ key: node.key, height: 1, size: 2, leftNode: NewIAVLNode(key, value), rightNode: node, }, false } else if cmp == 0 { return NewIAVLNode(key, value), true } else { return &IAVLNode{ key: key, height: 1, size: 2, leftNode: node, rightNode: NewIAVLNode(key, value), }, false } } else { node = node._copy() if t.keyCodec.Compare(key, node.key) < 0 { node.leftNode, updated = node.getLeftNode(t).set(t, key, value) node.leftHash = nil } else { node.rightNode, updated = node.getRightNode(t).set(t, key, value) node.rightHash = nil } if updated { return node, updated } else { node.calcHeightAndSize(t) return node.balance(t), updated } } } // newHash/newNode: The new hash or node to replace node after remove. // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. // value: removed value. func (node *IAVLNode) remove(t *IAVLTree, key interface{}) ( newHash []byte, newNode *IAVLNode, newKey interface{}, value interface{}, removed bool) { if node.height == 0 { if t.keyCodec.Compare(key, node.key) == 0 { return nil, nil, nil, node.value, true } else { return nil, node, nil, nil, false } } else { if t.keyCodec.Compare(key, node.key) < 0 { var newLeftHash []byte var newLeftNode *IAVLNode newLeftHash, newLeftNode, newKey, value, removed = node.getLeftNode(t).remove(t, key) if !removed { return nil, node, nil, value, false } else if newLeftHash == nil && newLeftNode == nil { // left node held value, was removed return node.rightHash, node.rightNode, node.key, value, true } node = node._copy() node.leftHash, node.leftNode = newLeftHash, newLeftNode node.calcHeightAndSize(t) return nil, node.balance(t), newKey, value, true } else { var newRightHash []byte var newRightNode *IAVLNode newRightHash, newRightNode, newKey, value, removed = node.getRightNode(t).remove(t, key) if !removed { return nil, node, nil, value, false } else if newRightHash == nil && newRightNode == nil { // right node held value, was removed return node.leftHash, node.leftNode, nil, value, true } node = node._copy() node.rightHash, node.rightNode = newRightHash, newRightNode if newKey != nil { node.key = newKey newKey = nil } node.calcHeightAndSize(t) return nil, node.balance(t), newKey, value, true } } } // NOTE: sets hashes recursively func (node *IAVLNode) writeToCountHashes(t *IAVLTree, w io.Writer) (n int64, hashCount uint64, err error) { // height & size & key WriteUInt8(w, node.height, &n, &err) WriteUInt64(w, node.size, &n, &err) t.keyCodec.Encode(node.key, w, &n, &err) if err != nil { return } if node.height == 0 { // value t.valueCodec.Encode(node.value, w, &n, &err) } else { // left if node.leftNode != nil { leftHash, leftCount := node.leftNode.hashWithCount(t) node.leftHash = leftHash hashCount += leftCount } if node.leftHash == nil { panic("node.leftHash was nil in save") } WriteByteSlice(w, node.leftHash, &n, &err) // right if node.rightNode != nil { rightHash, rightCount := node.rightNode.hashWithCount(t) node.rightHash = rightHash hashCount += rightCount } if node.rightHash == nil { panic("node.rightHash was nil in save") } WriteByteSlice(w, node.rightHash, &n, &err) } return } func (node *IAVLNode) getLeftNode(t *IAVLTree) *IAVLNode { if node.leftNode != nil { return node.leftNode } else { return t.ndb.GetNode(t, node.leftHash) } } func (node *IAVLNode) getRightNode(t *IAVLTree) *IAVLNode { if node.rightNode != nil { return node.rightNode } else { return t.ndb.GetNode(t, node.rightHash) } } func (node *IAVLNode) rotateRight(t *IAVLTree) *IAVLNode { node = node._copy() sl := node.getLeftNode(t)._copy() slrHash, slrCached := sl.rightHash, sl.rightNode sl.rightHash, sl.rightNode = nil, node node.leftHash, node.leftNode = slrHash, slrCached node.calcHeightAndSize(t) sl.calcHeightAndSize(t) return sl } func (node *IAVLNode) rotateLeft(t *IAVLTree) *IAVLNode { node = node._copy() sr := node.getRightNode(t)._copy() srlHash, srlCached := sr.leftHash, sr.leftNode sr.leftHash, sr.leftNode = nil, node node.rightHash, node.rightNode = srlHash, srlCached node.calcHeightAndSize(t) sr.calcHeightAndSize(t) return sr } // NOTE: mutates height and size func (node *IAVLNode) calcHeightAndSize(t *IAVLTree) { node.height = maxUint8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1 node.size = node.getLeftNode(t).size + node.getRightNode(t).size } func (node *IAVLNode) calcBalance(t *IAVLTree) int { return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height) } func (node *IAVLNode) balance(t *IAVLTree) (newSelf *IAVLNode) { balance := node.calcBalance(t) if balance > 1 { if node.getLeftNode(t).calcBalance(t) >= 0 { // Left Left Case return node.rotateRight(t) } else { // Left Right Case node = node._copy() node.leftHash, node.leftNode = nil, node.getLeftNode(t).rotateLeft(t) //node.calcHeightAndSize() return node.rotateRight(t) } } if balance < -1 { if node.getRightNode(t).calcBalance(t) <= 0 { // Right Right Case return node.rotateLeft(t) } else { // Right Left Case node = node._copy() node.rightHash, node.rightNode = nil, node.getRightNode(t).rotateRight(t) //node.calcHeightAndSize() return node.rotateLeft(t) } } // Nothing changed return node } func (node *IAVLNode) traverse(t *IAVLTree, cb func(*IAVLNode) bool) bool { stop := cb(node) if stop { return stop } if node.height > 0 { stop = node.getLeftNode(t).traverse(t, cb) if stop { return stop } stop = node.getRightNode(t).traverse(t, cb) if stop { return stop } } return false } // Only used in testing... func (node *IAVLNode) lmd(t *IAVLTree) *IAVLNode { if node.height == 0 { return node } return node.getLeftNode(t).lmd(t) } // Only used in testing... func (node *IAVLNode) rmd(t *IAVLTree) *IAVLNode { if node.height == 0 { return node } return node.getRightNode(t).rmd(t) }