feat: ADR 040: Implement in-memory DB backend (#9952)

## Description

Implements an in-memory backend for the DB interface introduced by https://github.com/cosmos/cosmos-sdk/pull/9573 and specified by [ADR-040](eb7d939f86/docs/architecture/adr-040-storage-and-smt-state-commitments.md). This expands on the [btree](https://pkg.go.dev/github.com/google/btree)-based [`MemDB`](https://github.com/tendermint/tm-db/tree/master/memdb) from `tm-db` by using copy-on-write clones to implement versioning.

Resolves: https://github.com/vulcanize/cosmos-sdk/issues/2

Will move out of draft once https://github.com/cosmos/cosmos-sdk/pull/9573 is merged and rebased on.

### Author Checklist

*All items are required. Please add a note to the item if the item is not applicable and
please add links to any relevant follow up issues.*

I have...

- [x] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [x] added `!` to the type prefix if API or client breaking change
- [x] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#pr-targeting))
- [x] provided a link to the relevant issue or specification
- [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/master/docs/building-modules) n/a
- [x] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#testing)
- [x] added a changelog entry to `CHANGELOG.md`
- [x] included comments for [documenting Go code](https://blog.golang.org/godoc)
- [x] updated the relevant documentation or specification
- [x] reviewed "Files changed" and left comments if necessary
- [ ] confirmed all CI checks have passed

### Reviewers Checklist

*All items are required. Please add a note if the item is not applicable and please add
your handle next to the items reviewed if you only reviewed selected items.*

I have...

- [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [ ] confirmed `!` in the type prefix if API or client breaking change
- [ ] confirmed all author checklist items have been addressed 
- [ ] reviewed state machine logic
- [ ] reviewed API design and naming
- [ ] reviewed documentation is accurate
- [ ] reviewed tests and test coverage
- [ ] manually tested (if applicable)
This commit is contained in:
Roy Crihfield 2021-08-31 16:09:37 +08:00 committed by GitHub
parent f98dc675c9
commit 2c31451a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1240 additions and 1 deletions

View File

@ -139,6 +139,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* (errors) [\#8845](https://github.com/cosmos/cosmos-sdk/pull/8845) Add `Error.Wrap` handy method
* [\#8518](https://github.com/cosmos/cosmos-sdk/pull/8518) Help users of multisig wallets debug signature issues.
* [\#9573](https://github.com/cosmos/cosmos-sdk/pull/9573) ADR 040 implementation: New DB interface
* [\#9952](https://github.com/cosmos/cosmos-sdk/pull/9952) ADR 040: Implement in-memory DB backend
### Client Breaking Changes

39
db/README.md Normal file
View File

@ -0,0 +1,39 @@
# Key-Value Database
Databases supporting mappings of arbitrary byte sequences.
## Interfaces
The database interface types consist of objects to encapsulate the singular connection to the DB, transactions being made to it, historical version state, and iteration.
### `DBConnection`
This interface represents a connection to a versioned key-value database. All versioning operations are performed using methods on this type.
* The `Versions` method returns a `VersionSet` which represents an immutable view of the version history at the current state.
* Version history is modified via the `{Save,Delete}Version` methods.
* Operations on version history do not modify any database contents.
### `DBReader`, `DBWriter`, and `DBReadWriter`
These types represent transactions on the database contents. Their methods provide CRUD operations as well as iteration.
* Writeable transactions call `Commit` flushes operations to the source DB.
* All open transactions must be closed with `Discard` or `Commit` before a new version can be saved on the source DB.
* The maximum number of safely concurrent transactions is dependent on the backend implementation.
* A single transaction object is not safe for concurrent use.
* Write conflicts on concurrent transactions will cause an error at commit time (optimistic concurrency control).
#### `Iterator`
* An iterator is invalidated by any writes within its `Domain` to the source transaction while it is open.
* An iterator must call `Close` before its source transaction is closed.
### `VersionSet`
This represents a self-contained and immutable view of a database's version history state. It is therefore safe to retain and conccurently access any instance of this object.
## Implementations
### In-memory DB
The in-memory DB in the `db/memdb` package cannot be persisted to disk. It is implemented using the Google [btree](https://pkg.go.dev/github.com/google/btree) library.
* This currently does not perform write conflict detection, so it only supports a single open write-transaction at a time. Multiple and concurrent read-transactions are supported.

109
db/dbtest/benchmark.go Normal file
View File

@ -0,0 +1,109 @@
package dbtest
import (
"bytes"
"encoding/binary"
"math/rand"
"testing"
"github.com/stretchr/testify/require"
dbm "github.com/cosmos/cosmos-sdk/db"
)
func Int64ToBytes(i int64) []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(i))
return buf
}
func BytesToInt64(buf []byte) int64 {
return int64(binary.BigEndian.Uint64(buf))
}
func BenchmarkRangeScans(b *testing.B, db dbm.DBReadWriter, dbSize int64) {
b.StopTimer()
rangeSize := int64(10000)
if dbSize < rangeSize {
b.Errorf("db size %v cannot be less than range size %v", dbSize, rangeSize)
}
for i := int64(0); i < dbSize; i++ {
bytes := Int64ToBytes(i)
err := db.Set(bytes, bytes)
if err != nil {
// require.NoError() is very expensive (according to profiler), so check manually
b.Fatal(b, err)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
start := rand.Int63n(dbSize - rangeSize) // nolint: gosec
end := start + rangeSize
iter, err := db.Iterator(Int64ToBytes(start), Int64ToBytes(end))
require.NoError(b, err)
count := 0
for iter.Next() {
count++
}
iter.Close()
require.EqualValues(b, rangeSize, count)
}
}
func BenchmarkRandomReadsWrites(b *testing.B, db dbm.DBReadWriter) {
b.StopTimer()
// create dummy data
const numItems = int64(1000000)
internal := map[int64]int64{}
for i := 0; i < int(numItems); i++ {
internal[int64(i)] = int64(0)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
{
idx := rand.Int63n(numItems) // nolint: gosec
internal[idx]++
val := internal[idx]
idxBytes := Int64ToBytes(idx)
valBytes := Int64ToBytes(val)
err := db.Set(idxBytes, valBytes)
if err != nil {
// require.NoError() is very expensive (according to profiler), so check manually
b.Fatal(b, err)
}
}
{
idx := rand.Int63n(numItems) // nolint: gosec
valExp := internal[idx]
idxBytes := Int64ToBytes(idx)
valBytes, err := db.Get(idxBytes)
if err != nil {
b.Fatal(b, err)
}
if valExp == 0 {
if !bytes.Equal(valBytes, nil) {
b.Errorf("Expected %v for %v, got %X", nil, idx, valBytes)
break
}
} else {
if len(valBytes) != 8 {
b.Errorf("Expected length 8 for %v, got %X", idx, valBytes)
break
}
valGot := BytesToInt64(valBytes)
if valExp != valGot {
b.Errorf("Expected %v for %v, got %v", valExp, idx, valGot)
break
}
}
}
}
}

431
db/dbtest/testcases.go Normal file
View File

@ -0,0 +1,431 @@
package dbtest
import (
"fmt"
"sort"
"sync"
"testing"
"github.com/stretchr/testify/require"
dbm "github.com/cosmos/cosmos-sdk/db"
)
type Loader func(*testing.T, string) dbm.DBConnection
func ikey(i int) []byte { return []byte(fmt.Sprintf("key-%03d", i)) }
func ival(i int) []byte { return []byte(fmt.Sprintf("val-%03d", i)) }
func DoTestGetSetHasDelete(t *testing.T, load Loader) {
t.Helper()
db := load(t, t.TempDir())
var txn dbm.DBReadWriter
var view dbm.DBReader
view = db.Reader()
require.NotNil(t, view)
// A nonexistent key should return nil.
value, err := view.Get([]byte("a"))
require.NoError(t, err)
require.Nil(t, value)
ok, err := view.Has([]byte("a"))
require.NoError(t, err)
require.False(t, ok)
txn = db.ReadWriter()
// Set and get a value.
err = txn.Set([]byte("a"), []byte{0x01})
require.NoError(t, err)
ok, err = txn.Has([]byte("a"))
require.NoError(t, err)
require.True(t, ok)
value, err = txn.Get([]byte("a"))
require.NoError(t, err)
require.Equal(t, []byte{0x01}, value)
// New value is not visible from another txn.
ok, err = view.Has([]byte("a"))
require.NoError(t, err)
require.False(t, ok)
// Deleting a non-existent value is fine.
err = txn.Delete([]byte("x"))
require.NoError(t, err)
// Delete a value.
err = txn.Delete([]byte("a"))
require.NoError(t, err)
value, err = txn.Get([]byte("a"))
require.NoError(t, err)
require.Nil(t, value)
err = txn.Set([]byte("b"), []byte{0x02})
require.NoError(t, err)
view.Discard()
require.NoError(t, txn.Commit())
txn = db.ReadWriter()
// Verify committed values.
value, err = txn.Get([]byte("b"))
require.NoError(t, err)
require.Equal(t, []byte{0x02}, value)
ok, err = txn.Has([]byte("a"))
require.NoError(t, err)
require.False(t, ok)
// Setting, getting, and deleting an empty key should error.
_, err = txn.Get([]byte{})
require.Equal(t, dbm.ErrKeyEmpty, err)
_, err = txn.Get(nil)
require.Equal(t, dbm.ErrKeyEmpty, err)
_, err = txn.Has([]byte{})
require.Equal(t, dbm.ErrKeyEmpty, err)
_, err = txn.Has(nil)
require.Equal(t, dbm.ErrKeyEmpty, err)
err = txn.Set([]byte{}, []byte{0x01})
require.Equal(t, dbm.ErrKeyEmpty, err)
err = txn.Set(nil, []byte{0x01})
require.Equal(t, dbm.ErrKeyEmpty, err)
err = txn.Delete([]byte{})
require.Equal(t, dbm.ErrKeyEmpty, err)
err = txn.Delete(nil)
require.Equal(t, dbm.ErrKeyEmpty, err)
// Setting a nil value should error, but an empty value is fine.
err = txn.Set([]byte("x"), nil)
require.Equal(t, dbm.ErrValueNil, err)
err = txn.Set([]byte("x"), []byte{})
require.NoError(t, err)
value, err = txn.Get([]byte("x"))
require.NoError(t, err)
require.Equal(t, []byte{}, value)
require.NoError(t, txn.Commit())
require.NoError(t, db.Close())
}
func DoTestIterators(t *testing.T, load Loader) {
t.Helper()
db := load(t, t.TempDir())
type entry struct {
key []byte
val string
}
entries := []entry{
{[]byte{0}, "0"},
{[]byte{0, 0}, "0 0"},
{[]byte{0, 1}, "0 1"},
{[]byte{0, 2}, "0 2"},
{[]byte{1}, "1"},
}
txn := db.ReadWriter()
for _, e := range entries {
require.NoError(t, txn.Set(e.key, []byte(e.val)))
}
require.NoError(t, txn.Commit())
testRange := func(t *testing.T, iter dbm.Iterator, expected []string) {
i := 0
for ; iter.Next(); i++ {
expectedValue := expected[i]
value := iter.Value()
require.EqualValues(t, expectedValue, string(value), "i=%v", i)
}
require.Equal(t, len(expected), i)
}
type testCase struct {
start, end []byte
expected []string
}
view := db.Reader()
iterCases := []testCase{
{nil, nil, []string{"0", "0 0", "0 1", "0 2", "1"}},
{[]byte{0x00}, nil, []string{"0", "0 0", "0 1", "0 2", "1"}},
{[]byte{0x00}, []byte{0x00, 0x01}, []string{"0", "0 0"}},
{[]byte{0x00}, []byte{0x01}, []string{"0", "0 0", "0 1", "0 2"}},
{[]byte{0x00, 0x01}, []byte{0x01}, []string{"0 1", "0 2"}},
{nil, []byte{0x01}, []string{"0", "0 0", "0 1", "0 2"}},
}
for i, tc := range iterCases {
t.Logf("Iterator case %d: [%v, %v)", i, tc.start, tc.end)
it, err := view.Iterator(tc.start, tc.end)
require.NoError(t, err)
testRange(t, it, tc.expected)
it.Close()
}
reverseCases := []testCase{
{nil, nil, []string{"1", "0 2", "0 1", "0 0", "0"}},
{[]byte{0x00}, nil, []string{"1", "0 2", "0 1", "0 0", "0"}},
{[]byte{0x00}, []byte{0x00, 0x01}, []string{"0 0", "0"}},
{[]byte{0x00}, []byte{0x01}, []string{"0 2", "0 1", "0 0", "0"}},
{[]byte{0x00, 0x01}, []byte{0x01}, []string{"0 2", "0 1"}},
{nil, []byte{0x01}, []string{"0 2", "0 1", "0 0", "0"}},
}
for i, tc := range reverseCases {
t.Logf("ReverseIterator case %d: [%v, %v)", i, tc.start, tc.end)
it, err := view.ReverseIterator(tc.start, tc.end)
require.NoError(t, err)
testRange(t, it, tc.expected)
it.Close()
}
view.Discard()
require.NoError(t, db.Close())
}
func DoTestVersioning(t *testing.T, load Loader) {
t.Helper()
db := load(t, t.TempDir())
view := db.Reader()
require.NotNil(t, view)
// Write, then read different versions
txn := db.ReadWriter()
require.NoError(t, txn.Set([]byte("0"), []byte("a")))
require.NoError(t, txn.Set([]byte("1"), []byte("b")))
require.NoError(t, txn.Commit())
v1, err := db.SaveNextVersion()
require.NoError(t, err)
txn = db.ReadWriter()
require.NoError(t, txn.Set([]byte("0"), []byte("c")))
require.NoError(t, txn.Delete([]byte("1")))
require.NoError(t, txn.Set([]byte("2"), []byte("c")))
require.NoError(t, txn.Commit())
v2, err := db.SaveNextVersion()
require.NoError(t, err)
// Skip to a future version
v3 := (v2 + 2)
require.NoError(t, db.SaveVersion(v3))
// Try to save to a past version
err = db.SaveVersion(v2)
require.Error(t, err)
// Verify existing versions
versions, err := db.Versions()
require.NoError(t, err)
require.Equal(t, 3, versions.Count())
var all []uint64
for it := versions.Iterator(); it.Next(); {
all = append(all, it.Value())
}
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
require.Equal(t, []uint64{v1, v2, v3}, all)
require.Equal(t, v3, versions.Last())
view, err = db.ReaderAt(v1)
require.NoError(t, err)
require.NotNil(t, view)
val, err := view.Get([]byte("0"))
require.Equal(t, []byte("a"), val)
require.NoError(t, err)
val, err = view.Get([]byte("1"))
require.Equal(t, []byte("b"), val)
require.NoError(t, err)
has, err := view.Has([]byte("2"))
require.False(t, has)
view, err = db.ReaderAt(v2)
require.NoError(t, err)
require.NotNil(t, view)
val, err = view.Get([]byte("0"))
require.Equal(t, []byte("c"), val)
require.NoError(t, err)
val, err = view.Get([]byte("2"))
require.Equal(t, []byte("c"), val)
require.NoError(t, err)
has, err = view.Has([]byte("1"))
require.False(t, has)
// Try to read an invalid version
view, err = db.ReaderAt(versions.Last() + 1)
require.Equal(t, dbm.ErrVersionDoesNotExist, err)
require.NoError(t, db.DeleteVersion(v2))
// Try to read a deleted version
view, err = db.ReaderAt(v2)
require.Equal(t, dbm.ErrVersionDoesNotExist, err)
// Ensure latest version is accurate
prev := v3
for i := 0; i < 10; i++ {
w := db.Writer()
require.NoError(t, w.Set(ikey(i), ival(i)))
require.NoError(t, w.Commit())
ver, err := db.SaveNextVersion()
require.NoError(t, err)
require.Equal(t, prev+1, ver)
versions, err := db.Versions()
require.NoError(t, err)
require.Equal(t, ver, versions.Last())
prev = ver
}
require.NoError(t, db.Close())
}
func DoTestTransactions(t *testing.T, load Loader, multipleWriters bool) {
t.Helper()
db := load(t, t.TempDir())
// Both methods should work in a DBWriter context
writerFuncs := []func() dbm.DBWriter{
db.Writer,
func() dbm.DBWriter { return db.ReadWriter() },
}
for _, getWriter := range writerFuncs {
// Uncommitted records are not saved
t.Run("no commit", func(t *testing.T) {
t.Helper()
view := db.Reader()
defer view.Discard()
tx := getWriter()
defer tx.Discard()
require.NoError(t, tx.Set([]byte("0"), []byte("a")))
v, err := view.Get([]byte("0"))
require.NoError(t, err)
require.Nil(t, v)
})
// Try to commit version with open txns
t.Run("open transactions", func(t *testing.T) {
t.Helper()
tx := getWriter()
tx.Set([]byte("2"), []byte("a"))
_, err := db.SaveNextVersion()
require.Equal(t, dbm.ErrOpenTransactions, err)
tx.Discard()
})
// Continue only if the backend supports multiple concurrent writers
if !multipleWriters {
continue
}
// Writing separately to same key causes a conflict
t.Run("write conflict", func(t *testing.T) {
t.Helper()
tx1 := getWriter()
tx2 := db.ReadWriter()
tx2.Get([]byte("1"))
require.NoError(t, tx1.Set([]byte("1"), []byte("b")))
require.NoError(t, tx2.Set([]byte("1"), []byte("c")))
require.NoError(t, tx1.Commit())
require.Error(t, tx2.Commit())
})
// Writing from concurrent txns
t.Run("concurrent transactions", func(t *testing.T) {
t.Helper()
var wg sync.WaitGroup
setkv := func(k, v []byte) {
defer wg.Done()
tx := getWriter()
require.NoError(t, tx.Set(k, v))
require.NoError(t, tx.Commit())
}
n := 10
wg.Add(n)
for i := 0; i < n; i++ {
go setkv(ikey(i), ival(i))
}
wg.Wait()
view := db.Reader()
defer view.Discard()
v, err := view.Get(ikey(0))
require.NoError(t, err)
require.Equal(t, ival(0), v)
})
}
require.NoError(t, db.Close())
}
// Tests reloading a saved DB from disk.
func DoTestReloadDB(t *testing.T, load Loader) {
t.Helper()
dirname := t.TempDir()
db := load(t, dirname)
txn := db.Writer()
for i := 0; i < 100; i++ {
require.NoError(t, txn.Set(ikey(i), ival(i)))
}
require.NoError(t, txn.Commit())
first, err := db.SaveNextVersion()
require.NoError(t, err)
txn = db.Writer()
for i := 0; i < 50; i++ { // overwrite some values
require.NoError(t, txn.Set(ikey(i), ival(i*10)))
}
require.NoError(t, txn.Commit())
last, err := db.SaveNextVersion()
require.NoError(t, err)
txn = db.Writer()
for i := 100; i < 150; i++ {
require.NoError(t, txn.Set(ikey(i), ival(i)))
}
require.NoError(t, txn.Commit())
db.Close()
// Reload and check each saved version
db = load(t, dirname)
view, err := db.ReaderAt(first)
require.NoError(t, err)
for i := 0; i < 100; i++ {
v, err := view.Get(ikey(i))
require.NoError(t, err)
require.Equal(t, ival(i), v)
}
view.Discard()
view, err = db.ReaderAt(last)
require.NoError(t, err)
for i := 0; i < 50; i++ {
v, err := view.Get(ikey(i))
require.NoError(t, err)
require.Equal(t, ival(i*10), v)
}
for i := 50; i < 100; i++ {
v, err := view.Get(ikey(i))
require.NoError(t, err)
require.Equal(t, ival(i), v)
}
view.Discard()
// Load working version
view = db.Reader()
for i := 100; i < 150; i++ {
v, err := view.Get(ikey(i))
require.NoError(t, err)
require.Equal(t, ival(i), v)
}
view.Discard()
require.NoError(t, db.Close())
}

View File

@ -3,5 +3,6 @@ go 1.15
module github.com/cosmos/cosmos-sdk/db
require (
github.com/google/btree v1.0.0
github.com/stretchr/testify v1.7.0
)

View File

@ -1,5 +1,7 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

274
db/memdb/db.go Normal file
View File

@ -0,0 +1,274 @@
package memdb
import (
"bytes"
"fmt"
"sync"
"sync/atomic"
dbm "github.com/cosmos/cosmos-sdk/db"
"github.com/google/btree"
)
const (
// The approximate number of items and children per B-tree node. Tuned with benchmarks.
bTreeDegree = 32
)
// MemDB is an in-memory database backend using a B-tree for storage.
//
// For performance reasons, all given and returned keys and values are pointers to the in-memory
// database, so modifying them will cause the stored values to be modified as well. All DB methods
// already specify that keys and values should be considered read-only, but this is especially
// important with MemDB.
//
// Versioning is implemented by maintaining references to copy-on-write clones of the backing btree.
//
// TODO: Currently transactions do not detect write conflicts, so writers cannot be used concurrently.
type MemDB struct {
btree *btree.BTree // Main contents
mtx sync.RWMutex // Guards version history
saved map[uint64]*btree.BTree // Past versions
vmgr *dbm.VersionManager // Mirrors version keys
openWriters int32 // Open writers
}
type dbTxn struct {
btree *btree.BTree
db *MemDB
}
type dbWriter struct{ dbTxn }
var (
_ dbm.DBConnection = (*MemDB)(nil)
_ dbm.DBReader = (*dbTxn)(nil)
_ dbm.DBWriter = (*dbWriter)(nil)
_ dbm.DBReadWriter = (*dbWriter)(nil)
)
// item is a btree.Item with byte slices as keys and values
type item struct {
key []byte
value []byte
}
// NewDB creates a new in-memory database.
func NewDB() *MemDB {
return &MemDB{
btree: btree.New(bTreeDegree),
saved: make(map[uint64]*btree.BTree),
vmgr: dbm.NewVersionManager(nil),
}
}
func (db *MemDB) newTxn(tree *btree.BTree) dbTxn {
return dbTxn{tree, db}
}
// Close implements DB.
// Close is a noop since for an in-memory database, we don't have a destination to flush
// contents to nor do we want any data loss on invoking Close().
// See the discussion in https://github.com/tendermint/tendermint/libs/pull/56
func (db *MemDB) Close() error {
return nil
}
// Versions implements DBConnection.
func (db *MemDB) Versions() (dbm.VersionSet, error) {
db.mtx.RLock()
defer db.mtx.RUnlock()
return db.vmgr, nil
}
// Reader implements DBConnection.
func (db *MemDB) Reader() dbm.DBReader {
db.mtx.RLock()
defer db.mtx.RUnlock()
ret := db.newTxn(db.btree)
return &ret
}
// ReaderAt implements DBConnection.
func (db *MemDB) ReaderAt(version uint64) (dbm.DBReader, error) {
db.mtx.RLock()
defer db.mtx.RUnlock()
tree, ok := db.saved[version]
if !ok {
return nil, dbm.ErrVersionDoesNotExist
}
ret := db.newTxn(tree)
return &ret, nil
}
// Writer implements DBConnection.
func (db *MemDB) Writer() dbm.DBWriter {
return db.ReadWriter()
}
// ReadWriter implements DBConnection.
func (db *MemDB) ReadWriter() dbm.DBReadWriter {
db.mtx.RLock()
defer db.mtx.RUnlock()
atomic.AddInt32(&db.openWriters, 1)
// Clone creates a copy-on-write extension of the current tree
return &dbWriter{db.newTxn(db.btree.Clone())}
}
func (db *MemDB) save(target uint64) (uint64, error) {
db.mtx.Lock()
defer db.mtx.Unlock()
if db.openWriters > 0 {
return 0, dbm.ErrOpenTransactions
}
newVmgr := db.vmgr.Copy()
target, err := newVmgr.Save(target)
if err != nil {
return 0, err
}
db.saved[target] = db.btree
db.vmgr = newVmgr
return target, nil
}
// SaveVersion implements DBConnection.
func (db *MemDB) SaveNextVersion() (uint64, error) {
return db.save(0)
}
// SaveNextVersion implements DBConnection.
func (db *MemDB) SaveVersion(target uint64) error {
if target == 0 {
return dbm.ErrInvalidVersion
}
_, err := db.save(target)
return err
}
// DeleteVersion implements DBConnection.
func (db *MemDB) DeleteVersion(target uint64) error {
db.mtx.Lock()
defer db.mtx.Unlock()
if _, has := db.saved[target]; !has {
return dbm.ErrVersionDoesNotExist
}
delete(db.saved, target)
db.vmgr = db.vmgr.Copy()
db.vmgr.Delete(target)
return nil
}
// Get implements DBReader.
func (tx *dbTxn) Get(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, dbm.ErrKeyEmpty
}
i := tx.btree.Get(newKey(key))
if i != nil {
return i.(*item).value, nil
}
return nil, nil
}
// Has implements DBReader.
func (tx *dbTxn) Has(key []byte) (bool, error) {
if len(key) == 0 {
return false, dbm.ErrKeyEmpty
}
return tx.btree.Has(newKey(key)), nil
}
// Set implements DBWriter.
func (tx *dbWriter) Set(key []byte, value []byte) error {
if len(key) == 0 {
return dbm.ErrKeyEmpty
}
if value == nil {
return dbm.ErrValueNil
}
tx.btree.ReplaceOrInsert(newPair(key, value))
return nil
}
// Delete implements DBWriter.
func (tx *dbWriter) Delete(key []byte) error {
if len(key) == 0 {
return dbm.ErrKeyEmpty
}
tx.btree.Delete(newKey(key))
return nil
}
// Iterator implements DBReader.
// Takes out a read-lock on the database until the iterator is closed.
func (tx *dbTxn) Iterator(start, end []byte) (dbm.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, dbm.ErrKeyEmpty
}
return newMemDBIterator(tx, start, end, false), nil
}
// ReverseIterator implements DBReader.
// Takes out a read-lock on the database until the iterator is closed.
func (tx *dbTxn) ReverseIterator(start, end []byte) (dbm.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, dbm.ErrKeyEmpty
}
return newMemDBIterator(tx, start, end, true), nil
}
// Commit implements DBWriter.
func (tx *dbWriter) Commit() error {
tx.db.mtx.Lock()
defer tx.db.mtx.Unlock()
defer tx.Discard()
tx.db.btree = tx.btree
return nil
}
// Discard implements DBReader and DBWriter.
func (tx *dbTxn) Discard() {}
func (tx *dbWriter) Discard() {
atomic.AddInt32(&tx.db.openWriters, -1)
}
// Print prints the database contents.
func (db *MemDB) Print() error {
db.mtx.RLock()
defer db.mtx.RUnlock()
db.btree.Ascend(func(i btree.Item) bool {
item := i.(*item)
fmt.Printf("[%X]:\t[%X]\n", item.key, item.value)
return true
})
return nil
}
// Stats implements DBConnection.
func (db *MemDB) Stats() map[string]string {
db.mtx.RLock()
defer db.mtx.RUnlock()
stats := make(map[string]string)
stats["database.type"] = "memDB"
stats["database.size"] = fmt.Sprintf("%d", db.btree.Len())
return stats
}
// Less implements btree.Item.
func (i *item) Less(other btree.Item) bool {
// this considers nil == []byte{}, but that's ok since we handle nil endpoints
// in iterators specially anyway
return bytes.Compare(i.key, other.(*item).key) == -1
}
// newKey creates a new key item.
func newKey(key []byte) *item {
return &item{key: key}
}
// newPair creates a new pair item.
func newPair(key, value []byte) *item {
return &item{key: key, value: value}
}

49
db/memdb/db_test.go Normal file
View File

@ -0,0 +1,49 @@
package memdb
import (
"testing"
dbm "github.com/cosmos/cosmos-sdk/db"
"github.com/cosmos/cosmos-sdk/db/dbtest"
)
func BenchmarkMemDBRangeScans1M(b *testing.B) {
db := NewDB()
defer db.Close()
dbtest.BenchmarkRangeScans(b, db.ReadWriter(), int64(1e6))
}
func BenchmarkMemDBRangeScans10M(b *testing.B) {
db := NewDB()
defer db.Close()
dbtest.BenchmarkRangeScans(b, db.ReadWriter(), int64(10e6))
}
func BenchmarkMemDBRandomReadsWrites(b *testing.B) {
db := NewDB()
defer db.Close()
dbtest.BenchmarkRandomReadsWrites(b, db.ReadWriter())
}
func load(t *testing.T, _ string) dbm.DBConnection {
return NewDB()
}
func TestGetSetHasDelete(t *testing.T) {
dbtest.DoTestGetSetHasDelete(t, load)
}
func TestIterators(t *testing.T) {
dbtest.DoTestIterators(t, load)
}
func TestVersioning(t *testing.T) {
dbtest.DoTestVersioning(t, load)
}
func TestTransactions(t *testing.T) {
dbtest.DoTestTransactions(t, load, false)
}

139
db/memdb/iterator.go Normal file
View File

@ -0,0 +1,139 @@
package memdb
import (
"bytes"
"context"
tmdb "github.com/cosmos/cosmos-sdk/db"
"github.com/google/btree"
)
const (
// Size of the channel buffer between traversal goroutine and iterator. Using an unbuffered
// channel causes two context switches per item sent, while buffering allows more work per
// context switch. Tuned with benchmarks.
chBufferSize = 64
)
// memDBIterator is a memDB iterator.
type memDBIterator struct {
ch <-chan *item
cancel context.CancelFunc
item *item
start []byte
end []byte
}
var _ tmdb.Iterator = (*memDBIterator)(nil)
// newMemDBIterator creates a new memDBIterator.
// A visitor is passed to the btree which streams items to the iterator over a channel. Advancing
// the iterator pulls items from the channel, returning execution to the visitor.
// The reverse case needs some special handling, since we use [start, end) while btree uses (start, end]
func newMemDBIterator(tx *dbTxn, start []byte, end []byte, reverse bool) *memDBIterator {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan *item, chBufferSize)
iter := &memDBIterator{
ch: ch,
cancel: cancel,
start: start,
end: end,
}
go func() {
defer close(ch)
// Because we use [start, end) for reverse ranges, while btree uses (start, end], we need
// the following variables to handle some reverse iteration conditions ourselves.
var (
skipEqual []byte
abortLessThan []byte
)
visitor := func(i btree.Item) bool {
item := i.(*item)
if skipEqual != nil && bytes.Equal(item.key, skipEqual) {
skipEqual = nil
return true
}
if abortLessThan != nil && bytes.Compare(item.key, abortLessThan) == -1 {
return false
}
select {
case <-ctx.Done():
return false
case ch <- item:
return true
}
}
switch {
case start == nil && end == nil && !reverse:
tx.btree.Ascend(visitor)
case start == nil && end == nil && reverse:
tx.btree.Descend(visitor)
case end == nil && !reverse:
// must handle this specially, since nil is considered less than anything else
tx.btree.AscendGreaterOrEqual(newKey(start), visitor)
case !reverse:
tx.btree.AscendRange(newKey(start), newKey(end), visitor)
case end == nil:
// abort after start, since we use [start, end) while btree uses (start, end]
abortLessThan = start
tx.btree.Descend(visitor)
default:
// skip end and abort after start, since we use [start, end) while btree uses (start, end]
skipEqual = end
abortLessThan = start
tx.btree.DescendLessOrEqual(newKey(end), visitor)
}
}()
return iter
}
// Close implements Iterator.
func (i *memDBIterator) Close() error {
i.cancel()
for range i.ch { // drain channel
}
i.item = nil
return nil
}
// Domain implements Iterator.
func (i *memDBIterator) Domain() ([]byte, []byte) {
return i.start, i.end
}
// Next implements Iterator.
func (i *memDBIterator) Next() bool {
item, ok := <-i.ch
switch {
case ok:
i.item = item
default:
i.item = nil
}
return i.item != nil
}
// Error implements Iterator.
func (i *memDBIterator) Error() error {
return nil
}
// Key implements Iterator.
func (i *memDBIterator) Key() []byte {
i.assertIsValid()
return i.item.key
}
// Value implements Iterator.
func (i *memDBIterator) Value() []byte {
i.assertIsValid()
return i.item.value
}
func (i *memDBIterator) assertIsValid() {
if i.item == nil {
panic("iterator is invalid")
}
}

View File

@ -44,7 +44,7 @@ type DBConnection interface {
// Opens a write-only transaction at the current version.
Writer() DBWriter
// Returns all saved versions
// Returns all saved versions as an immutable set which is safe for concurrent access.
Versions() (VersionSet, error)
// Saves the current contents of the database and returns the next version ID, which will be

135
db/version_manager.go Normal file
View File

@ -0,0 +1,135 @@
package db
import (
"fmt"
)
// VersionManager encapsulates the current valid versions of a DB and computes
// the next version.
type VersionManager struct {
versions map[uint64]struct{}
initial, last uint64
}
var _ VersionSet = (*VersionManager)(nil)
// NewVersionManager creates a VersionManager from a sorted slice of ascending version ids.
func NewVersionManager(versions []uint64) *VersionManager {
vmap := make(map[uint64]struct{})
var init, last uint64
for _, ver := range versions {
vmap[ver] = struct{}{}
}
if len(versions) > 0 {
init = versions[0]
last = versions[len(versions)-1]
}
return &VersionManager{versions: vmap, initial: init, last: last}
}
// Exists implements VersionSet.
func (vm *VersionManager) Exists(version uint64) bool {
_, has := vm.versions[version]
return has
}
// Last implements VersionSet.
func (vm *VersionManager) Last() uint64 {
return vm.last
}
func (vm *VersionManager) Initial() uint64 {
return vm.initial
}
func (vm *VersionManager) Next() uint64 {
return vm.Last() + 1
}
func (vm *VersionManager) Save(target uint64) (uint64, error) {
next := vm.Next()
if target == 0 {
target = next
}
if target < next {
return 0, fmt.Errorf(
"target version cannot be less than next sequential version (%v < %v)", target, next)
}
if _, has := vm.versions[target]; has {
return 0, fmt.Errorf("version exists: %v", target)
}
vm.versions[target] = struct{}{}
vm.last = target
if len(vm.versions) == 1 {
vm.initial = target
}
return target, nil
}
func findLimit(m map[uint64]struct{}, cmp func(uint64, uint64) bool, init uint64) uint64 {
for x, _ := range m {
if cmp(x, init) {
init = x
}
}
return init
}
func (vm *VersionManager) Delete(target uint64) {
delete(vm.versions, target)
if target == vm.last {
vm.last = findLimit(vm.versions, func(x, max uint64) bool { return x > max }, 0)
}
if target == vm.initial {
vm.initial = findLimit(vm.versions, func(x, min uint64) bool { return x < min }, vm.last)
}
}
type vmIterator struct {
ch <-chan uint64
open bool
buf uint64
}
func (vi *vmIterator) Next() bool {
vi.buf, vi.open = <-vi.ch
return vi.open
}
func (vi *vmIterator) Value() uint64 { return vi.buf }
// Iterator implements VersionSet.
func (vm *VersionManager) Iterator() VersionIterator {
ch := make(chan uint64)
go func() {
for ver, _ := range vm.versions {
ch <- ver
}
close(ch)
}()
return &vmIterator{ch: ch}
}
// Count implements VersionSet.
func (vm *VersionManager) Count() int { return len(vm.versions) }
// Equal implements VersionSet.
func (vm *VersionManager) Equal(that VersionSet) bool {
if vm.Count() != that.Count() {
return false
}
for it := that.Iterator(); it.Next(); {
if !vm.Exists(it.Value()) {
return false
}
}
return true
}
func (vm *VersionManager) Copy() *VersionManager {
vmap := make(map[uint64]struct{})
for ver, _ := range vm.versions {
vmap[ver] = struct{}{}
}
return &VersionManager{versions: vmap, initial: vm.initial, last: vm.last}
}

View File

@ -0,0 +1,59 @@
package db_test
import (
"sort"
"testing"
"github.com/stretchr/testify/require"
dbm "github.com/cosmos/cosmos-sdk/db"
)
// Test that VersionManager satisfies the behavior of VersionSet
func TestVersionManager(t *testing.T) {
vm := dbm.NewVersionManager(nil)
require.Equal(t, uint64(0), vm.Last())
require.Equal(t, 0, vm.Count())
require.True(t, vm.Equal(vm))
require.False(t, vm.Exists(0))
id, err := vm.Save(0)
require.NoError(t, err)
require.Equal(t, uint64(1), id)
require.True(t, vm.Exists(id))
id2, err := vm.Save(0)
require.NoError(t, err)
require.True(t, vm.Exists(id2))
id3, err := vm.Save(0)
require.NoError(t, err)
require.True(t, vm.Exists(id3))
id, err = vm.Save(id) // can't save existing id
require.Error(t, err)
id, err = vm.Save(0)
require.NoError(t, err)
require.True(t, vm.Exists(id))
vm.Delete(id)
require.False(t, vm.Exists(id))
vm.Delete(1)
require.False(t, vm.Exists(1))
require.Equal(t, id2, vm.Initial())
require.Equal(t, id3, vm.Last())
var all []uint64
for it := vm.Iterator(); it.Next(); {
all = append(all, it.Value())
}
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
require.Equal(t, []uint64{id2, id3}, all)
vmc := vm.Copy()
id, err = vmc.Save(0)
require.NoError(t, err)
require.False(t, vm.Exists(id)) // true copy is made
vm2 := dbm.NewVersionManager([]uint64{id2, id3})
require.True(t, vm.Equal(vm2))
}