diff --git a/CHANGELOG.md b/CHANGELOG.md index 62f0ae124..0fb74788a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -139,6 +139,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (errors) [\#8845](https://github.com/cosmos/cosmos-sdk/pull/8845) Add `Error.Wrap` handy method * [\#8518](https://github.com/cosmos/cosmos-sdk/pull/8518) Help users of multisig wallets debug signature issues. * [\#9573](https://github.com/cosmos/cosmos-sdk/pull/9573) ADR 040 implementation: New DB interface +* [\#9952](https://github.com/cosmos/cosmos-sdk/pull/9952) ADR 040: Implement in-memory DB backend ### Client Breaking Changes diff --git a/db/README.md b/db/README.md new file mode 100644 index 000000000..2fda2a85b --- /dev/null +++ b/db/README.md @@ -0,0 +1,39 @@ +# Key-Value Database + +Databases supporting mappings of arbitrary byte sequences. + +## Interfaces + +The database interface types consist of objects to encapsulate the singular connection to the DB, transactions being made to it, historical version state, and iteration. + +### `DBConnection` + +This interface represents a connection to a versioned key-value database. All versioning operations are performed using methods on this type. + * The `Versions` method returns a `VersionSet` which represents an immutable view of the version history at the current state. + * Version history is modified via the `{Save,Delete}Version` methods. + * Operations on version history do not modify any database contents. + +### `DBReader`, `DBWriter`, and `DBReadWriter` + +These types represent transactions on the database contents. Their methods provide CRUD operations as well as iteration. + * Writeable transactions call `Commit` flushes operations to the source DB. + * All open transactions must be closed with `Discard` or `Commit` before a new version can be saved on the source DB. + * The maximum number of safely concurrent transactions is dependent on the backend implementation. + * A single transaction object is not safe for concurrent use. + * Write conflicts on concurrent transactions will cause an error at commit time (optimistic concurrency control). + +#### `Iterator` + + * An iterator is invalidated by any writes within its `Domain` to the source transaction while it is open. + * An iterator must call `Close` before its source transaction is closed. + +### `VersionSet` + +This represents a self-contained and immutable view of a database's version history state. It is therefore safe to retain and conccurently access any instance of this object. + +## Implementations + +### In-memory DB + +The in-memory DB in the `db/memdb` package cannot be persisted to disk. It is implemented using the Google [btree](https://pkg.go.dev/github.com/google/btree) library. + * This currently does not perform write conflict detection, so it only supports a single open write-transaction at a time. Multiple and concurrent read-transactions are supported. diff --git a/db/dbtest/benchmark.go b/db/dbtest/benchmark.go new file mode 100644 index 000000000..680dcabbc --- /dev/null +++ b/db/dbtest/benchmark.go @@ -0,0 +1,109 @@ +package dbtest + +import ( + "bytes" + "encoding/binary" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + dbm "github.com/cosmos/cosmos-sdk/db" +) + +func Int64ToBytes(i int64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +func BytesToInt64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} + +func BenchmarkRangeScans(b *testing.B, db dbm.DBReadWriter, dbSize int64) { + b.StopTimer() + + rangeSize := int64(10000) + if dbSize < rangeSize { + b.Errorf("db size %v cannot be less than range size %v", dbSize, rangeSize) + } + + for i := int64(0); i < dbSize; i++ { + bytes := Int64ToBytes(i) + err := db.Set(bytes, bytes) + if err != nil { + // require.NoError() is very expensive (according to profiler), so check manually + b.Fatal(b, err) + } + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + start := rand.Int63n(dbSize - rangeSize) // nolint: gosec + end := start + rangeSize + iter, err := db.Iterator(Int64ToBytes(start), Int64ToBytes(end)) + require.NoError(b, err) + count := 0 + for iter.Next() { + count++ + } + iter.Close() + require.EqualValues(b, rangeSize, count) + } +} + +func BenchmarkRandomReadsWrites(b *testing.B, db dbm.DBReadWriter) { + b.StopTimer() + + // create dummy data + const numItems = int64(1000000) + internal := map[int64]int64{} + for i := 0; i < int(numItems); i++ { + internal[int64(i)] = int64(0) + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + { + idx := rand.Int63n(numItems) // nolint: gosec + internal[idx]++ + val := internal[idx] + idxBytes := Int64ToBytes(idx) + valBytes := Int64ToBytes(val) + err := db.Set(idxBytes, valBytes) + if err != nil { + // require.NoError() is very expensive (according to profiler), so check manually + b.Fatal(b, err) + } + } + + { + idx := rand.Int63n(numItems) // nolint: gosec + valExp := internal[idx] + idxBytes := Int64ToBytes(idx) + valBytes, err := db.Get(idxBytes) + if err != nil { + b.Fatal(b, err) + } + if valExp == 0 { + if !bytes.Equal(valBytes, nil) { + b.Errorf("Expected %v for %v, got %X", nil, idx, valBytes) + break + } + } else { + if len(valBytes) != 8 { + b.Errorf("Expected length 8 for %v, got %X", idx, valBytes) + break + } + valGot := BytesToInt64(valBytes) + if valExp != valGot { + b.Errorf("Expected %v for %v, got %v", valExp, idx, valGot) + break + } + } + } + + } +} diff --git a/db/dbtest/testcases.go b/db/dbtest/testcases.go new file mode 100644 index 000000000..475950111 --- /dev/null +++ b/db/dbtest/testcases.go @@ -0,0 +1,431 @@ +package dbtest + +import ( + "fmt" + "sort" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + dbm "github.com/cosmos/cosmos-sdk/db" +) + +type Loader func(*testing.T, string) dbm.DBConnection + +func ikey(i int) []byte { return []byte(fmt.Sprintf("key-%03d", i)) } +func ival(i int) []byte { return []byte(fmt.Sprintf("val-%03d", i)) } + +func DoTestGetSetHasDelete(t *testing.T, load Loader) { + t.Helper() + db := load(t, t.TempDir()) + + var txn dbm.DBReadWriter + var view dbm.DBReader + + view = db.Reader() + require.NotNil(t, view) + + // A nonexistent key should return nil. + value, err := view.Get([]byte("a")) + require.NoError(t, err) + require.Nil(t, value) + + ok, err := view.Has([]byte("a")) + require.NoError(t, err) + require.False(t, ok) + + txn = db.ReadWriter() + + // Set and get a value. + err = txn.Set([]byte("a"), []byte{0x01}) + require.NoError(t, err) + + ok, err = txn.Has([]byte("a")) + require.NoError(t, err) + require.True(t, ok) + + value, err = txn.Get([]byte("a")) + require.NoError(t, err) + require.Equal(t, []byte{0x01}, value) + + // New value is not visible from another txn. + ok, err = view.Has([]byte("a")) + require.NoError(t, err) + require.False(t, ok) + + // Deleting a non-existent value is fine. + err = txn.Delete([]byte("x")) + require.NoError(t, err) + + // Delete a value. + err = txn.Delete([]byte("a")) + require.NoError(t, err) + + value, err = txn.Get([]byte("a")) + require.NoError(t, err) + require.Nil(t, value) + + err = txn.Set([]byte("b"), []byte{0x02}) + require.NoError(t, err) + + view.Discard() + require.NoError(t, txn.Commit()) + + txn = db.ReadWriter() + + // Verify committed values. + value, err = txn.Get([]byte("b")) + require.NoError(t, err) + require.Equal(t, []byte{0x02}, value) + + ok, err = txn.Has([]byte("a")) + require.NoError(t, err) + require.False(t, ok) + + // Setting, getting, and deleting an empty key should error. + _, err = txn.Get([]byte{}) + require.Equal(t, dbm.ErrKeyEmpty, err) + _, err = txn.Get(nil) + require.Equal(t, dbm.ErrKeyEmpty, err) + + _, err = txn.Has([]byte{}) + require.Equal(t, dbm.ErrKeyEmpty, err) + _, err = txn.Has(nil) + require.Equal(t, dbm.ErrKeyEmpty, err) + + err = txn.Set([]byte{}, []byte{0x01}) + require.Equal(t, dbm.ErrKeyEmpty, err) + err = txn.Set(nil, []byte{0x01}) + require.Equal(t, dbm.ErrKeyEmpty, err) + + err = txn.Delete([]byte{}) + require.Equal(t, dbm.ErrKeyEmpty, err) + err = txn.Delete(nil) + require.Equal(t, dbm.ErrKeyEmpty, err) + + // Setting a nil value should error, but an empty value is fine. + err = txn.Set([]byte("x"), nil) + require.Equal(t, dbm.ErrValueNil, err) + + err = txn.Set([]byte("x"), []byte{}) + require.NoError(t, err) + + value, err = txn.Get([]byte("x")) + require.NoError(t, err) + require.Equal(t, []byte{}, value) + + require.NoError(t, txn.Commit()) + + require.NoError(t, db.Close()) +} + +func DoTestIterators(t *testing.T, load Loader) { + t.Helper() + db := load(t, t.TempDir()) + type entry struct { + key []byte + val string + } + entries := []entry{ + {[]byte{0}, "0"}, + {[]byte{0, 0}, "0 0"}, + {[]byte{0, 1}, "0 1"}, + {[]byte{0, 2}, "0 2"}, + {[]byte{1}, "1"}, + } + txn := db.ReadWriter() + for _, e := range entries { + require.NoError(t, txn.Set(e.key, []byte(e.val))) + } + require.NoError(t, txn.Commit()) + + testRange := func(t *testing.T, iter dbm.Iterator, expected []string) { + i := 0 + for ; iter.Next(); i++ { + expectedValue := expected[i] + value := iter.Value() + require.EqualValues(t, expectedValue, string(value), "i=%v", i) + } + require.Equal(t, len(expected), i) + } + + type testCase struct { + start, end []byte + expected []string + } + + view := db.Reader() + + iterCases := []testCase{ + {nil, nil, []string{"0", "0 0", "0 1", "0 2", "1"}}, + {[]byte{0x00}, nil, []string{"0", "0 0", "0 1", "0 2", "1"}}, + {[]byte{0x00}, []byte{0x00, 0x01}, []string{"0", "0 0"}}, + {[]byte{0x00}, []byte{0x01}, []string{"0", "0 0", "0 1", "0 2"}}, + {[]byte{0x00, 0x01}, []byte{0x01}, []string{"0 1", "0 2"}}, + {nil, []byte{0x01}, []string{"0", "0 0", "0 1", "0 2"}}, + } + for i, tc := range iterCases { + t.Logf("Iterator case %d: [%v, %v)", i, tc.start, tc.end) + it, err := view.Iterator(tc.start, tc.end) + require.NoError(t, err) + testRange(t, it, tc.expected) + it.Close() + } + + reverseCases := []testCase{ + {nil, nil, []string{"1", "0 2", "0 1", "0 0", "0"}}, + {[]byte{0x00}, nil, []string{"1", "0 2", "0 1", "0 0", "0"}}, + {[]byte{0x00}, []byte{0x00, 0x01}, []string{"0 0", "0"}}, + {[]byte{0x00}, []byte{0x01}, []string{"0 2", "0 1", "0 0", "0"}}, + {[]byte{0x00, 0x01}, []byte{0x01}, []string{"0 2", "0 1"}}, + {nil, []byte{0x01}, []string{"0 2", "0 1", "0 0", "0"}}, + } + for i, tc := range reverseCases { + t.Logf("ReverseIterator case %d: [%v, %v)", i, tc.start, tc.end) + it, err := view.ReverseIterator(tc.start, tc.end) + require.NoError(t, err) + testRange(t, it, tc.expected) + it.Close() + } + + view.Discard() + require.NoError(t, db.Close()) +} + +func DoTestVersioning(t *testing.T, load Loader) { + t.Helper() + db := load(t, t.TempDir()) + view := db.Reader() + require.NotNil(t, view) + + // Write, then read different versions + txn := db.ReadWriter() + require.NoError(t, txn.Set([]byte("0"), []byte("a"))) + require.NoError(t, txn.Set([]byte("1"), []byte("b"))) + require.NoError(t, txn.Commit()) + v1, err := db.SaveNextVersion() + require.NoError(t, err) + + txn = db.ReadWriter() + require.NoError(t, txn.Set([]byte("0"), []byte("c"))) + require.NoError(t, txn.Delete([]byte("1"))) + require.NoError(t, txn.Set([]byte("2"), []byte("c"))) + require.NoError(t, txn.Commit()) + v2, err := db.SaveNextVersion() + require.NoError(t, err) + + // Skip to a future version + v3 := (v2 + 2) + require.NoError(t, db.SaveVersion(v3)) + + // Try to save to a past version + err = db.SaveVersion(v2) + require.Error(t, err) + + // Verify existing versions + versions, err := db.Versions() + require.NoError(t, err) + require.Equal(t, 3, versions.Count()) + var all []uint64 + for it := versions.Iterator(); it.Next(); { + all = append(all, it.Value()) + } + sort.Slice(all, func(i, j int) bool { return all[i] < all[j] }) + require.Equal(t, []uint64{v1, v2, v3}, all) + require.Equal(t, v3, versions.Last()) + + view, err = db.ReaderAt(v1) + require.NoError(t, err) + require.NotNil(t, view) + val, err := view.Get([]byte("0")) + require.Equal(t, []byte("a"), val) + require.NoError(t, err) + val, err = view.Get([]byte("1")) + require.Equal(t, []byte("b"), val) + require.NoError(t, err) + has, err := view.Has([]byte("2")) + require.False(t, has) + + view, err = db.ReaderAt(v2) + require.NoError(t, err) + require.NotNil(t, view) + val, err = view.Get([]byte("0")) + require.Equal(t, []byte("c"), val) + require.NoError(t, err) + val, err = view.Get([]byte("2")) + require.Equal(t, []byte("c"), val) + require.NoError(t, err) + has, err = view.Has([]byte("1")) + require.False(t, has) + + // Try to read an invalid version + view, err = db.ReaderAt(versions.Last() + 1) + require.Equal(t, dbm.ErrVersionDoesNotExist, err) + + require.NoError(t, db.DeleteVersion(v2)) + // Try to read a deleted version + view, err = db.ReaderAt(v2) + require.Equal(t, dbm.ErrVersionDoesNotExist, err) + + // Ensure latest version is accurate + prev := v3 + for i := 0; i < 10; i++ { + w := db.Writer() + require.NoError(t, w.Set(ikey(i), ival(i))) + require.NoError(t, w.Commit()) + ver, err := db.SaveNextVersion() + require.NoError(t, err) + require.Equal(t, prev+1, ver) + versions, err := db.Versions() + require.NoError(t, err) + require.Equal(t, ver, versions.Last()) + prev = ver + } + + require.NoError(t, db.Close()) +} + +func DoTestTransactions(t *testing.T, load Loader, multipleWriters bool) { + t.Helper() + db := load(t, t.TempDir()) + // Both methods should work in a DBWriter context + writerFuncs := []func() dbm.DBWriter{ + db.Writer, + func() dbm.DBWriter { return db.ReadWriter() }, + } + + for _, getWriter := range writerFuncs { + // Uncommitted records are not saved + t.Run("no commit", func(t *testing.T) { + t.Helper() + view := db.Reader() + defer view.Discard() + tx := getWriter() + defer tx.Discard() + require.NoError(t, tx.Set([]byte("0"), []byte("a"))) + v, err := view.Get([]byte("0")) + require.NoError(t, err) + require.Nil(t, v) + }) + + // Try to commit version with open txns + t.Run("open transactions", func(t *testing.T) { + t.Helper() + tx := getWriter() + tx.Set([]byte("2"), []byte("a")) + _, err := db.SaveNextVersion() + require.Equal(t, dbm.ErrOpenTransactions, err) + tx.Discard() + }) + + // Continue only if the backend supports multiple concurrent writers + if !multipleWriters { + continue + } + + // Writing separately to same key causes a conflict + t.Run("write conflict", func(t *testing.T) { + t.Helper() + tx1 := getWriter() + tx2 := db.ReadWriter() + tx2.Get([]byte("1")) + require.NoError(t, tx1.Set([]byte("1"), []byte("b"))) + require.NoError(t, tx2.Set([]byte("1"), []byte("c"))) + require.NoError(t, tx1.Commit()) + require.Error(t, tx2.Commit()) + }) + + // Writing from concurrent txns + t.Run("concurrent transactions", func(t *testing.T) { + t.Helper() + var wg sync.WaitGroup + setkv := func(k, v []byte) { + defer wg.Done() + tx := getWriter() + require.NoError(t, tx.Set(k, v)) + require.NoError(t, tx.Commit()) + } + n := 10 + wg.Add(n) + for i := 0; i < n; i++ { + go setkv(ikey(i), ival(i)) + } + wg.Wait() + view := db.Reader() + defer view.Discard() + v, err := view.Get(ikey(0)) + require.NoError(t, err) + require.Equal(t, ival(0), v) + }) + + } + require.NoError(t, db.Close()) +} + +// Tests reloading a saved DB from disk. +func DoTestReloadDB(t *testing.T, load Loader) { + t.Helper() + dirname := t.TempDir() + db := load(t, dirname) + + txn := db.Writer() + for i := 0; i < 100; i++ { + require.NoError(t, txn.Set(ikey(i), ival(i))) + } + require.NoError(t, txn.Commit()) + first, err := db.SaveNextVersion() + require.NoError(t, err) + + txn = db.Writer() + for i := 0; i < 50; i++ { // overwrite some values + require.NoError(t, txn.Set(ikey(i), ival(i*10))) + } + require.NoError(t, txn.Commit()) + last, err := db.SaveNextVersion() + require.NoError(t, err) + + txn = db.Writer() + for i := 100; i < 150; i++ { + require.NoError(t, txn.Set(ikey(i), ival(i))) + } + require.NoError(t, txn.Commit()) + db.Close() + + // Reload and check each saved version + db = load(t, dirname) + + view, err := db.ReaderAt(first) + require.NoError(t, err) + for i := 0; i < 100; i++ { + v, err := view.Get(ikey(i)) + require.NoError(t, err) + require.Equal(t, ival(i), v) + } + view.Discard() + + view, err = db.ReaderAt(last) + require.NoError(t, err) + for i := 0; i < 50; i++ { + v, err := view.Get(ikey(i)) + require.NoError(t, err) + require.Equal(t, ival(i*10), v) + } + for i := 50; i < 100; i++ { + v, err := view.Get(ikey(i)) + require.NoError(t, err) + require.Equal(t, ival(i), v) + } + view.Discard() + + // Load working version + view = db.Reader() + for i := 100; i < 150; i++ { + v, err := view.Get(ikey(i)) + require.NoError(t, err) + require.Equal(t, ival(i), v) + } + view.Discard() + + require.NoError(t, db.Close()) +} diff --git a/db/go.mod b/db/go.mod index e5b6858d8..0d75f20d2 100644 --- a/db/go.mod +++ b/db/go.mod @@ -3,5 +3,6 @@ go 1.15 module github.com/cosmos/cosmos-sdk/db require ( + github.com/google/btree v1.0.0 github.com/stretchr/testify v1.7.0 ) diff --git a/db/go.sum b/db/go.sum index acb88a48f..c7801a435 100644 --- a/db/go.sum +++ b/db/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/db/memdb/db.go b/db/memdb/db.go new file mode 100644 index 000000000..9dd8af7c6 --- /dev/null +++ b/db/memdb/db.go @@ -0,0 +1,274 @@ +package memdb + +import ( + "bytes" + "fmt" + "sync" + "sync/atomic" + + dbm "github.com/cosmos/cosmos-sdk/db" + "github.com/google/btree" +) + +const ( + // The approximate number of items and children per B-tree node. Tuned with benchmarks. + bTreeDegree = 32 +) + +// MemDB is an in-memory database backend using a B-tree for storage. +// +// For performance reasons, all given and returned keys and values are pointers to the in-memory +// database, so modifying them will cause the stored values to be modified as well. All DB methods +// already specify that keys and values should be considered read-only, but this is especially +// important with MemDB. +// +// Versioning is implemented by maintaining references to copy-on-write clones of the backing btree. +// +// TODO: Currently transactions do not detect write conflicts, so writers cannot be used concurrently. +type MemDB struct { + btree *btree.BTree // Main contents + mtx sync.RWMutex // Guards version history + saved map[uint64]*btree.BTree // Past versions + vmgr *dbm.VersionManager // Mirrors version keys + openWriters int32 // Open writers +} + +type dbTxn struct { + btree *btree.BTree + db *MemDB +} +type dbWriter struct{ dbTxn } + +var ( + _ dbm.DBConnection = (*MemDB)(nil) + _ dbm.DBReader = (*dbTxn)(nil) + _ dbm.DBWriter = (*dbWriter)(nil) + _ dbm.DBReadWriter = (*dbWriter)(nil) +) + +// item is a btree.Item with byte slices as keys and values +type item struct { + key []byte + value []byte +} + +// NewDB creates a new in-memory database. +func NewDB() *MemDB { + return &MemDB{ + btree: btree.New(bTreeDegree), + saved: make(map[uint64]*btree.BTree), + vmgr: dbm.NewVersionManager(nil), + } +} + +func (db *MemDB) newTxn(tree *btree.BTree) dbTxn { + return dbTxn{tree, db} +} + +// Close implements DB. +// Close is a noop since for an in-memory database, we don't have a destination to flush +// contents to nor do we want any data loss on invoking Close(). +// See the discussion in https://github.com/tendermint/tendermint/libs/pull/56 +func (db *MemDB) Close() error { + return nil +} + +// Versions implements DBConnection. +func (db *MemDB) Versions() (dbm.VersionSet, error) { + db.mtx.RLock() + defer db.mtx.RUnlock() + return db.vmgr, nil +} + +// Reader implements DBConnection. +func (db *MemDB) Reader() dbm.DBReader { + db.mtx.RLock() + defer db.mtx.RUnlock() + ret := db.newTxn(db.btree) + return &ret +} + +// ReaderAt implements DBConnection. +func (db *MemDB) ReaderAt(version uint64) (dbm.DBReader, error) { + db.mtx.RLock() + defer db.mtx.RUnlock() + tree, ok := db.saved[version] + if !ok { + return nil, dbm.ErrVersionDoesNotExist + } + ret := db.newTxn(tree) + return &ret, nil +} + +// Writer implements DBConnection. +func (db *MemDB) Writer() dbm.DBWriter { + return db.ReadWriter() +} + +// ReadWriter implements DBConnection. +func (db *MemDB) ReadWriter() dbm.DBReadWriter { + db.mtx.RLock() + defer db.mtx.RUnlock() + atomic.AddInt32(&db.openWriters, 1) + // Clone creates a copy-on-write extension of the current tree + return &dbWriter{db.newTxn(db.btree.Clone())} +} + +func (db *MemDB) save(target uint64) (uint64, error) { + db.mtx.Lock() + defer db.mtx.Unlock() + if db.openWriters > 0 { + return 0, dbm.ErrOpenTransactions + } + + newVmgr := db.vmgr.Copy() + target, err := newVmgr.Save(target) + if err != nil { + return 0, err + } + db.saved[target] = db.btree + db.vmgr = newVmgr + return target, nil +} + +// SaveVersion implements DBConnection. +func (db *MemDB) SaveNextVersion() (uint64, error) { + return db.save(0) +} + +// SaveNextVersion implements DBConnection. +func (db *MemDB) SaveVersion(target uint64) error { + if target == 0 { + return dbm.ErrInvalidVersion + } + _, err := db.save(target) + return err +} + +// DeleteVersion implements DBConnection. +func (db *MemDB) DeleteVersion(target uint64) error { + db.mtx.Lock() + defer db.mtx.Unlock() + if _, has := db.saved[target]; !has { + return dbm.ErrVersionDoesNotExist + } + delete(db.saved, target) + db.vmgr = db.vmgr.Copy() + db.vmgr.Delete(target) + return nil +} + +// Get implements DBReader. +func (tx *dbTxn) Get(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, dbm.ErrKeyEmpty + } + i := tx.btree.Get(newKey(key)) + if i != nil { + return i.(*item).value, nil + } + return nil, nil +} + +// Has implements DBReader. +func (tx *dbTxn) Has(key []byte) (bool, error) { + if len(key) == 0 { + return false, dbm.ErrKeyEmpty + } + return tx.btree.Has(newKey(key)), nil +} + +// Set implements DBWriter. +func (tx *dbWriter) Set(key []byte, value []byte) error { + if len(key) == 0 { + return dbm.ErrKeyEmpty + } + if value == nil { + return dbm.ErrValueNil + } + tx.btree.ReplaceOrInsert(newPair(key, value)) + return nil +} + +// Delete implements DBWriter. +func (tx *dbWriter) Delete(key []byte) error { + if len(key) == 0 { + return dbm.ErrKeyEmpty + } + tx.btree.Delete(newKey(key)) + return nil +} + +// Iterator implements DBReader. +// Takes out a read-lock on the database until the iterator is closed. +func (tx *dbTxn) Iterator(start, end []byte) (dbm.Iterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, dbm.ErrKeyEmpty + } + return newMemDBIterator(tx, start, end, false), nil +} + +// ReverseIterator implements DBReader. +// Takes out a read-lock on the database until the iterator is closed. +func (tx *dbTxn) ReverseIterator(start, end []byte) (dbm.Iterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, dbm.ErrKeyEmpty + } + return newMemDBIterator(tx, start, end, true), nil +} + +// Commit implements DBWriter. +func (tx *dbWriter) Commit() error { + tx.db.mtx.Lock() + defer tx.db.mtx.Unlock() + defer tx.Discard() + tx.db.btree = tx.btree + return nil +} + +// Discard implements DBReader and DBWriter. +func (tx *dbTxn) Discard() {} +func (tx *dbWriter) Discard() { + atomic.AddInt32(&tx.db.openWriters, -1) +} + +// Print prints the database contents. +func (db *MemDB) Print() error { + db.mtx.RLock() + defer db.mtx.RUnlock() + + db.btree.Ascend(func(i btree.Item) bool { + item := i.(*item) + fmt.Printf("[%X]:\t[%X]\n", item.key, item.value) + return true + }) + return nil +} + +// Stats implements DBConnection. +func (db *MemDB) Stats() map[string]string { + db.mtx.RLock() + defer db.mtx.RUnlock() + + stats := make(map[string]string) + stats["database.type"] = "memDB" + stats["database.size"] = fmt.Sprintf("%d", db.btree.Len()) + return stats +} + +// Less implements btree.Item. +func (i *item) Less(other btree.Item) bool { + // this considers nil == []byte{}, but that's ok since we handle nil endpoints + // in iterators specially anyway + return bytes.Compare(i.key, other.(*item).key) == -1 +} + +// newKey creates a new key item. +func newKey(key []byte) *item { + return &item{key: key} +} + +// newPair creates a new pair item. +func newPair(key, value []byte) *item { + return &item{key: key, value: value} +} diff --git a/db/memdb/db_test.go b/db/memdb/db_test.go new file mode 100644 index 000000000..a3ac242ea --- /dev/null +++ b/db/memdb/db_test.go @@ -0,0 +1,49 @@ +package memdb + +import ( + "testing" + + dbm "github.com/cosmos/cosmos-sdk/db" + "github.com/cosmos/cosmos-sdk/db/dbtest" +) + +func BenchmarkMemDBRangeScans1M(b *testing.B) { + db := NewDB() + defer db.Close() + + dbtest.BenchmarkRangeScans(b, db.ReadWriter(), int64(1e6)) +} + +func BenchmarkMemDBRangeScans10M(b *testing.B) { + db := NewDB() + defer db.Close() + + dbtest.BenchmarkRangeScans(b, db.ReadWriter(), int64(10e6)) +} + +func BenchmarkMemDBRandomReadsWrites(b *testing.B) { + db := NewDB() + defer db.Close() + + dbtest.BenchmarkRandomReadsWrites(b, db.ReadWriter()) +} + +func load(t *testing.T, _ string) dbm.DBConnection { + return NewDB() +} + +func TestGetSetHasDelete(t *testing.T) { + dbtest.DoTestGetSetHasDelete(t, load) +} + +func TestIterators(t *testing.T) { + dbtest.DoTestIterators(t, load) +} + +func TestVersioning(t *testing.T) { + dbtest.DoTestVersioning(t, load) +} + +func TestTransactions(t *testing.T) { + dbtest.DoTestTransactions(t, load, false) +} diff --git a/db/memdb/iterator.go b/db/memdb/iterator.go new file mode 100644 index 000000000..fad55c792 --- /dev/null +++ b/db/memdb/iterator.go @@ -0,0 +1,139 @@ +package memdb + +import ( + "bytes" + "context" + + tmdb "github.com/cosmos/cosmos-sdk/db" + "github.com/google/btree" +) + +const ( + // Size of the channel buffer between traversal goroutine and iterator. Using an unbuffered + // channel causes two context switches per item sent, while buffering allows more work per + // context switch. Tuned with benchmarks. + chBufferSize = 64 +) + +// memDBIterator is a memDB iterator. +type memDBIterator struct { + ch <-chan *item + cancel context.CancelFunc + item *item + start []byte + end []byte +} + +var _ tmdb.Iterator = (*memDBIterator)(nil) + +// newMemDBIterator creates a new memDBIterator. +// A visitor is passed to the btree which streams items to the iterator over a channel. Advancing +// the iterator pulls items from the channel, returning execution to the visitor. +// The reverse case needs some special handling, since we use [start, end) while btree uses (start, end] +func newMemDBIterator(tx *dbTxn, start []byte, end []byte, reverse bool) *memDBIterator { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan *item, chBufferSize) + iter := &memDBIterator{ + ch: ch, + cancel: cancel, + start: start, + end: end, + } + + go func() { + defer close(ch) + // Because we use [start, end) for reverse ranges, while btree uses (start, end], we need + // the following variables to handle some reverse iteration conditions ourselves. + var ( + skipEqual []byte + abortLessThan []byte + ) + visitor := func(i btree.Item) bool { + item := i.(*item) + if skipEqual != nil && bytes.Equal(item.key, skipEqual) { + skipEqual = nil + return true + } + if abortLessThan != nil && bytes.Compare(item.key, abortLessThan) == -1 { + return false + } + select { + case <-ctx.Done(): + return false + case ch <- item: + return true + } + } + switch { + case start == nil && end == nil && !reverse: + tx.btree.Ascend(visitor) + case start == nil && end == nil && reverse: + tx.btree.Descend(visitor) + case end == nil && !reverse: + // must handle this specially, since nil is considered less than anything else + tx.btree.AscendGreaterOrEqual(newKey(start), visitor) + case !reverse: + tx.btree.AscendRange(newKey(start), newKey(end), visitor) + case end == nil: + // abort after start, since we use [start, end) while btree uses (start, end] + abortLessThan = start + tx.btree.Descend(visitor) + default: + // skip end and abort after start, since we use [start, end) while btree uses (start, end] + skipEqual = end + abortLessThan = start + tx.btree.DescendLessOrEqual(newKey(end), visitor) + } + }() + + return iter +} + +// Close implements Iterator. +func (i *memDBIterator) Close() error { + i.cancel() + for range i.ch { // drain channel + } + i.item = nil + return nil +} + +// Domain implements Iterator. +func (i *memDBIterator) Domain() ([]byte, []byte) { + return i.start, i.end +} + +// Next implements Iterator. +func (i *memDBIterator) Next() bool { + item, ok := <-i.ch + switch { + case ok: + i.item = item + default: + i.item = nil + } + return i.item != nil +} + +// Error implements Iterator. +func (i *memDBIterator) Error() error { + return nil +} + +// Key implements Iterator. +func (i *memDBIterator) Key() []byte { + i.assertIsValid() + return i.item.key +} + +// Value implements Iterator. +func (i *memDBIterator) Value() []byte { + i.assertIsValid() + return i.item.value +} + +func (i *memDBIterator) assertIsValid() { + if i.item == nil { + panic("iterator is invalid") + } +} diff --git a/db/types.go b/db/types.go index b9363fd50..69a494a1e 100644 --- a/db/types.go +++ b/db/types.go @@ -44,7 +44,7 @@ type DBConnection interface { // Opens a write-only transaction at the current version. Writer() DBWriter - // Returns all saved versions + // Returns all saved versions as an immutable set which is safe for concurrent access. Versions() (VersionSet, error) // Saves the current contents of the database and returns the next version ID, which will be diff --git a/db/version_manager.go b/db/version_manager.go new file mode 100644 index 000000000..55d88e97a --- /dev/null +++ b/db/version_manager.go @@ -0,0 +1,135 @@ +package db + +import ( + "fmt" +) + +// VersionManager encapsulates the current valid versions of a DB and computes +// the next version. +type VersionManager struct { + versions map[uint64]struct{} + initial, last uint64 +} + +var _ VersionSet = (*VersionManager)(nil) + +// NewVersionManager creates a VersionManager from a sorted slice of ascending version ids. +func NewVersionManager(versions []uint64) *VersionManager { + vmap := make(map[uint64]struct{}) + var init, last uint64 + for _, ver := range versions { + vmap[ver] = struct{}{} + } + if len(versions) > 0 { + init = versions[0] + last = versions[len(versions)-1] + } + return &VersionManager{versions: vmap, initial: init, last: last} +} + +// Exists implements VersionSet. +func (vm *VersionManager) Exists(version uint64) bool { + _, has := vm.versions[version] + return has +} + +// Last implements VersionSet. +func (vm *VersionManager) Last() uint64 { + return vm.last +} + +func (vm *VersionManager) Initial() uint64 { + return vm.initial +} + +func (vm *VersionManager) Next() uint64 { + return vm.Last() + 1 +} + +func (vm *VersionManager) Save(target uint64) (uint64, error) { + next := vm.Next() + if target == 0 { + target = next + } + if target < next { + return 0, fmt.Errorf( + "target version cannot be less than next sequential version (%v < %v)", target, next) + } + if _, has := vm.versions[target]; has { + return 0, fmt.Errorf("version exists: %v", target) + } + + vm.versions[target] = struct{}{} + vm.last = target + if len(vm.versions) == 1 { + vm.initial = target + } + return target, nil +} + +func findLimit(m map[uint64]struct{}, cmp func(uint64, uint64) bool, init uint64) uint64 { + for x, _ := range m { + if cmp(x, init) { + init = x + } + } + return init +} + +func (vm *VersionManager) Delete(target uint64) { + delete(vm.versions, target) + if target == vm.last { + vm.last = findLimit(vm.versions, func(x, max uint64) bool { return x > max }, 0) + } + if target == vm.initial { + vm.initial = findLimit(vm.versions, func(x, min uint64) bool { return x < min }, vm.last) + } +} + +type vmIterator struct { + ch <-chan uint64 + open bool + buf uint64 +} + +func (vi *vmIterator) Next() bool { + vi.buf, vi.open = <-vi.ch + return vi.open +} +func (vi *vmIterator) Value() uint64 { return vi.buf } + +// Iterator implements VersionSet. +func (vm *VersionManager) Iterator() VersionIterator { + ch := make(chan uint64) + go func() { + for ver, _ := range vm.versions { + ch <- ver + } + close(ch) + }() + return &vmIterator{ch: ch} +} + +// Count implements VersionSet. +func (vm *VersionManager) Count() int { return len(vm.versions) } + +// Equal implements VersionSet. +func (vm *VersionManager) Equal(that VersionSet) bool { + if vm.Count() != that.Count() { + return false + } + for it := that.Iterator(); it.Next(); { + if !vm.Exists(it.Value()) { + return false + } + } + return true +} + +func (vm *VersionManager) Copy() *VersionManager { + vmap := make(map[uint64]struct{}) + for ver, _ := range vm.versions { + vmap[ver] = struct{}{} + } + return &VersionManager{versions: vmap, initial: vm.initial, last: vm.last} +} diff --git a/db/version_manager_test.go b/db/version_manager_test.go new file mode 100644 index 000000000..53e8754a7 --- /dev/null +++ b/db/version_manager_test.go @@ -0,0 +1,59 @@ +package db_test + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/require" + + dbm "github.com/cosmos/cosmos-sdk/db" +) + +// Test that VersionManager satisfies the behavior of VersionSet +func TestVersionManager(t *testing.T) { + vm := dbm.NewVersionManager(nil) + require.Equal(t, uint64(0), vm.Last()) + require.Equal(t, 0, vm.Count()) + require.True(t, vm.Equal(vm)) + require.False(t, vm.Exists(0)) + + id, err := vm.Save(0) + require.NoError(t, err) + require.Equal(t, uint64(1), id) + require.True(t, vm.Exists(id)) + id2, err := vm.Save(0) + require.NoError(t, err) + require.True(t, vm.Exists(id2)) + id3, err := vm.Save(0) + require.NoError(t, err) + require.True(t, vm.Exists(id3)) + + id, err = vm.Save(id) // can't save existing id + require.Error(t, err) + + id, err = vm.Save(0) + require.NoError(t, err) + require.True(t, vm.Exists(id)) + vm.Delete(id) + require.False(t, vm.Exists(id)) + + vm.Delete(1) + require.False(t, vm.Exists(1)) + require.Equal(t, id2, vm.Initial()) + require.Equal(t, id3, vm.Last()) + + var all []uint64 + for it := vm.Iterator(); it.Next(); { + all = append(all, it.Value()) + } + sort.Slice(all, func(i, j int) bool { return all[i] < all[j] }) + require.Equal(t, []uint64{id2, id3}, all) + + vmc := vm.Copy() + id, err = vmc.Save(0) + require.NoError(t, err) + require.False(t, vm.Exists(id)) // true copy is made + + vm2 := dbm.NewVersionManager([]uint64{id2, id3}) + require.True(t, vm.Equal(vm2)) +}