refactoring again, implementing b+tree

This commit is contained in:
Jae Kwon 2014-05-23 23:11:22 -07:00
parent 98c6181de0
commit ef480bb229
5 changed files with 659 additions and 109 deletions

View File

@ -4,8 +4,6 @@ import (
"crypto/sha256"
)
const HASH_BYTE_SIZE int = 4+32
// Immutable AVL Tree (wraps the Node root)
type IAVLTree struct {
@ -39,7 +37,7 @@ func (self *IAVLTree) Height() uint8 {
}
func (self *IAVLTree) Has(key Key) bool {
return self.root.Has(self.db, key)
return self.root.has(self.db, key)
}
func (self *IAVLTree) Put(key Key, value Value) (updated bool) {
@ -59,7 +57,7 @@ func (self *IAVLTree) Save() {
}
func (self *IAVLTree) Get(key Key) (value Value) {
return self.root.Get(self.db, key)
return self.root.get(self.db, key)
}
func (self *IAVLTree) Remove(key Key) (value Value, err error) {
@ -71,6 +69,36 @@ func (self *IAVLTree) Remove(key Key) (value Value, err error) {
return value, nil
}
func (self *IAVLTree) Iterator() NodeIterator {
pop := func (stack []*IAVLNode) ([]*IAVLNode, *IAVLNode) {
if len(stack) <= 0 {
return stack, nil
} else {
return stack[0:len(stack)-1], stack[len(stack)-1]
}
}
stack := make([]*IAVLNode, 0, 10)
var cur *IAVLNode = self.root
var itr NodeIterator
itr = func()(tn Node) {
if len(stack) > 0 || cur != nil {
for cur != nil {
stack = append(stack, cur)
cur = cur.leftFilled(self.db)
}
stack, cur = pop(stack)
tn = cur
cur = cur.rightFilled(self.db)
return tn
} else {
return nil
}
}
return itr
}
// Node
type IAVLNode struct {
@ -127,16 +155,6 @@ func (self *IAVLNode) Value() Value {
return self.value
}
func (self *IAVLNode) Left(db Db) Node {
if self.left == nil { return nil }
return self.leftFilled(db)
}
func (self *IAVLNode) Right(db Db) Node {
if self.right == nil { return nil }
return self.rightFilled(db)
}
func (self *IAVLNode) Size() uint64 {
if self == nil { return 0 }
return self.size
@ -147,25 +165,25 @@ func (self *IAVLNode) Height() uint8 {
return self.height
}
func (self *IAVLNode) Has(db Db, key Key) (has bool) {
func (self *IAVLNode) has(db Db, key Key) (has bool) {
if self == nil { return false }
if self.key.Equals(key) {
return true
} else if key.Less(self.key) {
return self.leftFilled(db).Has(db, key)
return self.leftFilled(db).has(db, key)
} else {
return self.rightFilled(db).Has(db, key)
return self.rightFilled(db).has(db, key)
}
}
func (self *IAVLNode) Get(db Db, key Key) (value Value) {
func (self *IAVLNode) get(db Db, key Key) (value Value) {
if self == nil { return nil }
if self.key.Equals(key) {
return self.value
} else if key.Less(self.key) {
return self.leftFilled(db).Get(db, key)
return self.leftFilled(db).get(db, key)
} else {
return self.rightFilled(db).Get(db, key)
return self.rightFilled(db).get(db, key)
}
}
@ -210,13 +228,6 @@ func (self *IAVLNode) Save(db Db) {
self.flags |= IAVLNODE_FLAG_PERSISTED
}
// TODO: don't clear the hash if the value hasn't changed.
func (self *IAVLNode) Put(db Db, key Key, value Value) (_ Node, updated bool) {
node, updated := self.put(db, key, value)
if node == nil { panic("unexpected nil node in put") }
return node, updated
}
func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated bool) {
if self == nil {
return &IAVLNode{key: key, value: value, height: 1, size: 1, hash: nil}, false
@ -242,14 +253,6 @@ func (self *IAVLNode) put(db Db, key Key, value Value) (_ *IAVLNode, updated boo
}
}
func (self *IAVLNode) Remove(db Db, key Key) (newSelf Node, value Value, err error) {
newIAVLSelf, value, err := self.remove(db, key)
if newIAVLSelf != nil {
newSelf = newIAVLSelf
}
return
}
func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, value Value, err error) {
if self == nil { return nil, nil, NotFound(key) }
@ -571,10 +574,3 @@ func (self *IAVLNode) lmd(db Db) (*IAVLNode) {
func (self *IAVLNode) rmd(db Db) (*IAVLNode) {
return self._md(func(node *IAVLNode)*IAVLNode { return node.rightFilled(db) })
}
func maxUint8(a, b uint8) uint8 {
if a > b {
return a
}
return b
}

View File

@ -35,7 +35,7 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) {
}
records := make([]*record, 400)
var tree *IAVLNode
var node *IAVLNode
var err error
var val Value
var updated bool
@ -47,50 +47,50 @@ func TestImmutableAvlPutHasGetRemove(t *testing.T) {
for i := range records {
r := randomRecord()
records[i] = r
tree, updated = tree.put(nil, r.key, String(""))
node, updated = node.put(nil, r.key, String(""))
if updated {
t.Error("should have not been updated")
}
tree, updated = tree.put(nil, r.key, r.value)
node, updated = node.put(nil, r.key, r.value)
if !updated {
t.Error("should have been updated")
}
if tree.Size() != uint64(i+1) {
t.Error("size was wrong", tree.Size(), i+1)
if node.Size() != uint64(i+1) {
t.Error("size was wrong", node.Size(), i+1)
}
}
for _, r := range records {
if has := tree.Has(nil, r.key); !has {
if has := node.has(nil, r.key); !has {
t.Error("Missing key")
}
if has := tree.Has(nil, randstr(12)); has {
if has := node.has(nil, randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(nil, r.key); !(val.(String)).Equals(r.value) {
if val := node.get(nil, r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
for i, x := range records {
if tree, val, err = tree.remove(nil, x.key); err != nil {
if node, val, err = node.remove(nil, x.key); err != nil {
t.Error(err)
} else if !(val.(String)).Equals(x.value) {
t.Error("wrong value")
}
for _, r := range records[i+1:] {
if has := tree.Has(nil, r.key); !has {
if has := node.has(nil, r.key); !has {
t.Error("Missing key")
}
if has := tree.Has(nil, randstr(12)); has {
if has := node.has(nil, randstr(12)); has {
t.Error("Table has extra key")
}
if val := tree.Get(nil, r.key); !(val.(String)).Equals(r.value) {
if val := node.get(nil, r.key); !(val.(String)).Equals(r.value) {
t.Error("wrong value")
}
}
if tree.Size() != uint64(len(records) - (i+1)) {
t.Error("size was wrong", tree.Size(), (len(records) - (i+1)))
if node.Size() != uint64(len(records) - (i+1)) {
t.Error("size was wrong", node.Size(), (len(records) - (i+1)))
}
}
}
@ -138,7 +138,7 @@ func TestTraversals(t *testing.T) {
}
j := 0
itr := Iterator(T.Root());
itr := T.Iterator()
for node := itr(); node != nil; node = itr() {
if int(node.Key().(Int)) != data[j] {
t.Error("key in wrong spot in-order")
@ -185,7 +185,7 @@ func TestGriffin(t *testing.T) {
t.Fatalf("Expected %v new hashes, got %v", hashCount, count)
}
// nuke hashes and reconstruct hash, ensure it's the same.
itr := Iterator(n2)
itr := (&IAVLTree{root:n2}).Iterator()
for node:=itr(); node!=nil; node = itr() {
if node != nil {
node.(*IAVLNode).hash = nil

579
merkle/ibp.go Normal file
View File

@ -0,0 +1,579 @@
package merkle
import (
"crypto/sha256"
)
// Immutable B+ Tree (wraps the Node root)
type IBPTree struct {
db Db
root *IBPNode
}
func NewIBPTree(db Db) *IBPTree {
return &IBPTree{db:db, root:nil}
}
func NewIBPTreeFromHash(db Db, hash ByteSlice) *IBPTree {
root := &IBPNode{
hash: hash,
flags: IBPNODE_FLAG_PERSISTED | IBPNODE_FLAG_PLACEHOLDER,
}
root.fill(db)
return &IBPTree{db:db, root:root}
}
func (self *IBPTree) Root() Node {
return self.root
}
func (self *IBPTree) Size() uint64 {
return self.root.Size()
}
func (self *IBPTree) Height() uint8 {
return self.root.Height()
}
func (self *IBPTree) Has(key Key) bool {
return self.root.has(self.db, key)
}
func (self *IBPTree) Put(key Key, value Value) (updated bool) {
self.root, updated = self.root.put(self.db, key, value)
return updated
}
func (self *IBPTree) Hash() (ByteSlice, uint64) {
return self.root.Hash()
}
func (self *IBPTree) Save() {
if self.root.hash == nil {
self.root.Hash()
}
self.root.Save(self.db)
}
func (self *IBPTree) Get(key Key) (value Value) {
return self.root.get(self.db, key)
}
func (self *IBPTree) Remove(key Key) (value Value, err error) {
newRoot, value, err := self.root.remove(self.db, key)
if err != nil {
return nil, err
}
self.root = newRoot
return value, nil
}
func (self *IBPTree) Iterator() NodeIterator {
/*
pop := func(stack []*IBPNode) ([]*IBPNode, *IBPNode) {
if len(stack) <= 0 {
return stack, nil
} else {
return stack[0:len(stack)-1], stack[len(stack)-1]
}
}
stack := make([]*IBPNode, 0, 10)
var cur *IBPNode = self.root
var itr NodeIterator
itr = func()(tn Node) {
if len(stack) > 0 || cur != nil {
for cur != nil {
stack = append(stack, cur)
cur = cur.leftFilled(self.db)
}
stack, cur = pop(stack)
tn = cur
cur = cur.rightFilled(self.db)
return tn
} else {
return nil
}
}
return itr
*/
return nil
}
// Node
type IBPNode struct {
key Key
value Value
size uint64
height uint8
hash ByteSlice
left *IBPNode
right *IBPNode
// volatile
flags byte
}
const (
IBPNODE_FLAG_PERSISTED = byte(0x01)
IBPNODE_FLAG_PLACEHOLDER = byte(0x02)
IBPNODE_DESC_HAS_VALUE = byte(0x01)
IBPNODE_DESC_HAS_LEFT = byte(0x02)
IBPNODE_DESC_HAS_RIGHT = byte(0x04)
)
func (self *IBPNode) Copy() *IBPNode {
if self == nil {
return nil
}
return &IBPNode{
key: self.key,
value: self.value,
size: self.size,
height: self.height,
left: self.left,
right: self.right,
hash: nil,
flags: byte(0),
}
}
func (self *IBPNode) Equals(other Binary) bool {
if o, ok := other.(*IBPNode); ok {
return self.hash.Equals(o.hash)
} else {
return false
}
}
func (self *IBPNode) Key() Key {
return self.key
}
func (self *IBPNode) Value() Value {
return self.value
}
func (self *IBPNode) Size() uint64 {
if self == nil { return 0 }
return self.size
}
func (self *IBPNode) Height() uint8 {
if self == nil { return 0 }
return self.height
}
func (self *IBPNode) has(db Db, key Key) (has bool) {
if self == nil { return false }
if self.key.Equals(key) {
return true
} else if key.Less(self.key) {
return self.leftFilled(db).has(db, key)
} else {
return self.rightFilled(db).has(db, key)
}
}
func (self *IBPNode) get(db Db, key Key) (value Value) {
if self == nil { return nil }
if self.key.Equals(key) {
return self.value
} else if key.Less(self.key) {
return self.leftFilled(db).get(db, key)
} else {
return self.rightFilled(db).get(db, key)
}
}
func (self *IBPNode) Hash() (ByteSlice, uint64) {
if self == nil { return nil, 0 }
if self.hash != nil {
return self.hash, 0
}
size := self.ByteSize()
buf := make([]byte, size, size)
hasher := sha256.New()
_, hashCount := self.saveToCountHashes(buf)
hasher.Write(buf)
self.hash = hasher.Sum(nil)
return self.hash, hashCount+1
}
func (self *IBPNode) Save(db Db) {
if self == nil {
return
} else if self.hash == nil {
panic("savee.hash can't be nil")
}
if self.flags & IBPNODE_FLAG_PERSISTED > 0 ||
self.flags & IBPNODE_FLAG_PLACEHOLDER > 0 {
return
}
// save self
buf := make([]byte, self.ByteSize(), self.ByteSize())
self.SaveTo(buf)
db.Put([]byte(self.hash), buf)
// save left
self.left.Save(db)
// save right
self.right.Save(db)
self.flags |= IBPNODE_FLAG_PERSISTED
}
func (self *IBPNode) put(db Db, key Key, value Value) (_ *IBPNode, updated bool) {
if self == nil {
return &IBPNode{key: key, value: value, height: 1, size: 1, hash: nil}, false
}
self = self.Copy()
if self.key.Equals(key) {
self.value = value
return self, true
}
if key.Less(self.key) {
self.left, updated = self.leftFilled(db).put(db, key, value)
} else {
self.right, updated = self.rightFilled(db).put(db, key, value)
}
if updated {
return self, updated
} else {
self.calcHeightAndSize(db)
return self.balance(db), updated
}
}
func (self *IBPNode) remove(db Db, key Key) (newSelf *IBPNode, value Value, err error) {
if self == nil { return nil, nil, NotFound(key) }
if self.key.Equals(key) {
if self.left != nil && self.right != nil {
if self.leftFilled(db).Size() < self.rightFilled(db).Size() {
self, newSelf = self.popNode(db, self.rightFilled(db).lmd(db))
} else {
self, newSelf = self.popNode(db, self.leftFilled(db).rmd(db))
}
newSelf.left = self.left
newSelf.right = self.right
newSelf.calcHeightAndSize(db)
return newSelf, self.value, nil
} else if self.left == nil {
return self.rightFilled(db), self.value, nil
} else if self.right == nil {
return self.leftFilled(db), self.value, nil
} else {
return nil, self.value, nil
}
}
if key.Less(self.key) {
if self.left == nil {
return self, nil, NotFound(key)
}
var newLeft *IBPNode
newLeft, value, err = self.leftFilled(db).remove(db, key)
if newLeft == self.leftFilled(db) { // not found
return self, nil, err
} else if err != nil { // some other error
return self, value, err
}
self = self.Copy()
self.left = newLeft
} else {
if self.right == nil {
return self, nil, NotFound(key)
}
var newRight *IBPNode
newRight, value, err = self.rightFilled(db).remove(db, key)
if newRight == self.rightFilled(db) { // not found
return self, nil, err
} else if err != nil { // some other error
return self, value, err
}
self = self.Copy()
self.right = newRight
}
self.calcHeightAndSize(db)
return self.balance(db), value, err
}
func (self *IBPNode) ByteSize() int {
// 1 byte node descriptor
// 1 byte node neight
// 8 bytes node size
size := 10
// key
size += 1 // type info
size += self.key.ByteSize()
// value
if self.value != nil {
size += 1 // type info
size += self.value.ByteSize()
} else {
size += 1
}
// children
if self.left != nil {
size += HASH_BYTE_SIZE
}
if self.right != nil {
size += HASH_BYTE_SIZE
}
return size
}
func (self *IBPNode) SaveTo(buf []byte) int {
written, _ := self.saveToCountHashes(buf)
return written
}
func (self *IBPNode) saveToCountHashes(buf []byte) (int, uint64) {
cur := 0
hashCount := uint64(0)
// node descriptor
nodeDesc := byte(0)
if self.value != nil { nodeDesc |= IBPNODE_DESC_HAS_VALUE }
if self.left != nil { nodeDesc |= IBPNODE_DESC_HAS_LEFT }
if self.right != nil { nodeDesc |= IBPNODE_DESC_HAS_RIGHT }
cur += UInt8(nodeDesc).SaveTo(buf[cur:])
// node height & size
cur += UInt8(self.height).SaveTo(buf[cur:])
cur += UInt64(self.size).SaveTo(buf[cur:])
// node key
buf[cur] = GetBinaryType(self.key)
cur += 1
cur += self.key.SaveTo(buf[cur:])
// node value
if self.value != nil {
buf[cur] = GetBinaryType(self.value)
cur += 1
cur += self.value.SaveTo(buf[cur:])
}
// left child
if self.left != nil {
leftHash, leftCount := self.left.Hash()
hashCount += leftCount
cur += leftHash.SaveTo(buf[cur:])
}
// right child
if self.right != nil {
rightHash, rightCount := self.right.Hash()
hashCount += rightCount
cur += rightHash.SaveTo(buf[cur:])
}
return cur, hashCount
}
// Given a placeholder node which has only the hash set,
// load the rest of the data from db.
// Not threadsafe.
func (self *IBPNode) fill(db Db) {
if self == nil {
panic("placeholder can't be nil")
} else if self.hash == nil {
panic("placeholder.hash can't be nil")
}
buf := db.Get(self.hash)
cur := 0
// node header
nodeDesc := byte(LoadUInt8(buf))
self.height = uint8(LoadUInt8(buf[1:]))
self.size = uint64(LoadUInt64(buf[2:]))
// key
key, cur := LoadBinary(buf, 10)
self.key = key.(Key)
// value
if nodeDesc & IBPNODE_DESC_HAS_VALUE > 0 {
self.value, cur = LoadBinary(buf, cur)
}
// children
if nodeDesc & IBPNODE_DESC_HAS_LEFT > 0 {
var leftHash ByteSlice
leftHash, cur = LoadByteSlice(buf, cur)
self.left = &IBPNode{
hash: leftHash,
flags: IBPNODE_FLAG_PERSISTED | IBPNODE_FLAG_PLACEHOLDER,
}
}
if nodeDesc & IBPNODE_DESC_HAS_RIGHT > 0 {
var rightHash ByteSlice
rightHash, cur = LoadByteSlice(buf, cur)
self.right = &IBPNode{
hash: rightHash,
flags: IBPNODE_FLAG_PERSISTED | IBPNODE_FLAG_PLACEHOLDER,
}
}
if cur != len(buf) {
panic("buf not all consumed")
}
self.flags &= ^IBPNODE_FLAG_PLACEHOLDER
}
func (self *IBPNode) leftFilled(db Db) *IBPNode {
if self.left == nil {
return nil
}
if self.left.flags & IBPNODE_FLAG_PLACEHOLDER > 0 {
self.left.fill(db)
}
return self.left
}
func (self *IBPNode) rightFilled(db Db) *IBPNode {
if self.right == nil {
return nil
}
if self.right.flags & IBPNODE_FLAG_PLACEHOLDER > 0 {
self.right.fill(db)
}
return self.right
}
// Returns a new tree (unless node is the root) & a copy of the popped node.
// Can only pop nodes that have one or no children.
func (self *IBPNode) popNode(db Db, node *IBPNode) (newSelf, new_node *IBPNode) {
if self == nil {
panic("self can't be nil")
} else if node == nil {
panic("node can't be nil")
} else if node.left != nil && node.right != nil {
panic("node hnot have both left and right")
}
if self == node {
var n *IBPNode
if node.left != nil {
n = node.leftFilled(db)
} else if node.right != nil {
n = node.rightFilled(db)
} else {
n = nil
}
node = node.Copy()
node.left = nil
node.right = nil
node.calcHeightAndSize(db)
return n, node
} else {
self = self.Copy()
if node.key.Less(self.key) {
self.left, node = self.leftFilled(db).popNode(db, node)
} else {
self.right, node = self.rightFilled(db).popNode(db, node)
}
self.calcHeightAndSize(db)
return self, node
}
}
func (self *IBPNode) rotateRight(db Db) *IBPNode {
self = self.Copy()
sl := self.leftFilled(db).Copy()
slr := sl.right
sl.right = self
self.left = slr
self.calcHeightAndSize(db)
sl.calcHeightAndSize(db)
return sl
}
func (self *IBPNode) rotateLeft(db Db) *IBPNode {
self = self.Copy()
sr := self.rightFilled(db).Copy()
srl := sr.left
sr.left = self
self.right = srl
self.calcHeightAndSize(db)
sr.calcHeightAndSize(db)
return sr
}
func (self *IBPNode) calcHeightAndSize(db Db) {
self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1
self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size() + 1
}
func (self *IBPNode) calcBalance(db Db) int {
if self == nil {
return 0
}
return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height())
}
func (self *IBPNode) balance(db Db) (newSelf *IBPNode) {
balance := self.calcBalance(db)
if (balance > 1) {
if (self.leftFilled(db).calcBalance(db) >= 0) {
// Left Left Case
return self.rotateRight(db)
} else {
// Left Right Case
self = self.Copy()
self.left = self.leftFilled(db).rotateLeft(db)
//self.calcHeightAndSize()
return self.rotateRight(db)
}
}
if (balance < -1) {
if (self.rightFilled(db).calcBalance(db) <= 0) {
// Right Right Case
return self.rotateLeft(db)
} else {
// Right Left Case
self = self.Copy()
self.right = self.rightFilled(db).rotateRight(db)
//self.calcHeightAndSize()
return self.rotateLeft(db)
}
}
// Nothing changed
return self
}
func (self *IBPNode) _md(side func(*IBPNode)*IBPNode) (*IBPNode) {
if self == nil {
return nil
} else if side(self) != nil {
return side(self)._md(side)
} else {
return self
}
}
func (self *IBPNode) lmd(db Db) (*IBPNode) {
return self._md(func(node *IBPNode)*IBPNode { return node.leftFilled(db) })
}
func (self *IBPNode) rmd(db Db) (*IBPNode) {
return self._md(func(node *IBPNode)*IBPNode { return node.rightFilled(db) })
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
)
const HASH_BYTE_SIZE int = 4+32
type Binary interface {
ByteSize() int
SaveTo([]byte) int
@ -25,6 +27,21 @@ type Db interface {
Put([]byte, []byte)
}
type Node interface {
Binary
Key() Key
Value() Value
Size() uint64
Height() uint8
Hash() (ByteSlice, uint64)
Save(Db)
}
type NodeIterator func() Node
type Tree interface {
Root() Node
@ -38,30 +55,10 @@ type Tree interface {
Put(Key, Value) bool
Remove(Key) (Value, error)
Iterator() NodeIterator
}
type Node interface {
Binary
Key() Key
Value() Value
Left(Db) Node
Right(Db) Node
Size() uint64
Height() uint8
Has(Db, Key) bool
Get(Db, Key) Value
Hash() (ByteSlice, uint64)
Save(Db)
Put(Db, Key, Value) (Node, bool)
Remove(Db, Key) (Node, Value, error)
}
type NodeIterator func() Node
func NotFound(key Key) error {
return fmt.Errorf("Key was not found.")
}

View File

@ -5,35 +5,6 @@ import (
"fmt"
)
func Iterator(node Node) NodeIterator {
stack := make([]Node, 0, 10)
var cur Node = node
var itr NodeIterator
itr = func()(tn Node) {
if len(stack) > 0 || cur != nil {
for cur != nil {
stack = append(stack, cur)
cur = cur.Left(nil)
}
stack, cur = pop(stack)
tn = cur
cur = cur.Right(nil)
return tn
} else {
return nil
}
}
return itr
}
func pop(stack []Node) ([]Node, Node) {
if len(stack) <= 0 {
return stack, nil
} else {
return stack[0:len(stack)-1], stack[len(stack)-1]
}
}
func PrintIAVLNode(node *IAVLNode) {
fmt.Println("==== NODE")
printIAVLNode(node, 0)
@ -68,3 +39,10 @@ func randstr(length int) String {
panic("unreachable")
}
func maxUint8(a, b uint8) uint8 {
if a > b {
return a
}
return b
}