diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index c86c42f8..bb746e9b 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -16,21 +16,17 @@ type IAVLNode struct { height uint8 hash []byte leftHash []byte + leftNode *IAVLNode rightHash []byte + rightNode *IAVLNode persisted bool - - // May or may not be persisted nodes, but they'll get cleared - // when this this node is saved. - leftCached *IAVLNode - rightCached *IAVLNode } func NewIAVLNode(key []byte, value []byte) *IAVLNode { return &IAVLNode{ - key: key, - value: value, - size: 1, - persisted: false, + key: key, + value: value, + size: 1, } } @@ -47,12 +43,9 @@ func ReadIAVLNode(r io.Reader, n *int64, err *error) *IAVLNode { // node value or children. if node.height == 0 { - // value node.value = ReadByteSlice(r, n, err) } else { - // left node.leftHash = ReadByteSlice(r, n, err) - // right node.rightHash = ReadByteSlice(r, n, err) } if *err != nil { @@ -66,15 +59,15 @@ func (self *IAVLNode) Copy() *IAVLNode { panic("Why are you copying a value node?") } return &IAVLNode{ - key: self.key, - size: self.size, - height: self.height, - hash: nil, // Going to be mutated anyways. - leftHash: self.leftHash, - rightHash: self.rightHash, - persisted: self.persisted, - leftCached: self.leftCached, - rightCached: self.rightCached, + key: self.key, + size: self.size, + height: self.height, + hash: nil, // Going to be mutated anyways. + leftHash: self.leftHash, + leftNode: self.leftNode, + rightHash: self.rightHash, + rightNode: self.rightNode, + persisted: self.persisted, } } @@ -94,9 +87,9 @@ func (self *IAVLNode) has(ndb *IAVLNodeDB, key []byte) (has bool) { return false } else { if bytes.Compare(key, self.key) == -1 { - return self.getLeft(ndb).has(ndb, key) + return self.getLeftNode(ndb).has(ndb, key) } else { - return self.getRight(ndb).has(ndb, key) + return self.getRightNode(ndb).has(ndb, key) } } } @@ -110,9 +103,9 @@ func (self *IAVLNode) get(ndb *IAVLNodeDB, key []byte) (value []byte) { } } else { if bytes.Compare(key, self.key) == -1 { - return self.getLeft(ndb).get(ndb, key) + return self.getLeftNode(ndb).get(ndb, key) } else { - return self.getRight(ndb).get(ndb, key) + return self.getRightNode(ndb).get(ndb, key) } } } @@ -123,7 +116,7 @@ func (self *IAVLNode) HashWithCount() ([]byte, uint64) { } hasher := sha256.New() - _, hashCount, err := self.saveToCountHashes(hasher) + _, hashCount, err := self.writeToCountHashes(hasher) if err != nil { panic(err) } @@ -140,14 +133,14 @@ func (self *IAVLNode) Save(ndb *IAVLNodeDB) []byte { return self.hash } - // children - if self.leftCached != nil { - self.leftHash = self.leftCached.Save(ndb) - self.leftCached = nil + // save children + if self.leftNode != nil { + self.leftHash = self.leftNode.Save(ndb) + self.leftNode = nil } - if self.rightCached != nil { - self.rightHash = self.rightCached.Save(ndb) - self.rightCached = nil + if self.rightNode != nil { + self.rightHash = self.rightNode.Save(ndb) + self.rightNode = nil } // save self @@ -159,30 +152,30 @@ func (self *IAVLNode) set(ndb *IAVLNodeDB, key []byte, value []byte) (_ *IAVLNod if self.height == 0 { if bytes.Compare(key, self.key) == -1 { return &IAVLNode{ - key: self.key, - height: 1, - size: 2, - leftCached: NewIAVLNode(key, value), - rightCached: self, + key: self.key, + height: 1, + size: 2, + leftNode: NewIAVLNode(key, value), + rightNode: self, }, false } else if bytes.Equal(self.key, key) { return NewIAVLNode(key, value), true } else { return &IAVLNode{ - key: key, - height: 1, - size: 2, - leftCached: self, - rightCached: NewIAVLNode(key, value), + key: key, + height: 1, + size: 2, + leftNode: self, + rightNode: NewIAVLNode(key, value), }, false } } else { self = self.Copy() if bytes.Compare(key, self.key) == -1 { - self.leftCached, updated = self.getLeft(ndb).set(ndb, key, value) + self.leftNode, updated = self.getLeftNode(ndb).set(ndb, key, value) self.leftHash = nil } else { - self.rightCached, updated = self.getRight(ndb).set(ndb, key, value) + self.rightNode, updated = self.getRightNode(ndb).set(ndb, key, value) self.rightHash = nil } if updated { @@ -194,10 +187,11 @@ func (self *IAVLNode) set(ndb *IAVLNodeDB, key []byte, value []byte) (_ *IAVLNod } } +// newHash/newNode: The new hash or node to replace self after remove. // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. -// only one of newSelfHash or newSelf is returned. +// value: removed value. func (self *IAVLNode) remove(ndb *IAVLNodeDB, key []byte) ( - newSelfHash []byte, newSelf *IAVLNode, newKey []byte, value []byte, err error) { + newHash []byte, newNode *IAVLNode, newKey []byte, value []byte, err error) { if self.height == 0 { if bytes.Equal(self.key, key) { return nil, nil, nil, self.value, nil @@ -207,26 +201,26 @@ func (self *IAVLNode) remove(ndb *IAVLNodeDB, key []byte) ( } else { if bytes.Compare(key, self.key) == -1 { var newLeftHash []byte - var newLeft *IAVLNode - newLeftHash, newLeft, newKey, value, err = self.getLeft(ndb).remove(ndb, key) + var newLeftNode *IAVLNode + newLeftHash, newLeftNode, newKey, value, err = self.getLeftNode(ndb).remove(ndb, key) if err != nil { return nil, self, nil, value, err - } else if newLeftHash == nil && newLeft == nil { // left node held value, was removed - return self.rightHash, self.rightCached, self.key, value, nil + } else if newLeftHash == nil && newLeftNode == nil { // left node held value, was removed + return self.rightHash, self.rightNode, self.key, value, nil } self = self.Copy() - self.leftHash, self.leftCached = newLeftHash, newLeft + self.leftHash, self.leftNode = newLeftHash, newLeftNode } else { var newRightHash []byte - var newRight *IAVLNode - newRightHash, newRight, newKey, value, err = self.getRight(ndb).remove(ndb, key) + var newRightNode *IAVLNode + newRightHash, newRightNode, newKey, value, err = self.getRightNode(ndb).remove(ndb, key) if err != nil { return nil, self, nil, value, err - } else if newRightHash == nil && newRight == nil { // right node held value, was removed - return self.leftHash, self.leftCached, nil, value, nil + } else if newRightHash == nil && newRightNode == nil { // right node held value, was removed + return self.leftHash, self.leftNode, nil, value, nil } self = self.Copy() - self.rightHash, self.rightCached = newRightHash, newRight + self.rightHash, self.rightNode = newRightHash, newRightNode if newKey != nil { self.key = newKey newKey = nil @@ -238,11 +232,11 @@ func (self *IAVLNode) remove(ndb *IAVLNodeDB, key []byte) ( } func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { - n, _, err = self.saveToCountHashes(w) + n, _, err = self.writeToCountHashes(w) return } -func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) { +func (self *IAVLNode) writeToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) { // height & size & key WriteUInt8(w, self.height, &n, &err) WriteUInt64(w, self.size, &n, &err) @@ -251,40 +245,45 @@ func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, return } - // value or children if self.height == 0 { // value WriteByteSlice(w, self.value, &n, &err) } else { // left - if self.leftCached != nil { - leftHash, leftCount := self.leftCached.HashWithCount() + if self.leftNode != nil { + leftHash, leftCount := self.leftNode.HashWithCount() self.leftHash = leftHash hashCount += leftCount } + if self.leftHash == nil { + panic("self.leftHash was nil in save") + } WriteByteSlice(w, self.leftHash, &n, &err) // right - if self.rightCached != nil { - rightHash, rightCount := self.rightCached.HashWithCount() + if self.rightNode != nil { + rightHash, rightCount := self.rightNode.HashWithCount() self.rightHash = rightHash hashCount += rightCount } + if self.rightHash == nil { + panic("self.rightHash was nil in save") + } WriteByteSlice(w, self.rightHash, &n, &err) } return } -func (self *IAVLNode) getLeft(ndb *IAVLNodeDB) *IAVLNode { - if self.leftCached != nil { - return self.leftCached +func (self *IAVLNode) getLeftNode(ndb *IAVLNodeDB) *IAVLNode { + if self.leftNode != nil { + return self.leftNode } else { return ndb.Get(self.leftHash) } } -func (self *IAVLNode) getRight(ndb *IAVLNodeDB) *IAVLNode { - if self.rightCached != nil { - return self.rightCached +func (self *IAVLNode) getRightNode(ndb *IAVLNodeDB) *IAVLNode { + if self.rightNode != nil { + return self.rightNode } else { return ndb.Get(self.rightHash) } @@ -292,11 +291,11 @@ func (self *IAVLNode) getRight(ndb *IAVLNodeDB) *IAVLNode { func (self *IAVLNode) rotateRight(ndb *IAVLNodeDB) *IAVLNode { self = self.Copy() - sl := self.getLeft(ndb).Copy() + sl := self.getLeftNode(ndb).Copy() - slrHash, slrCached := sl.rightHash, sl.rightCached - sl.rightHash, sl.rightCached = nil, self - self.leftHash, self.leftCached = slrHash, slrCached + slrHash, slrCached := sl.rightHash, sl.rightNode + sl.rightHash, sl.rightNode = nil, self + self.leftHash, self.leftNode = slrHash, slrCached self.calcHeightAndSize(ndb) sl.calcHeightAndSize(ndb) @@ -306,11 +305,11 @@ func (self *IAVLNode) rotateRight(ndb *IAVLNodeDB) *IAVLNode { func (self *IAVLNode) rotateLeft(ndb *IAVLNodeDB) *IAVLNode { self = self.Copy() - sr := self.getRight(ndb).Copy() + sr := self.getRightNode(ndb).Copy() - srlHash, srlCached := sr.leftHash, sr.leftCached - sr.leftHash, sr.leftCached = nil, self - self.rightHash, self.rightCached = srlHash, srlCached + srlHash, srlCached := sr.leftHash, sr.leftNode + sr.leftHash, sr.leftNode = nil, self + self.rightHash, self.rightNode = srlHash, srlCached self.calcHeightAndSize(ndb) sr.calcHeightAndSize(ndb) @@ -319,36 +318,36 @@ func (self *IAVLNode) rotateLeft(ndb *IAVLNodeDB) *IAVLNode { } func (self *IAVLNode) calcHeightAndSize(ndb *IAVLNodeDB) { - self.height = maxUint8(self.getLeft(ndb).Height(), self.getRight(ndb).Height()) + 1 - self.size = self.getLeft(ndb).Size() + self.getRight(ndb).Size() + self.height = maxUint8(self.getLeftNode(ndb).Height(), self.getRightNode(ndb).Height()) + 1 + self.size = self.getLeftNode(ndb).Size() + self.getRightNode(ndb).Size() } func (self *IAVLNode) calcBalance(ndb *IAVLNodeDB) int { - return int(self.getLeft(ndb).Height()) - int(self.getRight(ndb).Height()) + return int(self.getLeftNode(ndb).Height()) - int(self.getRightNode(ndb).Height()) } func (self *IAVLNode) balance(ndb *IAVLNodeDB) (newSelf *IAVLNode) { balance := self.calcBalance(ndb) if balance > 1 { - if self.getLeft(ndb).calcBalance(ndb) >= 0 { + if self.getLeftNode(ndb).calcBalance(ndb) >= 0 { // Left Left Case return self.rotateRight(ndb) } else { // Left Right Case self = self.Copy() - self.leftHash, self.leftCached = nil, self.getLeft(ndb).rotateLeft(ndb) + self.leftHash, self.leftNode = nil, self.getLeftNode(ndb).rotateLeft(ndb) //self.calcHeightAndSize() return self.rotateRight(ndb) } } if balance < -1 { - if self.getRight(ndb).calcBalance(ndb) <= 0 { + if self.getRightNode(ndb).calcBalance(ndb) <= 0 { // Right Right Case return self.rotateLeft(ndb) } else { // Right Left Case self = self.Copy() - self.rightHash, self.rightCached = nil, self.getRight(ndb).rotateRight(ndb) + self.rightHash, self.rightNode = nil, self.getRightNode(ndb).rotateRight(ndb) //self.calcHeightAndSize() return self.rotateLeft(ndb) } @@ -357,12 +356,30 @@ func (self *IAVLNode) balance(ndb *IAVLNodeDB) (newSelf *IAVLNode) { return self } +func (self *IAVLNode) traverse(ndb *IAVLNodeDB, cb func(*IAVLNode) bool) bool { + stop := cb(self) + if stop { + return stop + } + if self.height > 0 { + stop = self.getLeftNode(ndb).traverse(ndb, cb) + if stop { + return stop + } + stop = self.getRightNode(ndb).traverse(ndb, cb) + if stop { + return stop + } + } + return false +} + // Only used in testing... func (self *IAVLNode) lmd(ndb *IAVLNodeDB) *IAVLNode { if self.height == 0 { return self } - return self.getLeft(ndb).lmd(ndb) + return self.getLeftNode(ndb).lmd(ndb) } // Only used in testing... @@ -370,23 +387,5 @@ func (self *IAVLNode) rmd(ndb *IAVLNodeDB) *IAVLNode { if self.height == 0 { return self } - return self.getRight(ndb).rmd(ndb) -} - -func (self *IAVLNode) traverse(ndb *IAVLNodeDB, cb func(*IAVLNode) bool) bool { - stop := cb(self) - if stop { - return stop - } - if self.height > 0 { - stop = self.getLeft(ndb).traverse(ndb, cb) - if stop { - return stop - } - stop = self.getRight(ndb).traverse(ndb, cb) - if stop { - return stop - } - } - return false + return self.getRightNode(ndb).rmd(ndb) } diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 1db7efef..a8d86352 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -37,9 +37,9 @@ func TestUnit(t *testing.T) { } n := &IAVLNode{ - key: right.lmd(nil).key, - leftCached: left, - rightCached: right, + key: right.lmd(nil).key, + leftNode: left, + rightNode: right, } n.calcHeightAndSize(nil) n.HashWithCount() @@ -52,7 +52,7 @@ func TestUnit(t *testing.T) { if n.height == 0 { return fmt.Sprintf("%v", n.key[0]) } else { - return fmt.Sprintf("(%v %v)", P(n.leftCached), P(n.rightCached)) + return fmt.Sprintf("(%v %v)", P(n.leftNode), P(n.rightNode)) } } diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 6bb5b73f..907022b5 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -7,6 +7,8 @@ import ( const defaultCacheCapacity = 1000 // TODO make configurable. +// XXX Make Codec tree. + /* Immutable AVL Tree (wraps the Node root) @@ -83,7 +85,6 @@ func (t *IAVLTree) Save() { if t.root == nil { return } - t.root.HashWithCount() t.root.Save(t.ndb) } diff --git a/merkle/util.go b/merkle/util.go index a67e12d6..44e1fedd 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -239,16 +239,16 @@ func printIAVLNode(node *IAVLNode, indent int) { indentPrefix += " " } - if node.rightCached != nil { - printIAVLNode(node.rightCached, indent+1) + if node.rightNode != nil { + printIAVLNode(node.rightNode, indent+1) } else if node.rightHash != nil { fmt.Printf("%s %X\n", indentPrefix, node.rightHash) } fmt.Printf("%s%v:%v\n", indentPrefix, node.key, node.height) - if node.leftCached != nil { - printIAVLNode(node.leftCached, indent+1) + if node.leftNode != nil { + printIAVLNode(node.leftNode, indent+1) } else if node.leftHash != nil { fmt.Printf("%s %X\n", indentPrefix, node.leftHash) } diff --git a/p2p/switch.go b/p2p/switch.go index 9d55c1f4..32cbcc8b 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -156,6 +156,7 @@ func (s *Switch) IsDialing(addr *NetAddress) bool { return s.dialing.Has(addr.String()) } +// XXX: This is wrong, we can't just ignore failures on TrySend. func (s *Switch) Broadcast(chId byte, msg Binary) (numSuccess, numFailure int) { if atomic.LoadUint32(&s.stopped) == 1 { return