From 8747bd5a8be64f6a676cc22d05cb795801503c92 Mon Sep 17 00:00:00 2001 From: Ethan Frey Date: Fri, 14 Jul 2017 12:17:08 +0200 Subject: [PATCH] Add set struct to the store --- state/queue.go | 4 +- state/set.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++ state/set_test.go | 77 +++++++++++++++++++++++ 3 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 state/set.go create mode 100644 state/set_test.go diff --git a/state/queue.go b/state/queue.go index db4869172..9460cf1d1 100644 --- a/state/queue.go +++ b/state/queue.go @@ -31,8 +31,8 @@ func (q *Queue) Tail() uint64 { } // Size returns how many elements are in the queue -func (q *Queue) Size() uint64 { - return q.tail - q.head +func (q *Queue) Size() int { + return int(q.tail - q.head) } // Push adds an element to the tail of the queue and returns it's location diff --git a/state/set.go b/state/set.go new file mode 100644 index 000000000..98d6a4095 --- /dev/null +++ b/state/set.go @@ -0,0 +1,152 @@ +package state + +import ( + "bytes" + "sort" + + wire "github.com/tendermint/go-wire" +) + +// Set allows us to add arbitrary k-v pairs, check existence, +// as well as iterate through the set (always in key order) +// +// If we had full access to the IAVL tree, this would be completely +// trivial and redundant +type Set struct { + store KVStore + keys KeyList +} + +var _ KVStore = &Set{} + +// NewSet loads or initializes a span of keys +func NewSet(store KVStore) *Set { + s := &Set{store: store} + s.loadKeys() + return s +} + +// Set puts a value at a given height. +// If the value is nil, or an empty slice, remove the key from the list +func (s *Set) Set(key []byte, value []byte) { + s.store.Set(makeBKey(key), value) + if len(value) > 0 { + s.addKey(key) + } else { + s.removeKey(key) + } + s.storeKeys() +} + +// Get returns the element with a key if it exists +func (s *Set) Get(key []byte) []byte { + return s.store.Get(makeBKey(key)) +} + +// Remove deletes this key from the set (same as setting value = nil) +func (s *Set) Remove(key []byte) { + s.store.Set(key, nil) +} + +// Exists checks for the existence of the key in the set +func (s *Set) Exists(key []byte) bool { + return len(s.Get(key)) > 0 +} + +// Size returns how many elements are in the set +func (s *Set) Size() int { + return len(s.keys) +} + +// List returns all keys in the set +// It makes a copy, so we don't modify this in place +func (s *Set) List() (keys KeyList) { + out := make([][]byte, len(s.keys)) + for i := range s.keys { + out[i] = append([]byte(nil), s.keys[i]...) + } + return out +} + +// addKey inserts this key, maintaining sorted order, no duplicates +func (s *Set) addKey(key []byte) { + for i, k := range s.keys { + cmp := bytes.Compare(k, key) + // don't add duplicates + if cmp == 0 { + return + } + // insert before the first key greater than input + if cmp > 0 { + // https://github.com/golang/go/wiki/SliceTricks + s.keys = append(s.keys, nil) + copy(s.keys[i+1:], s.keys[i:]) + s.keys[i] = key + return + } + } + // if it is higher than all (or empty keys), append + s.keys = append(s.keys, key) +} + +// removeKey removes this key if it is present, maintaining sorted order +func (s *Set) removeKey(key []byte) { + for i, k := range s.keys { + cmp := bytes.Compare(k, key) + // if there is a match, remove + if cmp == 0 { + s.keys = append(s.keys[:i], s.keys[i+1:]...) + return + } + // if we has the proper location, without finding it, abort + if cmp > 0 { + return + } + } +} + +func (s *Set) loadKeys() { + b := s.store.Get(keys) + if b == nil { + return + } + err := wire.ReadBinaryBytes(b, &s.keys) + // hahaha... just like i love to hate :) + if err != nil { + panic(err) + } +} + +func (s *Set) storeKeys() { + b := wire.BinaryBytes(s.keys) + s.store.Set(keys, b) +} + +// makeBKey prefixes the byte slice for the storage key +func makeBKey(key []byte) []byte { + return append(dataKey, key...) +} + +// KeyList is a sortable list of byte slices +type KeyList [][]byte + +//nolint +func (kl KeyList) Len() int { return len(kl) } +func (kl KeyList) Less(i, j int) bool { return bytes.Compare(kl[i], kl[j]) < 0 } +func (kl KeyList) Swap(i, j int) { kl[i], kl[j] = kl[j], kl[i] } + +var _ sort.Interface = KeyList{} + +// Equals checks for if the two lists have the same content... +// needed as == doesn't work for slices of slices +func (kl KeyList) Equals(kl2 KeyList) bool { + if len(kl) != len(kl2) { + return false + } + for i := range kl { + if !bytes.Equal(kl[i], kl2[i]) { + return false + } + } + return true +} diff --git a/state/set_test.go b/state/set_test.go new file mode 100644 index 000000000..2c2f3a5d2 --- /dev/null +++ b/state/set_test.go @@ -0,0 +1,77 @@ +package state + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +type pair struct { + k []byte + v []byte +} + +type setCase struct { + data []pair + // these are the tests to try out + gets []pair // for each item check the query matches + list KeyList // make sure the set returns the proper list +} + +func TestSet(t *testing.T) { + + a, b, c, d := []byte{0xaa}, []byte{0xbb}, []byte{0xcc}, []byte{0xdd} + + cases := []setCase{ + + // simplest queries + { + []pair{{a, a}, {b, b}, {c, c}}, + []pair{{c, c}, {d, nil}, {b, b}}, + KeyList{a, b, c}, + }, + // out of order + { + []pair{{c, a}, {a, b}, {d, c}, {b, d}}, + []pair{{a, b}, {b, d}}, + KeyList{a, b, c, d}, + }, + // duplicate and removing + { + []pair{{c, a}, {c, c}, {a, d}, {d, d}, {b, b}, {d, nil}, {a, nil}, {a, a}, {b, nil}}, + []pair{{a, a}, {c, c}, {b, nil}}, + KeyList{a, c}, + }, + } + + for i, tc := range cases { + store := NewMemKVStore() + + // initialize a queue and add items + s := NewSet(store) + for _, x := range tc.data { + s.Set(x.k, x.v) + } + + testSet(t, i, s, tc) + // reload and try the queries again + s2 := NewSet(store) + testSet(t, i+10, s2, tc) + } +} + +func testSet(t *testing.T, idx int, s *Set, tc setCase) { + assert := assert.New(t) + i := strconv.Itoa(idx) + + for _, g := range tc.gets { + v := s.Get(g.k) + assert.Equal(g.v, v, i) + e := s.Exists(g.k) + assert.Equal(e, (g.v != nil), i) + } + + l := s.List() + assert.True(tc.list.Equals(l), "%s: %v / %v", i, tc.list, l) +}