diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index a94b557d..a2865205 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -10,43 +10,71 @@ import ( // Node type IAVLNode struct { - key []byte - value []byte - size uint64 - height uint8 - hash []byte - left *IAVLNode - right *IAVLNode + key []byte + value []byte + size uint64 + height uint8 + hash []byte + leftHash []byte + rightHash []byte + persisted bool - // volatile - flags byte + // May or may not be persisted nodes, but they'll get cleared + // when this this node is saved. + leftCached *IAVLNode + rightCached *IAVLNode } -const ( - IAVLNODE_FLAG_PERSISTED = byte(0x01) - IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) -) - func NewIAVLNode(key []byte, value []byte) *IAVLNode { return &IAVLNode{ - key: key, - value: value, - size: 1, + key: key, + value: value, + size: 1, + persisted: false, } } +func ReadIAVLNode(r io.Reader, n *int64, err *error) *IAVLNode { + node := &IAVLNode{} + + // node header & key + node.height = ReadUInt8(r, &n, &err) + node.size = ReadUInt64(r, &n, &err) + node.key = ReadByteSlice(r, &n, &err) + if err != nil { + panic(err) + } + + // node value or children. + if node.height == 0 { + // value + node.value = ReadByteSlice(r, &n, &err) + } else { + // left + node.leftHash = ReadByteSlice(r, &n, &err) + // right + node.rightHash = ReadByteSlice(r, &n, &err) + } + if err != nil { + panic(err) + } + return node +} + func (self *IAVLNode) Copy() *IAVLNode { if self.height == 0 { panic("Why are you copying a value node?") } return &IAVLNode{ - key: self.key, - size: self.size, - height: self.height, - left: self.left, - right: self.right, - hash: nil, - flags: byte(0), + key: self.key, + size: self.size, + height: self.height, + hash: nil, // Going to be mutated anyways. + leftHash: self.leftHash, + rightHash: self.rightHash, + persisted: self.persisted, + leftCached: self.leftCached, + rightCached: self.rightCached, } } @@ -58,7 +86,7 @@ func (self *IAVLNode) Height() uint8 { return self.height } -func (self *IAVLNode) has(db Db, key []byte) (has bool) { +func (self *IAVLNode) has(ndb *IAVLNodeDB, key []byte) (has bool) { if bytes.Equal(self.key, key) { return true } @@ -66,14 +94,14 @@ func (self *IAVLNode) has(db Db, key []byte) (has bool) { return false } else { if bytes.Compare(key, self.key) == -1 { - return self.leftFilled(db).has(db, key) + return self.getLeft(ndb).has(ndb, key) } else { - return self.rightFilled(db).has(db, key) + return self.getRight(ndb).has(ndb, key) } } } -func (self *IAVLNode) get(db Db, key []byte) (value []byte) { +func (self *IAVLNode) get(ndb *IAVLNodeDB, key []byte) (value []byte) { if self.height == 0 { if bytes.Equal(self.key, key) { return self.value @@ -82,9 +110,9 @@ func (self *IAVLNode) get(db Db, key []byte) (value []byte) { } } else { if bytes.Compare(key, self.key) == -1 { - return self.leftFilled(db).get(db, key) + return self.getLeft(ndb).get(ndb, key) } else { - return self.rightFilled(db).get(db, key) + return self.getRight(ndb).get(ndb, key) } } } @@ -104,105 +132,108 @@ func (self *IAVLNode) HashWithCount() ([]byte, uint64) { return self.hash, hashCount + 1 } -func (self *IAVLNode) Save(db Db) { +func (self *IAVLNode) Save(ndb *IAVLNodeDB) []byte { if self.hash == nil { - panic("savee.hash can't be nil") + hash, _ := self.HashWithCount() + self.hash = hash } - if self.flags&IAVLNODE_FLAG_PERSISTED > 0 || - self.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { - return + if self.persisted { + return self.hash } // children - if self.height > 0 { - self.left.Save(db) - self.right.Save(db) + if self.leftCached != nil { + self.leftHash = self.leftCached.Save(ndb) + self.leftCached = nil + } + if self.rightCached != nil { + self.rightHash = self.rightCached.Save(ndb) + self.rightCached = nil } // save self - buf := bytes.NewBuffer(nil) - _, err := self.WriteTo(buf) - if err != nil { - panic(err) - } - db.Set([]byte(self.hash), buf.Bytes()) - - self.flags |= IAVLNODE_FLAG_PERSISTED + ndb.Save(self) + return self.hash } -func (self *IAVLNode) set(db Db, key []byte, value []byte) (_ *IAVLNode, updated bool) { +func (self *IAVLNode) set(ndb *IAVLNodeDB, key []byte, value []byte) (_ *IAVLNode, updated bool) { if self.height == 0 { if bytes.Compare(key, self.key) == -1 { return &IAVLNode{ - key: self.key, - height: 1, - size: 2, - left: NewIAVLNode(key, value), - right: self, + key: self.key, + height: 1, + size: 2, + leftCached: NewIAVLNode(key, value), + rightCached: self, }, false } else if bytes.Equal(self.key, key) { return NewIAVLNode(key, value), true } else { return &IAVLNode{ - key: key, - height: 1, - size: 2, - left: self, - right: NewIAVLNode(key, value), + key: key, + height: 1, + size: 2, + leftCached: self, + rightCached: NewIAVLNode(key, value), }, false } } else { self = self.Copy() if bytes.Compare(key, self.key) == -1 { - self.left, updated = self.leftFilled(db).set(db, key, value) + self.leftCached, updated = self.getLeft(ndb).set(ndb, key, value) + self.leftHash = nil } else { - self.right, updated = self.rightFilled(db).set(db, key, value) + self.rightCached, updated = self.getRight(ndb).set(ndb, key, value) + self.rightHash = nil } if updated { return self, updated } else { - self.calcHeightAndSize(db) - return self.balance(db), updated + self.calcHeightAndSize(ndb) + return self.balance(ndb), updated } } } // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. -func (self *IAVLNode) remove(db Db, key []byte) (newSelf *IAVLNode, newKey []byte, value []byte, err error) { +// only one of newSelfHash or newSelf is returned. +func (self *IAVLNode) remove(ndb *IAVLNodeDB, key []byte) (newSelfHash []byte, newSelf *IAVLNode, newKey []byte, value []byte, err error) { if self.height == 0 { if bytes.Equal(self.key, key) { - return nil, nil, self.value, nil + return nil, nil, nil, self.value, nil } else { - return self, nil, nil, NotFound(key) + return nil, self, nil, nil, NotFound(key) } } else { if bytes.Compare(key, self.key) == -1 { + var newLeftHash []byte var newLeft *IAVLNode - newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) + newLeftHash, newLeft, newKey, value, err = self.getLeft(ndb).remove(ndb, key) if err != nil { - return self, nil, value, err - } else if newLeft == nil { // left node held value, was removed - return self.right, self.key, value, nil + return nil, self, nil, value, err + } else if newLeftHash == nil && newLeft == nil { // left node held value, was removed + return self.rightHash, self.rightCached, self.key, value, nil } self = self.Copy() - self.left = newLeft + self.leftHash, self.leftCached = newLeftHash, newLeft } else { + var newRightHash []byte var newRight *IAVLNode - newRight, newKey, value, err = self.rightFilled(db).remove(db, key) + newRightHash, newRight, newKey, value, err = self.getRight(ndb).remove(ndb, key) if err != nil { - return self, nil, value, err - } else if newRight == nil { // right node held value, was removed - return self.left, nil, value, nil + return nil, self, nil, value, err + } else if newRightHash == nil && newRight == nil { // right node held value, was removed + return self.leftHash, self.leftCached, nil, value, nil } self = self.Copy() - self.right = newRight + self.rightHash, self.rightCached = newRightHash, newRight if newKey != nil { self.key = newKey newKey = nil } } - self.calcHeightAndSize(db) - return self.balance(db), newKey, value, err + self.calcHeightAndSize(ndb) + return nil, self.balance(ndb), newKey, value, err } } @@ -226,170 +257,133 @@ func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, WriteByteSlice(w, self.value, &n, &err) } else { // left - leftHash, leftCount := self.left.HashWithCount() - hashCount += leftCount - WriteByteSlice(w, leftHash, &n, &err) + if self.leftCached != nil { + leftHash, leftCount := self.left.HashWithCount() + self.leftHash = leftHash + hashCount += leftCount + } + WriteByteSlice(w, self.leftHash, &n, &err) // right - rightHash, rightCount := self.right.HashWithCount() - hashCount += rightCount - WriteByteSlice(w, rightHash, &n, &err) + if self.rightCached != nil { + rightHash, rightCount := self.right.HashWithCount() + self.rightHash = rightHash + hashCount += rightCount + } + WriteByteSlice(w, self.rightHash, &n, &err) } return } -// Given a placeholder node which has only the hash set, -// load the rest of the data from db. -// Not threadsafe. -func (self *IAVLNode) fill(db Db) { - if self.hash == nil { - panic("placeholder.hash can't be nil") - } - buf := db.Get(self.hash) - r := bytes.NewReader(buf) - var n int64 - var err error - - // node header & key - self.height = ReadUInt8(r, &n, &err) - self.size = ReadUInt64(r, &n, &err) - self.key = ReadByteSlice(r, &n, &err) - if err != nil { - panic(err) - } - - // node value or children. - if self.height == 0 { - // value - self.value = ReadByteSlice(r, &n, &err) +func (self *IAVLNode) getLeft(ndb *IAVLNodeDB) *IAVLNode { + if self.leftCached != nil { + return self.leftCached } else { - // left - leftHash := ReadByteSlice(r, &n, &err) - self.left = &IAVLNode{ - hash: leftHash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - // right - rightHash := ReadByteSlice(r, &n, &err) - self.right = &IAVLNode{ - hash: rightHash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - if r.Len() != 0 { - panic("buf not all consumed") - } + return ndb.Get(leftHash) } - if err != nil { - panic(err) - } - self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER } -func (self *IAVLNode) leftFilled(db Db) *IAVLNode { - if self.left.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { - self.left.fill(db) +func (self *IAVLNode) getRight(ndb *IAVLNodeDB) *IAVLNode { + if self.rightCached != nil { + return self.rightCached + } else { + return ndb.Get(rightHash) } - return self.left } -func (self *IAVLNode) rightFilled(db Db) *IAVLNode { - if self.right.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 { - self.right.fill(db) - } - return self.right -} - -func (self *IAVLNode) rotateRight(db Db) *IAVLNode { +func (self *IAVLNode) rotateRight(ndb *IAVLNodeDB) *IAVLNode { self = self.Copy() - sl := self.leftFilled(db).Copy() - slr := sl.right + sl := self.getLeft(ndb).Copy() - sl.right = self - self.left = slr + slrHash, slrCached := sl.rightHash, sl.rightCached + sl.rightHash, sl.rightCached = nil, self + self.leftHash, self.leftCached = slrHash, slrCached - self.calcHeightAndSize(db) - sl.calcHeightAndSize(db) + self.calcHeightAndSize(ndb) + sl.calcHeightAndSize(ndb) return sl } -func (self *IAVLNode) rotateLeft(db Db) *IAVLNode { +func (self *IAVLNode) rotateLeft(ndb *IAVLNodeDB) *IAVLNode { self = self.Copy() - sr := self.rightFilled(db).Copy() - srl := sr.left + sr := self.getRight(ndb).Copy() - sr.left = self - self.right = srl + srlHash, srlCached := sr.leftHash, sr.leftCached + sr.leftHash, sr.leftCached = nil, self + self.rightHash, self.rightCached = srlHash, srlCached - self.calcHeightAndSize(db) - sr.calcHeightAndSize(db) + self.calcHeightAndSize(ndb) + sr.calcHeightAndSize(ndb) return sr } -func (self *IAVLNode) calcHeightAndSize(db Db) { - self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1 - self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() +func (self *IAVLNode) calcHeightAndSize(ndb *IAVLNodeDB) { + self.height = maxUint8(self.getLeft(ndb).Height(), self.getRight(ndb).Height()) + 1 + self.size = self.getLeft(ndb).Size() + self.getRight(ndb).Size() } -func (self *IAVLNode) calcBalance(db Db) int { - return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height()) +func (self *IAVLNode) calcBalance(ndb *IAVLNodeDB) int { + return int(self.getLeft(ndb).Height()) - int(self.getRight(ndb).Height()) } -func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) { - balance := self.calcBalance(db) +func (self *IAVLNode) balance(ndb *IAVLNodeDB) (newSelf *IAVLNode) { + balance := self.calcBalance(ndb) if balance > 1 { - if self.leftFilled(db).calcBalance(db) >= 0 { + if self.getLeft(ndb).calcBalance(ndb) >= 0 { // Left Left Case - return self.rotateRight(db) + return self.rotateRight(ndb) } else { // Left Right Case self = self.Copy() - self.left = self.leftFilled(db).rotateLeft(db) + self.leftHash, self.leftCached = nil, self.getLeft(ndb).rotateLeft(ndb) //self.calcHeightAndSize() - return self.rotateRight(db) + return self.rotateRight(ndb) } } if balance < -1 { - if self.rightFilled(db).calcBalance(db) <= 0 { + if self.getRight(ndb).calcBalance(ndb) <= 0 { // Right Right Case - return self.rotateLeft(db) + return self.rotateLeft(ndb) } else { // Right Left Case self = self.Copy() - self.right = self.rightFilled(db).rotateRight(db) + self.rightHash, self.rightCached = nil, self.getRight(ndb).rotateRight(ndb) //self.calcHeightAndSize() - return self.rotateLeft(db) + return self.rotateLeft(ndb) } } // Nothing changed return self } -func (self *IAVLNode) lmd(db Db) *IAVLNode { +// Only used in testing... +func (self *IAVLNode) lmd(ndb *IAVLNodeDB) *IAVLNode { if self.height == 0 { return self } - return self.leftFilled(db).lmd(db) + return self.getLeft(ndb).lmd(ndb) } -func (self *IAVLNode) rmd(db Db) *IAVLNode { +// Only used in testing... +func (self *IAVLNode) rmd(ndb *IAVLNodeDB) *IAVLNode { if self.height == 0 { return self } - return self.rightFilled(db).rmd(db) + return self.getRight(ndb).rmd(ndb) } -func (self *IAVLNode) traverse(db Db, cb func(*IAVLNode) bool) bool { +func (self *IAVLNode) traverse(ndb *IAVLNodeDB, cb func(*IAVLNode) bool) bool { stop := cb(self) if stop { return stop } if self.height > 0 { - stop = self.leftFilled(db).traverse(db, cb) + stop = self.getLeft(ndb).traverse(ndb, cb) if stop { return stop } - stop = self.rightFilled(db).traverse(db, cb) + stop = self.getRight(ndb).traverse(ndb, cb) if stop { return stop } diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index d35b64c8..58323e9b 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -9,38 +9,24 @@ This tree is not concurrency safe. You must wrap your calls with your own mutex. */ type IAVLTree struct { - db Db + ndb IAVLNodeDB root *IAVLNode } -func NewIAVLTree(db Db) *IAVLTree { +func NewIAVLTree(db DB) *IAVLTree { return &IAVLTree{ - db: db, + ndb: NewIAVLNodeDB(db), root: nil, } } -// TODO rename to Load. -func NewIAVLTreeFromHash(db Db, hash []byte) *IAVLTree { - root := &IAVLNode{ - hash: hash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - root.fill(db) - return &IAVLTree{db: db, root: root} -} - -func NewIAVLTreeFromKey(db Db, key string) *IAVLTree { - hash := db.Get([]byte(key)) - if hash == nil { +func LoadIAVLTreeFromHash(db DB, hash []byte) *IAVLTree { + ndb := NewIAVLNodeDB(db) + root := ndb.Get(hash) + if root == nil { return nil } - root := &IAVLNode{ - hash: hash, - flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, - } - root.fill(db) - return &IAVLTree{db: db, root: root} + return &IAVLTree{ndb: ndb, root: root} } func (t *IAVLTree) Size() uint64 { @@ -61,7 +47,7 @@ func (t *IAVLTree) Has(key []byte) bool { if t.root == nil { return false } - return t.root.has(t.db, key) + return t.root.has(t.ndb, key) } func (t *IAVLTree) Set(key []byte, value []byte) (updated bool) { @@ -69,7 +55,7 @@ func (t *IAVLTree) Set(key []byte, value []byte) (updated bool) { t.root = NewIAVLNode(key, value) return false } - t.root, updated = t.root.set(t.db, key, value) + t.root, updated = t.root.set(t.ndb, key, value) return updated } @@ -93,7 +79,7 @@ func (t *IAVLTree) Save() { return } t.root.HashWithCount() - t.root.Save(t.db) + t.root.Save(t.ndb) } func (t *IAVLTree) SaveKey(key string) { @@ -101,29 +87,70 @@ func (t *IAVLTree) SaveKey(key string) { return } hash, _ := t.root.HashWithCount() - t.root.Save(t.db) - t.db.Set([]byte(key), hash) + t.root.Save(t.ndb) + t.ndb.Set([]byte(key), hash) } func (t *IAVLTree) Get(key []byte) (value []byte) { if t.root == nil { return nil } - return t.root.get(t.db, key) + return t.root.get(t.ndb, key) } func (t *IAVLTree) Remove(key []byte) (value []byte, err error) { if t.root == nil { return nil, NotFound(key) } - newRoot, _, value, err := t.root.remove(t.db, key) + newRootHash, newRoot, _, value, err := t.root.remove(t.ndb, key) if err != nil { return nil, err } - t.root = newRoot + if newRoot == nil && newRootHash != nil { + t.root = t.ndb.Get(newRootHash) + } else { + t.root = newRoot + } return value, nil } func (t *IAVLTree) Copy() Tree { - return &IAVLTree{db: t.db, root: t.root} + return &IAVLTree{ndb: t.ndb, root: t.root} +} + +//----------------------------------------------------------------------------- + +type IAVLNodeDB struct { + db DB + cache map[string]*IAVLNode + // XXX expire entries +} + +func (ndb *IAVLNodeDB) Get(hash []byte) *IAVLNode { + buf := ndb.db.Get(hash) + r := bytes.NewReader(buf) + var n int64 + var err error + node := ReadIAVLNode(r, &n, &err) + if err != nil { + panic(err) + } + node.persisted = true + ndb.cache[string(hash)] = node + return node +} + +func (ndb *IAVLNodeDB) Save(node *IAVLNode) { + hash := node.hash + if hash != nil { + panic("Expected to find node.hash, but none found.") + } + buf := bytes.NewBuffer(nil) + _, err := self.WriteTo(buf) + if err != nil { + panic(err) + } + node.persisted = true + ndb.cache[string(hash)] = node + ndb.db.Set(hash, buf.Bytes()) } diff --git a/merkle/types.go b/merkle/types.go index 35d3ef42..396eb21b 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -4,7 +4,7 @@ import ( "fmt" ) -type Db interface { +type DB interface { Get([]byte) []byte Set([]byte, []byte) }