tendermint/merkle/iavl_node.go

409 lines
9.8 KiB
Go
Raw Normal View History

package merkle
import (
2014-07-01 14:50:24 -07:00
"crypto/sha256"
"io"
2014-12-29 15:14:54 -08:00
"github.com/tendermint/tendermint/binary"
)
// Node
type IAVLNode struct {
2014-10-11 00:52:29 -07:00
key interface{}
value interface{}
2014-10-05 21:21:34 -07:00
size uint64
height uint8
hash []byte
leftHash []byte
2014-10-06 01:46:39 -07:00
leftNode *IAVLNode
2014-10-05 21:21:34 -07:00
rightHash []byte
2014-10-06 01:46:39 -07:00
rightNode *IAVLNode
2014-10-05 21:21:34 -07:00
persisted bool
}
2014-10-11 00:52:29 -07:00
func NewIAVLNode(key interface{}, value interface{}) *IAVLNode {
2014-07-01 14:50:24 -07:00
return &IAVLNode{
2014-10-06 01:46:39 -07:00
key: key,
value: value,
size: 1,
2014-10-05 21:21:34 -07:00
}
}
func ReadIAVLNode(t *IAVLTree, r io.Reader, n *int64, err *error) *IAVLNode {
2014-10-05 21:21:34 -07:00
node := &IAVLNode{}
// node header & key
node.height = binary.ReadUint8(r, n, err)
node.size = binary.ReadUint64(r, n, err)
2014-10-11 00:52:29 -07:00
node.key = t.keyCodec.Decode(r, n, err)
2014-10-06 00:15:37 -07:00
if *err != nil {
panic(*err)
2014-10-05 21:21:34 -07:00
}
// node value or children.
if node.height == 0 {
2014-10-11 00:52:29 -07:00
node.value = t.valueCodec.Decode(r, n, err)
2014-10-05 21:21:34 -07:00
} else {
node.leftHash = binary.ReadByteSlice(r, n, err)
node.rightHash = binary.ReadByteSlice(r, n, err)
2014-07-01 14:50:24 -07:00
}
2014-10-06 00:15:37 -07:00
if *err != nil {
panic(*err)
2014-10-05 21:21:34 -07:00
}
return node
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) _copy() *IAVLNode {
if node.height == 0 {
2014-07-01 14:50:24 -07:00
panic("Why are you copying a value node?")
}
return &IAVLNode{
2014-10-11 00:52:29 -07:00
key: node.key,
size: node.size,
height: node.height,
2014-10-06 01:46:39 -07:00
hash: nil, // Going to be mutated anyways.
2014-10-11 00:52:29 -07:00
leftHash: node.leftHash,
leftNode: node.leftNode,
rightHash: node.rightHash,
rightNode: node.rightNode,
persisted: false, // Going to be mutated, so it can't already be persisted.
2014-07-01 14:50:24 -07:00
}
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) has(t *IAVLTree, key interface{}) (has bool) {
if t.keyCodec.Compare(node.key, key) == 0 {
2014-07-01 14:50:24 -07:00
return true
}
2014-10-11 00:52:29 -07:00
if node.height == 0 {
2014-07-01 14:50:24 -07:00
return false
} else {
2014-10-11 00:52:29 -07:00
if t.keyCodec.Compare(key, node.key) < 0 {
return node.getLeftNode(t).has(t, key)
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
return node.getRightNode(t).has(t, key)
2014-07-01 14:50:24 -07:00
}
}
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) get(t *IAVLTree, key interface{}) (index uint64, value interface{}) {
if node.height == 0 {
if t.keyCodec.Compare(node.key, key) == 0 {
return 0, node.value
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
return 0, nil
2014-07-01 14:50:24 -07:00
}
} else {
2014-10-11 00:52:29 -07:00
if t.keyCodec.Compare(key, node.key) < 0 {
return node.getLeftNode(t).get(t, key)
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
rightNode := node.getRightNode(t)
index, value = rightNode.get(t, key)
index += node.size - rightNode.size
return index, value
2014-07-01 14:50:24 -07:00
}
}
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) getByIndex(t *IAVLTree, index uint64) (key interface{}, value interface{}) {
if node.height == 0 {
if index == 0 {
return node.key, node.value
} else {
panic("getByIndex asked for invalid index")
}
} else {
// TODO: could improve this by storing the
// sizes as well as left/right hash.
leftNode := node.getLeftNode(t)
if index < leftNode.size {
return leftNode.getByIndex(t, index)
} else {
return node.getRightNode(t).getByIndex(t, index-leftNode.size)
}
}
}
2014-10-11 20:39:13 -07:00
// NOTE: sets hashes recursively
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) hashWithCount(t *IAVLTree) ([]byte, uint64) {
if node.hash != nil {
return node.hash, 0
2014-07-01 14:50:24 -07:00
}
hasher := sha256.New()
2014-10-11 00:52:29 -07:00
_, hashCount, err := node.writeToCountHashes(t, hasher)
2014-07-01 14:50:24 -07:00
if err != nil {
panic(err)
}
2014-10-11 00:52:29 -07:00
node.hash = hasher.Sum(nil)
2014-07-01 14:50:24 -07:00
2014-10-11 00:52:29 -07:00
return node.hash, hashCount + 1
2014-05-22 18:08:49 -07:00
}
2014-10-11 20:39:13 -07:00
// NOTE: sets hashes recursively
// NOTE: clears leftNode/rightNode recursively
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) save(t *IAVLTree) []byte {
if node.hash == nil {
node.hash, _ = node.hashWithCount(t)
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
if node.persisted {
return node.hash
2014-07-01 14:50:24 -07:00
}
2014-10-06 01:46:39 -07:00
// save children
2014-10-11 00:52:29 -07:00
if node.leftNode != nil {
node.leftHash = node.leftNode.save(t)
node.leftNode = nil
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
if node.rightNode != nil {
node.rightHash = node.rightNode.save(t)
node.rightNode = nil
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
// save node
2014-10-11 20:39:13 -07:00
t.ndb.SaveNode(t, node)
2014-10-11 00:52:29 -07:00
return node.hash
2014-05-23 17:49:28 -07:00
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) set(t *IAVLTree, key interface{}, value interface{}) (newSelf *IAVLNode, updated bool) {
if node.height == 0 {
cmp := t.keyCodec.Compare(key, node.key)
if cmp < 0 {
2014-07-01 14:50:24 -07:00
return &IAVLNode{
2014-10-11 00:52:29 -07:00
key: node.key,
2014-10-06 01:46:39 -07:00
height: 1,
size: 2,
leftNode: NewIAVLNode(key, value),
2014-10-11 00:52:29 -07:00
rightNode: node,
2014-07-01 14:50:24 -07:00
}, false
2014-10-11 00:52:29 -07:00
} else if cmp == 0 {
2014-07-01 14:50:24 -07:00
return NewIAVLNode(key, value), true
} else {
return &IAVLNode{
2014-10-06 01:46:39 -07:00
key: key,
height: 1,
size: 2,
2014-10-11 00:52:29 -07:00
leftNode: node,
2014-10-06 01:46:39 -07:00
rightNode: NewIAVLNode(key, value),
2014-07-01 14:50:24 -07:00
}, false
}
} else {
2014-10-11 00:52:29 -07:00
node = node._copy()
if t.keyCodec.Compare(key, node.key) < 0 {
node.leftNode, updated = node.getLeftNode(t).set(t, key, value)
node.leftHash = nil
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
node.rightNode, updated = node.getRightNode(t).set(t, key, value)
node.rightHash = nil
2014-07-01 14:50:24 -07:00
}
if updated {
2014-10-11 00:52:29 -07:00
return node, updated
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
node.calcHeightAndSize(t)
return node.balance(t), updated
2014-07-01 14:50:24 -07:00
}
}
}
2014-05-21 16:24:50 -07:00
2014-10-11 00:52:29 -07:00
// newHash/newNode: The new hash or node to replace node after remove.
// newKey: new leftmost leaf key for tree after successfully removing 'key' if changed.
2014-10-06 01:46:39 -07:00
// value: removed value.
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) remove(t *IAVLTree, key interface{}) (
newHash []byte, newNode *IAVLNode, newKey interface{}, value interface{}, removed bool) {
if node.height == 0 {
if t.keyCodec.Compare(key, node.key) == 0 {
return nil, nil, nil, node.value, true
2014-07-01 14:50:24 -07:00
} else {
2014-10-11 00:52:29 -07:00
return nil, node, nil, nil, false
2014-07-01 14:50:24 -07:00
}
} else {
2014-10-11 00:52:29 -07:00
if t.keyCodec.Compare(key, node.key) < 0 {
2014-10-05 21:21:34 -07:00
var newLeftHash []byte
2014-10-06 01:46:39 -07:00
var newLeftNode *IAVLNode
2014-10-11 00:52:29 -07:00
newLeftHash, newLeftNode, newKey, value, removed = node.getLeftNode(t).remove(t, key)
if !removed {
return nil, node, nil, value, false
2014-10-06 01:46:39 -07:00
} else if newLeftHash == nil && newLeftNode == nil { // left node held value, was removed
2014-10-11 00:52:29 -07:00
return node.rightHash, node.rightNode, node.key, value, true
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
node = node._copy()
node.leftHash, node.leftNode = newLeftHash, newLeftNode
node.calcHeightAndSize(t)
return nil, node.balance(t), newKey, value, true
2014-07-01 14:50:24 -07:00
} else {
2014-10-05 21:21:34 -07:00
var newRightHash []byte
2014-10-06 01:46:39 -07:00
var newRightNode *IAVLNode
2014-10-11 00:52:29 -07:00
newRightHash, newRightNode, newKey, value, removed = node.getRightNode(t).remove(t, key)
if !removed {
return nil, node, nil, value, false
2014-10-06 01:46:39 -07:00
} else if newRightHash == nil && newRightNode == nil { // right node held value, was removed
2014-10-11 00:52:29 -07:00
return node.leftHash, node.leftNode, nil, value, true
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
node = node._copy()
node.rightHash, node.rightNode = newRightHash, newRightNode
2014-07-01 14:50:24 -07:00
if newKey != nil {
2014-10-11 00:52:29 -07:00
node.key = newKey
2014-07-01 14:50:24 -07:00
newKey = nil
}
2014-10-11 00:52:29 -07:00
node.calcHeightAndSize(t)
return nil, node.balance(t), newKey, value, true
2014-07-01 14:50:24 -07:00
}
}
2014-05-22 18:08:49 -07:00
}
2014-10-11 20:39:13 -07:00
// NOTE: sets hashes recursively
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) writeToCountHashes(t *IAVLTree, w io.Writer) (n int64, hashCount uint64, err error) {
// height & size & key
binary.WriteUint8(node.height, w, &n, &err)
binary.WriteUint64(node.size, w, &n, &err)
2014-10-11 00:52:29 -07:00
t.keyCodec.Encode(node.key, w, &n, &err)
2014-08-10 16:35:08 -07:00
if err != nil {
return
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
if node.height == 0 {
2014-07-01 14:50:24 -07:00
// value
2014-10-11 00:52:29 -07:00
t.valueCodec.Encode(node.value, w, &n, &err)
2014-07-01 14:50:24 -07:00
} else {
// left
2014-10-11 00:52:29 -07:00
if node.leftNode != nil {
leftHash, leftCount := node.leftNode.hashWithCount(t)
node.leftHash = leftHash
2014-10-05 21:21:34 -07:00
hashCount += leftCount
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
if node.leftHash == nil {
panic("node.leftHash was nil in save")
2014-10-06 01:46:39 -07:00
}
binary.WriteByteSlice(node.leftHash, w, &n, &err)
2014-07-01 14:50:24 -07:00
// right
2014-10-11 00:52:29 -07:00
if node.rightNode != nil {
rightHash, rightCount := node.rightNode.hashWithCount(t)
node.rightHash = rightHash
2014-10-05 21:21:34 -07:00
hashCount += rightCount
2014-07-01 14:50:24 -07:00
}
2014-10-11 00:52:29 -07:00
if node.rightHash == nil {
panic("node.rightHash was nil in save")
2014-10-06 01:46:39 -07:00
}
binary.WriteByteSlice(node.rightHash, w, &n, &err)
2014-07-01 14:50:24 -07:00
}
2014-10-05 21:21:34 -07:00
return
2014-05-23 17:49:28 -07:00
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) getLeftNode(t *IAVLTree) *IAVLNode {
if node.leftNode != nil {
return node.leftNode
2014-10-05 21:21:34 -07:00
} else {
2014-10-11 20:39:13 -07:00
return t.ndb.GetNode(t, node.leftHash)
2014-07-01 14:50:24 -07:00
}
2014-05-22 18:08:49 -07:00
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) getRightNode(t *IAVLTree) *IAVLNode {
if node.rightNode != nil {
return node.rightNode
2014-10-05 21:21:34 -07:00
} else {
2014-10-11 20:39:13 -07:00
return t.ndb.GetNode(t, node.rightHash)
2014-07-01 14:50:24 -07:00
}
2014-05-21 16:24:50 -07:00
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) rotateRight(t *IAVLTree) *IAVLNode {
node = node._copy()
sl := node.getLeftNode(t)._copy()
2014-10-06 01:46:39 -07:00
slrHash, slrCached := sl.rightHash, sl.rightNode
2014-10-11 00:52:29 -07:00
sl.rightHash, sl.rightNode = nil, node
node.leftHash, node.leftNode = slrHash, slrCached
2014-10-11 00:52:29 -07:00
node.calcHeightAndSize(t)
sl.calcHeightAndSize(t)
2014-07-01 14:50:24 -07:00
return sl
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) rotateLeft(t *IAVLTree) *IAVLNode {
node = node._copy()
sr := node.getRightNode(t)._copy()
2014-10-06 01:46:39 -07:00
srlHash, srlCached := sr.leftHash, sr.leftNode
2014-10-11 00:52:29 -07:00
sr.leftHash, sr.leftNode = nil, node
node.rightHash, node.rightNode = srlHash, srlCached
2014-10-11 00:52:29 -07:00
node.calcHeightAndSize(t)
sr.calcHeightAndSize(t)
2014-07-01 14:50:24 -07:00
return sr
}
2014-10-11 20:39:13 -07:00
// NOTE: mutates height and size
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) calcHeightAndSize(t *IAVLTree) {
node.height = maxUint8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1
node.size = node.getLeftNode(t).size + node.getRightNode(t).size
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) calcBalance(t *IAVLTree) int {
return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height)
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) balance(t *IAVLTree) (newSelf *IAVLNode) {
balance := node.calcBalance(t)
2014-07-01 14:50:24 -07:00
if balance > 1 {
2014-10-11 00:52:29 -07:00
if node.getLeftNode(t).calcBalance(t) >= 0 {
2014-07-01 14:50:24 -07:00
// Left Left Case
2014-10-11 00:52:29 -07:00
return node.rotateRight(t)
2014-07-01 14:50:24 -07:00
} else {
// Left Right Case
2014-10-11 00:52:29 -07:00
node = node._copy()
node.leftHash, node.leftNode = nil, node.getLeftNode(t).rotateLeft(t)
//node.calcHeightAndSize()
return node.rotateRight(t)
2014-07-01 14:50:24 -07:00
}
}
if balance < -1 {
2014-10-11 00:52:29 -07:00
if node.getRightNode(t).calcBalance(t) <= 0 {
2014-07-01 14:50:24 -07:00
// Right Right Case
2014-10-11 00:52:29 -07:00
return node.rotateLeft(t)
2014-07-01 14:50:24 -07:00
} else {
// Right Left Case
2014-10-11 00:52:29 -07:00
node = node._copy()
node.rightHash, node.rightNode = nil, node.getRightNode(t).rotateRight(t)
//node.calcHeightAndSize()
return node.rotateLeft(t)
2014-07-01 14:50:24 -07:00
}
}
// Nothing changed
2014-10-11 00:52:29 -07:00
return node
}
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) traverse(t *IAVLTree, cb func(*IAVLNode) bool) bool {
stop := cb(node)
2014-07-01 14:50:24 -07:00
if stop {
return stop
}
2014-10-11 00:52:29 -07:00
if node.height > 0 {
stop = node.getLeftNode(t).traverse(t, cb)
2014-07-01 14:50:24 -07:00
if stop {
return stop
}
2014-10-11 00:52:29 -07:00
stop = node.getRightNode(t).traverse(t, cb)
2014-07-01 14:50:24 -07:00
if stop {
return stop
}
}
return false
}
2014-10-06 01:46:39 -07:00
// Only used in testing...
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) lmd(t *IAVLTree) *IAVLNode {
if node.height == 0 {
return node
2014-10-06 01:46:39 -07:00
}
2014-10-11 00:52:29 -07:00
return node.getLeftNode(t).lmd(t)
2014-10-06 01:46:39 -07:00
}
// Only used in testing...
2014-10-11 00:52:29 -07:00
func (node *IAVLNode) rmd(t *IAVLTree) *IAVLNode {
if node.height == 0 {
return node
2014-10-06 01:46:39 -07:00
}
2014-10-11 00:52:29 -07:00
return node.getRightNode(t).rmd(t)
2014-10-06 01:46:39 -07:00
}