diff --git a/.gitignore b/.gitignore index e0a06eaf..a2ebfde2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -*.swp +*.sw[opqr] vendor .glide diff --git a/CHANGELOG.md b/CHANGELOG.md index fe2c2fe9..89b841d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,37 @@ # Changelog +## 0.7.0 (TBD) + +BREAKING: + + - [db] Major API upgrade. See `db/types.go`. + - [common] added `Quit() <-chan struct{}` to Service interface. + The returned channel is closed when service is stopped. + - [common] Remove HTTP functions + - [common] Heap.Push takes an `int`, new Heap.PushComparable takes the comparable. + - [logger] Removed. Use `log` + - [merkle] Major API updade - uses cmn.KVPairs. + - [cli] WriteDemoConfig -> WriteConfigValues + - [all] Remove go-wire dependency! + +FEATURES: + + - [db] New FSDB that uses the filesystem directly + - [common] HexBytes + - [common] KVPair and KI64Pair (protobuf based key-value pair objects) + +IMPROVEMENTS: + + - [clist] add WaitChan() to CList, NextWaitChan() and PrevWaitChan() + to CElement. These can be used instead of blocking `*Wait()` methods + if you need to be able to send quit signal and not block forever + - [common] IsHex handles 0x-prefix + +BUG FIXES: + + - [common] BitArray check for nil arguments + - [common] Fix memory leak in RepeatTimer + ## 0.6.0 (December 29, 2017) BREAKING: diff --git a/Makefile b/Makefile index dd4711aa..ae2c7161 100644 --- a/Makefile +++ b/Makefile @@ -1,57 +1,121 @@ -.PHONY: all test get_vendor_deps ensure_tools - GOTOOLS = \ github.com/Masterminds/glide \ - github.com/alecthomas/gometalinter + github.com/gogo/protobuf/protoc-gen-gogo \ + github.com/gogo/protobuf/gogoproto + # github.com/alecthomas/gometalinter.v2 \ -PACKAGES=$(shell go list ./... | grep -v '/vendor/') -REPO:=github.com/tendermint/tmlibs +GOTOOLS_CHECK = glide gometalinter.v2 protoc protoc-gen-gogo +INCLUDE = -I=. -I=${GOPATH}/src -I=${GOPATH}/src/github.com/gogo/protobuf/protobuf -all: test +all: check get_vendor_deps protoc build test install metalinter -test: - @echo "--> Running linter" - @make metalinter_test - @echo "--> Running go test" - @go test $(PACKAGES) +check: check_tools -get_vendor_deps: ensure_tools +######################################## +### Build + +protoc: + ## If you get the following error, + ## "error while loading shared libraries: libprotobuf.so.14: cannot open shared object file: No such file or directory" + ## See https://stackoverflow.com/a/25518702 + protoc $(INCLUDE) --gogo_out=plugins=grpc:. common/*.proto + @echo "--> adding nolint declarations to protobuf generated files" + @awk '/package common/ { print "//nolint: gas"; print; next }1' common/types.pb.go > common/types.pb.go.new + @mv common/types.pb.go.new common/types.pb.go + +build: + # Nothing to build! + +install: + # Nothing to install! + + +######################################## +### Tools & dependencies + +check_tools: + @# https://stackoverflow.com/a/25668869 + @echo "Found tools: $(foreach tool,$(GOTOOLS_CHECK),\ + $(if $(shell which $(tool)),$(tool),$(error "No $(tool) in PATH")))" + +get_tools: + @echo "--> Installing tools" + go get -u -v $(GOTOOLS) + # @gometalinter.v2 --install + +get_protoc: + @# https://github.com/google/protobuf/releases + curl -L https://github.com/google/protobuf/releases/download/v3.4.1/protobuf-cpp-3.4.1.tar.gz | tar xvz && \ + cd protobuf-3.4.1 && \ + DIST_LANG=cpp ./configure && \ + make && \ + make install && \ + cd .. && \ + rm -rf protobuf-3.4.1 + +update_tools: + @echo "--> Updating tools" + @go get -u $(GOTOOLS) + +get_vendor_deps: @rm -rf vendor/ @echo "--> Running glide install" @glide install -ensure_tools: - go get $(GOTOOLS) - @gometalinter --install -metalinter: - gometalinter --vendor --deadline=600s --enable-all --disable=lll ./... +######################################## +### Testing -metalinter_test: - gometalinter --vendor --deadline=600s --disable-all \ +test: + go test -tags gcc `glide novendor` + +test100: + @for i in {1..100}; do make test; done + + +######################################## +### Formatting, linting, and vetting + +fmt: + @go fmt ./... + +metalinter: + @echo "==> Running linter" + gometalinter.v2 --vendor --deadline=600s --disable-all \ --enable=deadcode \ --enable=goconst \ + --enable=goimports \ --enable=gosimple \ - --enable=ineffassign \ - --enable=interfacer \ + --enable=ineffassign \ --enable=megacheck \ - --enable=misspell \ - --enable=staticcheck \ + --enable=misspell \ + --enable=staticcheck \ --enable=safesql \ - --enable=structcheck \ - --enable=unconvert \ + --enable=structcheck \ + --enable=unconvert \ --enable=unused \ - --enable=varcheck \ + --enable=varcheck \ --enable=vetshadow \ - --enable=vet \ ./... + #--enable=maligned \ #--enable=gas \ #--enable=aligncheck \ #--enable=dupl \ #--enable=errcheck \ #--enable=gocyclo \ - #--enable=goimports \ #--enable=golint \ <== comments on anything exported #--enable=gotype \ - #--enable=unparam \ + #--enable=interfacer \ + #--enable=unparam \ + #--enable=vet \ + +metalinter_all: + protoc $(INCLUDE) --lint_out=. types/*.proto + gometalinter.v2 --vendor --deadline=600s --enable-all --disable=lll ./... + + +# To avoid unintended conflicts with file names, always add to .PHONY +# unless there is a reason not to. +# https://www.gnu.org/software/make/manual/html_node/Phony-Targets.html +.PHONY: check protoc build check_tools get_tools get_protoc update_tools get_vendor_deps test fmt metalinter metalinter_all diff --git a/README.md b/README.md index d5a11c7b..9ea618db 100644 --- a/README.md +++ b/README.md @@ -36,10 +36,6 @@ Flowrate is a fork of https://github.com/mxk/go-flowrate that added a `SetREMA` Log is a log package structured around key-value pairs that allows logging level to be set differently for different keys. -## logger - -Logger is DEPRECATED. It's a simple wrapper around `log15`. - ## merkle Merkle provides a simple static merkle tree and corresponding proofs. diff --git a/circle.yml b/circle.yml index 104cfa6f..390ffb03 100644 --- a/circle.yml +++ b/circle.yml @@ -15,7 +15,7 @@ dependencies: test: override: - - cd $PROJECT_PATH && make get_vendor_deps && bash ./test.sh + - cd $PROJECT_PATH && make get_tools && make get_vendor_deps && bash ./test.sh post: - cd "$PROJECT_PATH" && bash <(curl -s https://codecov.io/bash) -f coverage.txt - cd "$PROJECT_PATH" && mv coverage.txt "${CIRCLE_ARTIFACTS}" diff --git a/cli/helper.go b/cli/helper.go index 845c17db..878cf26e 100644 --- a/cli/helper.go +++ b/cli/helper.go @@ -9,21 +9,15 @@ import ( "path/filepath" ) -// WriteDemoConfig writes a toml file with the given values. -// It returns the RootDir the config.toml file is stored in, -// or an error if writing was impossible -func WriteDemoConfig(vals map[string]string) (string, error) { - cdir, err := ioutil.TempDir("", "test-cli") - if err != nil { - return "", err - } +// WriteConfigVals writes a toml file with the given values. +// It returns an error if writing was impossible. +func WriteConfigVals(dir string, vals map[string]string) error { data := "" for k, v := range vals { data = data + fmt.Sprintf("%s = \"%s\"\n", k, v) } - cfile := filepath.Join(cdir, "config.toml") - err = ioutil.WriteFile(cfile, []byte(data), 0666) - return cdir, err + cfile := filepath.Join(dir, "config.toml") + return ioutil.WriteFile(cfile, []byte(data), 0666) } // RunWithArgs executes the given command with the specified command line args diff --git a/cli/setup.go b/cli/setup.go index 29547759..dc34abdf 100644 --- a/cli/setup.go +++ b/cli/setup.go @@ -3,14 +3,12 @@ package cli import ( "fmt" "os" + "path/filepath" "strings" "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/spf13/viper" - - data "github.com/tendermint/go-wire/data" - "github.com/tendermint/go-wire/data/base58" ) const ( @@ -42,7 +40,7 @@ func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor func PrepareMainCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor { cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)") cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)") - cmd.PersistentPreRunE = concatCobraCmdFuncs(setEncoding, validateOutput, cmd.PersistentPreRunE) + cmd.PersistentPreRunE = concatCobraCmdFuncs(validateOutput, cmd.PersistentPreRunE) return PrepareBaseCmd(cmd, envPrefix, defaultHome) } @@ -132,8 +130,9 @@ func bindFlagsLoadViper(cmd *cobra.Command, args []string) error { homeDir := viper.GetString(HomeFlag) viper.Set(HomeFlag, homeDir) - viper.SetConfigName("config") // name of config file (without extension) - viper.AddConfigPath(homeDir) // search root directory + viper.SetConfigName("config") // name of config file (without extension) + viper.AddConfigPath(homeDir) // search root directory + viper.AddConfigPath(filepath.Join(homeDir, "config")) // search root directory /config // If a config file is found, read it in. if err := viper.ReadInConfig(); err == nil { @@ -147,23 +146,6 @@ func bindFlagsLoadViper(cmd *cobra.Command, args []string) error { return nil } -// setEncoding reads the encoding flag -func setEncoding(cmd *cobra.Command, args []string) error { - // validate and set encoding - enc := viper.GetString("encoding") - switch enc { - case "hex": - data.Encoder = data.HexEncoder - case "b64": - data.Encoder = data.B64Encoder - case "btc": - data.Encoder = base58.BTCEncoder - default: - return errors.Errorf("Unsupported encoding: %s", enc) - } - return nil -} - func validateOutput(cmd *cobra.Command, args []string) error { // validate output format output := viper.GetString(OutputFlag) diff --git a/cli/setup_test.go b/cli/setup_test.go index e0fd75d8..04209e49 100644 --- a/cli/setup_test.go +++ b/cli/setup_test.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io/ioutil" "strconv" "strings" "testing" @@ -54,11 +55,20 @@ func TestSetupEnv(t *testing.T) { } } +func tempDir() string { + cdir, err := ioutil.TempDir("", "test-cli") + if err != nil { + panic(err) + } + return cdir +} + func TestSetupConfig(t *testing.T) { // we pre-create two config files we can refer to in the rest of // the test cases. cval1 := "fubble" - conf1, err := WriteDemoConfig(map[string]string{"boo": cval1}) + conf1 := tempDir() + err := WriteConfigVals(conf1, map[string]string{"boo": cval1}) require.Nil(t, err) cases := []struct { @@ -116,10 +126,12 @@ func TestSetupUnmarshal(t *testing.T) { // we pre-create two config files we can refer to in the rest of // the test cases. cval1, cval2 := "someone", "else" - conf1, err := WriteDemoConfig(map[string]string{"name": cval1}) + conf1 := tempDir() + err := WriteConfigVals(conf1, map[string]string{"name": cval1}) require.Nil(t, err) // even with some ignored fields, should be no problem - conf2, err := WriteDemoConfig(map[string]string{"name": cval2, "foo": "bar"}) + conf2 := tempDir() + err = WriteConfigVals(conf2, map[string]string{"name": cval2, "foo": "bar"}) require.Nil(t, err) // unused is not declared on a flag and remains from base diff --git a/clist/clist.go b/clist/clist.go index a52920f8..28d771a2 100644 --- a/clist/clist.go +++ b/clist/clist.go @@ -36,12 +36,14 @@ waiting on NextWait() (since it's just a read operation). */ type CElement struct { - mtx sync.RWMutex - prev *CElement - prevWg *sync.WaitGroup - next *CElement - nextWg *sync.WaitGroup - removed bool + mtx sync.RWMutex + prev *CElement + prevWg *sync.WaitGroup + prevWaitCh chan struct{} + next *CElement + nextWg *sync.WaitGroup + nextWaitCh chan struct{} + removed bool Value interface{} // immutable } @@ -84,6 +86,24 @@ func (e *CElement) PrevWait() *CElement { } } +// PrevWaitChan can be used to wait until Prev becomes not nil. Once it does, +// channel will be closed. +func (e *CElement) PrevWaitChan() <-chan struct{} { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.prevWaitCh +} + +// NextWaitChan can be used to wait until Next becomes not nil. Once it does, +// channel will be closed. +func (e *CElement) NextWaitChan() <-chan struct{} { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.nextWaitCh +} + // Nonblocking, may return nil if at the end. func (e *CElement) Next() *CElement { e.mtx.RLock() @@ -142,9 +162,11 @@ func (e *CElement) SetNext(newNext *CElement) { // events, new Add calls must happen after all previous Wait calls have // returned. e.nextWg = waitGroup1() // WaitGroups are difficult to re-use. + e.nextWaitCh = make(chan struct{}) } if oldNext == nil && newNext != nil { e.nextWg.Done() + close(e.nextWaitCh) } } @@ -158,9 +180,11 @@ func (e *CElement) SetPrev(newPrev *CElement) { e.prev = newPrev if oldPrev != nil && newPrev == nil { e.prevWg = waitGroup1() // WaitGroups are difficult to re-use. + e.prevWaitCh = make(chan struct{}) } if oldPrev == nil && newPrev != nil { e.prevWg.Done() + close(e.prevWaitCh) } } @@ -173,9 +197,11 @@ func (e *CElement) SetRemoved() { // This wakes up anyone waiting in either direction. if e.prev == nil { e.prevWg.Done() + close(e.prevWaitCh) } if e.next == nil { e.nextWg.Done() + close(e.nextWaitCh) } } @@ -185,11 +211,12 @@ func (e *CElement) SetRemoved() { // The zero value for CList is an empty list ready to use. // Operations are goroutine-safe. type CList struct { - mtx sync.RWMutex - wg *sync.WaitGroup - head *CElement // first element - tail *CElement // last element - len int // list length + mtx sync.RWMutex + wg *sync.WaitGroup + waitCh chan struct{} + head *CElement // first element + tail *CElement // last element + len int // list length } func (l *CList) Init() *CList { @@ -197,6 +224,7 @@ func (l *CList) Init() *CList { defer l.mtx.Unlock() l.wg = waitGroup1() + l.waitCh = make(chan struct{}) l.head = nil l.tail = nil l.len = 0 @@ -258,23 +286,35 @@ func (l *CList) BackWait() *CElement { } } +// WaitChan can be used to wait until Front or Back becomes not nil. Once it +// does, channel will be closed. +func (l *CList) WaitChan() <-chan struct{} { + l.mtx.Lock() + defer l.mtx.Unlock() + + return l.waitCh +} + func (l *CList) PushBack(v interface{}) *CElement { l.mtx.Lock() defer l.mtx.Unlock() // Construct a new element e := &CElement{ - prev: nil, - prevWg: waitGroup1(), - next: nil, - nextWg: waitGroup1(), - removed: false, - Value: v, + prev: nil, + prevWg: waitGroup1(), + prevWaitCh: make(chan struct{}), + next: nil, + nextWg: waitGroup1(), + nextWaitCh: make(chan struct{}), + removed: false, + Value: v, } // Release waiters on FrontWait/BackWait maybe if l.len == 0 { l.wg.Done() + close(l.waitCh) } l.len += 1 @@ -313,6 +353,7 @@ func (l *CList) Remove(e *CElement) interface{} { // If we're removing the only item, make CList FrontWait/BackWait wait. if l.len == 1 { l.wg = waitGroup1() // WaitGroups are difficult to re-use. + l.waitCh = make(chan struct{}) } // Update l.len diff --git a/clist/clist_test.go b/clist/clist_test.go index 9d5272de..31f82165 100644 --- a/clist/clist_test.go +++ b/clist/clist_test.go @@ -218,3 +218,76 @@ func TestScanRightDeleteRandom(t *testing.T) { t.Fatal("Failed to remove all elements from CList") } } + +func TestWaitChan(t *testing.T) { + l := New() + ch := l.WaitChan() + + // 1) add one element to an empty list + go l.PushBack(1) + <-ch + + // 2) and remove it + el := l.Front() + v := l.Remove(el) + if v != 1 { + t.Fatal("where is 1 coming from?") + } + + // 3) test iterating forward and waiting for Next (NextWaitChan and Next) + el = l.PushBack(0) + + done := make(chan struct{}) + pushed := 0 + go func() { + for i := 1; i < 100; i++ { + l.PushBack(i) + pushed++ + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + } + close(done) + }() + + next := el + seen := 0 +FOR_LOOP: + for { + select { + case <-next.NextWaitChan(): + next = next.Next() + seen++ + if next == nil { + continue + } + case <-done: + break FOR_LOOP + case <-time.After(10 * time.Second): + t.Fatal("max execution time") + } + } + + if pushed != seen { + t.Fatalf("number of pushed items (%d) not equal to number of seen items (%d)", pushed, seen) + } + + // 4) test iterating backwards (PrevWaitChan and Prev) + prev := next + seen = 0 +FOR_LOOP2: + for { + select { + case <-prev.PrevWaitChan(): + prev = prev.Prev() + seen++ + if prev == nil { + t.Fatal("expected PrevWaitChan to block forever on nil when reached first elem") + } + case <-time.After(5 * time.Second): + break FOR_LOOP2 + } + } + + if pushed != seen { + t.Fatalf("number of pushed items (%d) not equal to number of seen items (%d)", pushed, seen) + } +} diff --git a/common/bit_array.go b/common/bit_array.go index 848763b4..7cc84705 100644 --- a/common/bit_array.go +++ b/common/bit_array.go @@ -99,8 +99,14 @@ func (bA *BitArray) copyBits(bits int) *BitArray { // Returns a BitArray of larger bits size. func (bA *BitArray) Or(o *BitArray) *BitArray { - if bA == nil { - o.Copy() + if bA == nil && o == nil { + return nil + } + if bA == nil && o != nil { + return o.Copy() + } + if o == nil { + return bA.Copy() } bA.mtx.Lock() defer bA.mtx.Unlock() @@ -113,7 +119,7 @@ func (bA *BitArray) Or(o *BitArray) *BitArray { // Returns a BitArray of smaller bit size. func (bA *BitArray) And(o *BitArray) *BitArray { - if bA == nil { + if bA == nil || o == nil { return nil } bA.mtx.Lock() @@ -143,7 +149,8 @@ func (bA *BitArray) Not() *BitArray { } func (bA *BitArray) Sub(o *BitArray) *BitArray { - if bA == nil { + if bA == nil || o == nil { + // TODO: Decide if we should do 1's complement here? return nil } bA.mtx.Lock() @@ -306,7 +313,7 @@ func (bA *BitArray) Bytes() []byte { // so if necessary, caller must copy or lock o prior to calling Update. // If bA is nil, does nothing. func (bA *BitArray) Update(o *BitArray) { - if bA == nil { + if bA == nil || o == nil { return } bA.mtx.Lock() diff --git a/common/bit_array_test.go b/common/bit_array_test.go index 1c72882c..94a312b7 100644 --- a/common/bit_array_test.go +++ b/common/bit_array_test.go @@ -3,6 +3,8 @@ package common import ( "bytes" "testing" + + "github.com/stretchr/testify/require" ) func randBitArray(bits int) (*BitArray, []byte) { @@ -26,6 +28,11 @@ func TestAnd(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.And(bA2) + var bNil *BitArray + require.Equal(t, bNil.And(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.And(nil), (*BitArray)(nil)) + require.Equal(t, bNil.And(nil), (*BitArray)(nil)) + if bA3.Bits != 31 { t.Error("Expected min bits", bA3.Bits) } @@ -46,6 +53,11 @@ func TestOr(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.Or(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Or(bA1), bA1) + require.Equal(t, bA1.Or(nil), bA1) + require.Equal(t, bNil.Or(nil), (*BitArray)(nil)) + if bA3.Bits != 51 { t.Error("Expected max bits") } @@ -66,6 +78,11 @@ func TestSub1(t *testing.T) { bA2, _ := randBitArray(51) bA3 := bA1.Sub(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + if bA3.Bits != bA1.Bits { t.Error("Expected bA1 bits") } @@ -89,6 +106,11 @@ func TestSub2(t *testing.T) { bA2, _ := randBitArray(31) bA3 := bA1.Sub(bA2) + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + if bA3.Bits != bA1.Bits { t.Error("Expected bA1 bits") } @@ -164,3 +186,25 @@ func TestEmptyFull(t *testing.T) { } } } + +func TestUpdateNeverPanics(t *testing.T) { + newRandBitArray := func(n int) *BitArray { + ba, _ := randBitArray(n) + return ba + } + pairs := []struct { + a, b *BitArray + }{ + {nil, nil}, + {newRandBitArray(10), newRandBitArray(12)}, + {newRandBitArray(23), newRandBitArray(23)}, + {newRandBitArray(37), nil}, + {nil, NewBitArray(10)}, + } + + for _, pair := range pairs { + a, b := pair.a, pair.b + a.Update(b) + b.Update(a) + } +} diff --git a/common/bytes.go b/common/bytes.go new file mode 100644 index 00000000..ba81bbe9 --- /dev/null +++ b/common/bytes.go @@ -0,0 +1,53 @@ +package common + +import ( + "encoding/hex" + "fmt" + "strings" +) + +// The main purpose of HexBytes is to enable HEX-encoding for json/encoding. +type HexBytes []byte + +// Marshal needed for protobuf compatibility +func (bz HexBytes) Marshal() ([]byte, error) { + return bz, nil +} + +// Unmarshal needed for protobuf compatibility +func (bz *HexBytes) Unmarshal(data []byte) error { + *bz = data + return nil +} + +// This is the point of Bytes. +func (bz HexBytes) MarshalJSON() ([]byte, error) { + s := strings.ToUpper(hex.EncodeToString(bz)) + jbz := make([]byte, len(s)+2) + jbz[0] = '"' + copy(jbz[1:], []byte(s)) + jbz[len(jbz)-1] = '"' + return jbz, nil +} + +// This is the point of Bytes. +func (bz *HexBytes) UnmarshalJSON(data []byte) error { + if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { + return fmt.Errorf("Invalid hex string: %s", data) + } + bz2, err := hex.DecodeString(string(data[1 : len(data)-1])) + if err != nil { + return err + } + *bz = bz2 + return nil +} + +// Allow it to fulfill various interfaces in light-client, etc... +func (bz HexBytes) Bytes() []byte { + return bz +} + +func (bz HexBytes) String() string { + return strings.ToUpper(hex.EncodeToString(bz)) +} diff --git a/common/bytes_test.go b/common/bytes_test.go new file mode 100644 index 00000000..9e11988f --- /dev/null +++ b/common/bytes_test.go @@ -0,0 +1,65 @@ +package common + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This is a trivial test for protobuf compatibility. +func TestMarshal(t *testing.T) { + bz := []byte("hello world") + dataB := HexBytes(bz) + bz2, err := dataB.Marshal() + assert.Nil(t, err) + assert.Equal(t, bz, bz2) + + var dataB2 HexBytes + err = (&dataB2).Unmarshal(bz) + assert.Nil(t, err) + assert.Equal(t, dataB, dataB2) +} + +// Test that the hex encoding works. +func TestJSONMarshal(t *testing.T) { + + type TestStruct struct { + B1 []byte + B2 HexBytes + } + + cases := []struct { + input []byte + expected string + }{ + {[]byte(``), `{"B1":"","B2":""}`}, + {[]byte(`a`), `{"B1":"YQ==","B2":"61"}`}, + {[]byte(`abc`), `{"B1":"YWJj","B2":"616263"}`}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { + ts := TestStruct{B1: tc.input, B2: tc.input} + + // Test that it marshals correctly to JSON. + jsonBytes, err := json.Marshal(ts) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, string(jsonBytes), tc.expected) + + // TODO do fuzz testing to ensure that unmarshal fails + + // Test that unmarshaling works correctly. + ts2 := TestStruct{} + err = json.Unmarshal(jsonBytes, &ts2) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, ts2.B1, tc.input) + assert.Equal(t, ts2.B2, HexBytes(tc.input)) + }) + } +} diff --git a/common/errors.go b/common/errors.go index 039342a6..4710b9ee 100644 --- a/common/errors.go +++ b/common/errors.go @@ -22,6 +22,7 @@ func (se StackError) Error() string { // A panic resulting from a sanity check means there is a programmer error // and some guarantee is not satisfied. +// XXX DEPRECATED func PanicSanity(v interface{}) { panic(Fmt("Panicked on a Sanity Check: %v", v)) } @@ -29,17 +30,20 @@ func PanicSanity(v interface{}) { // A panic here means something has gone horribly wrong, in the form of data corruption or // failure of the operating system. In a correct/healthy system, these should never fire. // If they do, it's indicative of a much more serious problem. +// XXX DEPRECATED func PanicCrisis(v interface{}) { panic(Fmt("Panicked on a Crisis: %v", v)) } // Indicates a failure of consensus. Someone was malicious or something has // gone horribly wrong. These should really boot us into an "emergency-recover" mode +// XXX DEPRECATED func PanicConsensus(v interface{}) { panic(Fmt("Panicked on a Consensus Failure: %v", v)) } // For those times when we're not sure if we should panic +// XXX DEPRECATED func PanicQ(v interface{}) { panic(Fmt("Panicked questionably: %v", v)) } diff --git a/common/heap.go b/common/heap.go index 4a96d7aa..b3bcb9db 100644 --- a/common/heap.go +++ b/common/heap.go @@ -1,28 +1,25 @@ package common import ( + "bytes" "container/heap" ) -type Comparable interface { - Less(o interface{}) bool -} - -//----------------------------------------------------------------------------- - /* -Example usage: + Example usage: + + ``` h := NewHeap() - h.Push(String("msg1"), 1) - h.Push(String("msg3"), 3) - h.Push(String("msg2"), 2) + h.Push("msg1", 1) + h.Push("msg3", 3) + h.Push("msg2", 2) - fmt.Println(h.Pop()) - fmt.Println(h.Pop()) - fmt.Println(h.Pop()) + fmt.Println(h.Pop()) // msg1 + fmt.Println(h.Pop()) // msg2 + fmt.Println(h.Pop()) // msg3 + ``` */ - type Heap struct { pq priorityQueue } @@ -35,7 +32,15 @@ func (h *Heap) Len() int64 { return int64(len(h.pq)) } -func (h *Heap) Push(value interface{}, priority Comparable) { +func (h *Heap) Push(value interface{}, priority int) { + heap.Push(&h.pq, &pqItem{value: value, priority: cmpInt(priority)}) +} + +func (h *Heap) PushBytes(value interface{}, priority []byte) { + heap.Push(&h.pq, &pqItem{value: value, priority: cmpBytes(priority)}) +} + +func (h *Heap) PushComparable(value interface{}, priority Comparable) { heap.Push(&h.pq, &pqItem{value: value, priority: priority}) } @@ -56,8 +61,6 @@ func (h *Heap) Pop() interface{} { } //----------------------------------------------------------------------------- - -/////////////////////// // From: http://golang.org/pkg/container/heap/#example__priorityQueue type pqItem struct { @@ -101,3 +104,22 @@ func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority Compar item.priority = priority heap.Fix(pq, item.index) } + +//-------------------------------------------------------------------------------- +// Comparable + +type Comparable interface { + Less(o interface{}) bool +} + +type cmpInt int + +func (i cmpInt) Less(o interface{}) bool { + return int(i) < int(o.(cmpInt)) +} + +type cmpBytes []byte + +func (bz cmpBytes) Less(o interface{}) bool { + return bytes.Compare([]byte(bz), []byte(o.(cmpBytes))) < 0 +} diff --git a/common/http.go b/common/http.go deleted file mode 100644 index 56b5b6c6..00000000 --- a/common/http.go +++ /dev/null @@ -1,153 +0,0 @@ -package common - -import ( - "encoding/json" - "io" - "net/http" - - "gopkg.in/go-playground/validator.v9" - - "github.com/pkg/errors" -) - -type ErrorResponse struct { - Success bool `json:"success,omitempty"` - - // Err is the error message if Success is false - Err string `json:"error,omitempty"` - - // Code is set if Success is false - Code int `json:"code,omitempty"` -} - -// ErrorWithCode makes an ErrorResponse with the -// provided err's Error() content, and status code. -// It panics if err is nil. -func ErrorWithCode(err error, code int) *ErrorResponse { - return &ErrorResponse{ - Err: err.Error(), - Code: code, - } -} - -// Ensure that ErrorResponse implements error -var _ error = (*ErrorResponse)(nil) - -func (er *ErrorResponse) Error() string { - return er.Err -} - -// Ensure that ErrorResponse implements httpCoder -var _ httpCoder = (*ErrorResponse)(nil) - -func (er *ErrorResponse) HTTPCode() int { - return er.Code -} - -var errNilBody = errors.Errorf("expecting a non-nil body") - -// FparseJSON unmarshals into save, the body of the provided reader. -// Since it uses json.Unmarshal, save must be of a pointer type -// or compatible with json.Unmarshal. -func FparseJSON(r io.Reader, save interface{}) error { - if r == nil { - return errors.Wrap(errNilBody, "Reader") - } - - dec := json.NewDecoder(r) - if err := dec.Decode(save); err != nil { - return errors.Wrap(err, "Decode/Unmarshal") - } - return nil -} - -// ParseRequestJSON unmarshals into save, the body of the -// request. It closes the body of the request after parsing. -// Since it uses json.Unmarshal, save must be of a pointer type -// or compatible with json.Unmarshal. -func ParseRequestJSON(r *http.Request, save interface{}) error { - if r == nil || r.Body == nil { - return errNilBody - } - defer r.Body.Close() - - return FparseJSON(r.Body, save) -} - -// ParseRequestAndValidateJSON unmarshals into save, the body of the -// request and invokes a validator on the saved content. To ensure -// validation, make sure to set tags "validate" on your struct as -// per https://godoc.org/gopkg.in/go-playground/validator.v9. -// It closes the body of the request after parsing. -// Since it uses json.Unmarshal, save must be of a pointer type -// or compatible with json.Unmarshal. -func ParseRequestAndValidateJSON(r *http.Request, save interface{}) error { - if r == nil || r.Body == nil { - return errNilBody - } - defer r.Body.Close() - - return FparseAndValidateJSON(r.Body, save) -} - -// FparseAndValidateJSON like FparseJSON unmarshals into save, -// the body of the provided reader. However, it invokes the validator -// to check the set validators on your struct fields as per -// per https://godoc.org/gopkg.in/go-playground/validator.v9. -// Since it uses json.Unmarshal, save must be of a pointer type -// or compatible with json.Unmarshal. -func FparseAndValidateJSON(r io.Reader, save interface{}) error { - if err := FparseJSON(r, save); err != nil { - return err - } - return validate(save) -} - -var theValidator = validator.New() - -func validate(obj interface{}) error { - return errors.Wrap(theValidator.Struct(obj), "Validate") -} - -// WriteSuccess JSON marshals the content provided, to an HTTP -// response, setting the provided status code and setting header -// "Content-Type" to "application/json". -func WriteSuccess(w http.ResponseWriter, data interface{}) { - WriteCode(w, data, 200) -} - -// WriteCode JSON marshals content, to an HTTP response, -// setting the provided status code, and setting header -// "Content-Type" to "application/json". If JSON marshalling fails -// with an error, WriteCode instead writes out the error invoking -// WriteError. -func WriteCode(w http.ResponseWriter, out interface{}, code int) { - blob, err := json.MarshalIndent(out, "", " ") - if err != nil { - WriteError(w, err) - } else { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - w.Write(blob) - } -} - -type httpCoder interface { - HTTPCode() int -} - -// WriteError is a convenience function to write out an -// error to an http.ResponseWriter, to send out an error -// that's structured as JSON i.e the form -// {"error": sss, "code": ddd} -// If err implements the interface HTTPCode() int, -// it will use that status code otherwise, it will -// set code to be http.StatusBadRequest -func WriteError(w http.ResponseWriter, err error) { - code := http.StatusBadRequest - if httpC, ok := err.(httpCoder); ok { - code = httpC.HTTPCode() - } - - WriteCode(w, ErrorWithCode(err, code), code) -} diff --git a/common/http_test.go b/common/http_test.go deleted file mode 100644 index 4272f606..00000000 --- a/common/http_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package common_test - -import ( - "bytes" - "encoding/json" - "errors" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "reflect" - "strings" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/tendermint/tmlibs/common" -) - -func TestWriteSuccess(t *testing.T) { - w := httptest.NewRecorder() - common.WriteSuccess(w, "foo") - assert.Equal(t, w.Code, 200, "should get a 200") -} - -var blankErrResponse = new(common.ErrorResponse) - -func TestWriteError(t *testing.T) { - tests := [...]struct { - msg string - code int - }{ - 0: { - msg: "this is a message", - code: 419, - }, - } - - for i, tt := range tests { - w := httptest.NewRecorder() - msg := tt.msg - - // First check without a defined code, should send back a 400 - common.WriteError(w, errors.New(msg)) - assert.Equal(t, w.Code, http.StatusBadRequest, "#%d: should get a 400", i) - blob, err := ioutil.ReadAll(w.Body) - if err != nil { - assert.Fail(t, "expecting a successful ioutil.ReadAll", "#%d", i) - continue - } - - recv := new(common.ErrorResponse) - if err := json.Unmarshal(blob, recv); err != nil { - assert.Fail(t, "expecting a successful json.Unmarshal", "#%d", i) - continue - } - - assert.Equal(t, reflect.DeepEqual(recv, blankErrResponse), false, "expecting a non-blank error response") - - // Now test with an error that's .HTTPCode() int conforming - - // Reset w - w = httptest.NewRecorder() - - common.WriteError(w, common.ErrorWithCode(errors.New("foo"), tt.code)) - assert.Equal(t, w.Code, tt.code, "case #%d", i) - } -} - -type marshalFailer struct{} - -var errFooFailed = errors.New("foo failed here") - -func (mf *marshalFailer) MarshalJSON() ([]byte, error) { - return nil, errFooFailed -} - -func TestWriteCode(t *testing.T) { - codes := [...]int{ - 0: http.StatusOK, - 1: http.StatusBadRequest, - 2: http.StatusUnauthorized, - 3: http.StatusInternalServerError, - } - - for i, code := range codes { - w := httptest.NewRecorder() - common.WriteCode(w, "foo", code) - assert.Equal(t, w.Code, code, "#%d", i) - - // Then for the failed JSON marshaling - w = httptest.NewRecorder() - common.WriteCode(w, &marshalFailer{}, code) - wantCode := http.StatusBadRequest - assert.Equal(t, w.Code, wantCode, "#%d", i) - assert.True(t, strings.Contains(w.Body.String(), errFooFailed.Error()), - "#%d: expected %q in the error message", i, errFooFailed) - } -} - -type saver struct { - Foo int `json:"foo" validate:"min=10"` - Bar string `json:"bar"` -} - -type rcloser struct { - closeOnce sync.Once - body *bytes.Buffer - closeChan chan bool -} - -var errAlreadyClosed = errors.New("already closed") - -func (rc *rcloser) Close() error { - var err = errAlreadyClosed - rc.closeOnce.Do(func() { - err = nil - rc.closeChan <- true - close(rc.closeChan) - }) - return err -} - -func (rc *rcloser) Read(b []byte) (int, error) { - return rc.body.Read(b) -} - -var _ io.ReadCloser = (*rcloser)(nil) - -func makeReq(strBody string) (*http.Request, <-chan bool) { - closeChan := make(chan bool, 1) - buf := new(bytes.Buffer) - buf.Write([]byte(strBody)) - req := &http.Request{ - Header: make(http.Header), - Body: &rcloser{body: buf, closeChan: closeChan}, - } - return req, closeChan -} - -func TestParseRequestJSON(t *testing.T) { - tests := [...]struct { - body string - wantErr bool - useNil bool - }{ - 0: {wantErr: true, body: ``}, - 1: {body: `{}`}, - 2: {body: `{"foo": 2}`}, // Not that the validate tags don't matter here since we are just parsing - 3: {body: `{"foo": "abcd"}`, wantErr: true}, - 4: {useNil: true, wantErr: true}, - } - - for i, tt := range tests { - req, closeChan := makeReq(tt.body) - if tt.useNil { - req.Body = nil - } - sav := new(saver) - err := common.ParseRequestJSON(req, sav) - if tt.wantErr { - assert.NotEqual(t, err, nil, "#%d: want non-nil error", i) - continue - } - assert.Equal(t, err, nil, "#%d: want nil error", i) - wasClosed := <-closeChan - assert.Equal(t, wasClosed, true, "#%d: should have invoked close", i) - } -} - -func TestFparseJSON(t *testing.T) { - r1 := strings.NewReader(`{"foo": 1}`) - sav := new(saver) - require.Equal(t, common.FparseJSON(r1, sav), nil, "expecting successful parsing") - r2 := strings.NewReader(`{"bar": "blockchain"}`) - require.Equal(t, common.FparseJSON(r2, sav), nil, "expecting successful parsing") - require.Equal(t, reflect.DeepEqual(sav, &saver{Foo: 1, Bar: "blockchain"}), true, "should have parsed both") - - // Now with a nil body - require.NotEqual(t, nil, common.FparseJSON(nil, sav), "expecting a nil error report") -} - -func TestFparseAndValidateJSON(t *testing.T) { - r1 := strings.NewReader(`{"foo": 1}`) - sav := new(saver) - require.NotEqual(t, common.FparseAndValidateJSON(r1, sav), nil, "expecting validation to fail") - r1 = strings.NewReader(`{"foo": 100}`) - require.Equal(t, common.FparseJSON(r1, sav), nil, "expecting successful parsing") - r2 := strings.NewReader(`{"bar": "blockchain"}`) - require.Equal(t, common.FparseAndValidateJSON(r2, sav), nil, "expecting successful parsing") - require.Equal(t, reflect.DeepEqual(sav, &saver{Foo: 100, Bar: "blockchain"}), true, "should have parsed both") - - // Now with a nil body - require.NotEqual(t, nil, common.FparseJSON(nil, sav), "expecting a nil error report") -} - -var blankSaver = new(saver) - -func TestParseAndValidateRequestJSON(t *testing.T) { - tests := [...]struct { - body string - wantErr bool - useNil bool - }{ - 0: {wantErr: true, body: ``}, - 1: {body: `{}`, wantErr: true}, // Here it should fail since Foo doesn't meet the minimum value - 2: {body: `{"foo": 2}`, wantErr: true}, // Here validation should fail - 3: {body: `{"foo": "abcd"}`, wantErr: true}, - 4: {useNil: true, wantErr: true}, - 5: {body: `{"foo": 100}`}, // Must succeed - } - - for i, tt := range tests { - req, closeChan := makeReq(tt.body) - if tt.useNil { - req.Body = nil - } - sav := new(saver) - err := common.ParseRequestAndValidateJSON(req, sav) - if tt.wantErr { - assert.NotEqual(t, err, nil, "#%d: want non-nil error", i) - continue - } - - assert.Equal(t, err, nil, "#%d: want nil error", i) - assert.False(t, reflect.DeepEqual(blankSaver, sav), "#%d: expecting a set saver", i) - - wasClosed := <-closeChan - assert.Equal(t, wasClosed, true, "#%d: should have invoked close", i) - } -} - -func TestErrorWithCode(t *testing.T) { - tests := [...]struct { - code int - err error - }{ - 0: {code: 500, err: errors.New("funky")}, - 1: {code: 406, err: errors.New("purist")}, - } - - for i, tt := range tests { - errRes := common.ErrorWithCode(tt.err, tt.code) - assert.Equal(t, errRes.Error(), tt.err.Error(), "#%d: expecting the error values to be equal", i) - assert.Equal(t, errRes.Code, tt.code, "expecting the same status code", i) - assert.Equal(t, errRes.HTTPCode(), tt.code, "expecting the same status code", i) - } -} diff --git a/common/kvpair.go b/common/kvpair.go new file mode 100644 index 00000000..54c3a58c --- /dev/null +++ b/common/kvpair.go @@ -0,0 +1,67 @@ +package common + +import ( + "bytes" + "sort" +) + +//---------------------------------------- +// KVPair + +/* +Defined in types.proto + +type KVPair struct { + Key []byte + Value []byte +} +*/ + +type KVPairs []KVPair + +// Sorting +func (kvs KVPairs) Len() int { return len(kvs) } +func (kvs KVPairs) Less(i, j int) bool { + switch bytes.Compare(kvs[i].Key, kvs[j].Key) { + case -1: + return true + case 0: + return bytes.Compare(kvs[i].Value, kvs[j].Value) < 0 + case 1: + return false + default: + panic("invalid comparison result") + } +} +func (kvs KVPairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } +func (kvs KVPairs) Sort() { sort.Sort(kvs) } + +//---------------------------------------- +// KI64Pair + +/* +Defined in types.proto +type KI64Pair struct { + Key []byte + Value int64 +} +*/ + +type KI64Pairs []KI64Pair + +// Sorting +func (kvs KI64Pairs) Len() int { return len(kvs) } +func (kvs KI64Pairs) Less(i, j int) bool { + switch bytes.Compare(kvs[i].Key, kvs[j].Key) { + case -1: + return true + case 0: + return kvs[i].Value < kvs[j].Value + case 1: + return false + default: + panic("invalid comparison result") + } +} +func (kvs KI64Pairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } +func (kvs KI64Pairs) Sort() { sort.Sort(kvs) } diff --git a/common/repeat_timer.go b/common/repeat_timer.go index 2e6cb81c..5d049738 100644 --- a/common/repeat_timer.go +++ b/common/repeat_timer.go @@ -20,15 +20,17 @@ type Ticker interface { } //---------------------------------------- -// defaultTickerMaker +// defaultTicker + +var _ Ticker = (*defaultTicker)(nil) + +type defaultTicker time.Ticker func defaultTickerMaker(dur time.Duration) Ticker { ticker := time.NewTicker(dur) return (*defaultTicker)(ticker) } -type defaultTicker time.Ticker - // Implements Ticker func (t *defaultTicker) Chan() <-chan time.Time { return t.C @@ -80,13 +82,11 @@ func (t *logicalTicker) fireRoutine(interval time.Duration) { } // Init `lasttime` end - timeleft := interval for { select { case newtime := <-source: elapsed := newtime.Sub(lasttime) - timeleft -= elapsed - if timeleft <= 0 { + if interval <= elapsed { // Block for determinism until the ticker is stopped. select { case t.ch <- newtime: @@ -97,7 +97,7 @@ func (t *logicalTicker) fireRoutine(interval time.Duration) { // Don't try to "catch up" by sending more. // "Ticker adjusts the intervals or drops ticks to make up for // slow receivers" - https://golang.org/pkg/time/#Ticker - timeleft = interval + lasttime = newtime } case <-t.quit: return // done @@ -153,11 +153,16 @@ func NewRepeatTimerWithTickerMaker(name string, dur time.Duration, tm TickerMake return t } +// receive ticks on ch, send out on t.ch func (t *RepeatTimer) fireRoutine(ch <-chan time.Time, quit <-chan struct{}) { for { select { - case t_ := <-ch: - t.ch <- t_ + case tick := <-ch: + select { + case t.ch <- tick: + case <-quit: + return + } case <-quit: // NOTE: `t.quit` races. return } @@ -212,7 +217,6 @@ func (t *RepeatTimer) stop() { t.ticker.Stop() t.ticker = nil /* - XXX From https://golang.org/pkg/time/#Ticker: "Stop the ticker to release associated resources" "After Stop, no more ticks will be sent" diff --git a/common/repeat_timer_test.go b/common/repeat_timer_test.go index 5a3a4c0a..160f4394 100644 --- a/common/repeat_timer_test.go +++ b/common/repeat_timer_test.go @@ -1,9 +1,12 @@ package common import ( + "math/rand" + "sync" "testing" "time" + "github.com/fortytw2/leaktest" "github.com/stretchr/testify/assert" ) @@ -13,29 +16,42 @@ func TestDefaultTicker(t *testing.T) { ticker.Stop() } -func TestRepeat(t *testing.T) { +func TestRepeatTimer(t *testing.T) { ch := make(chan time.Time, 100) - lt := time.Time{} // zero time is year 1 + mtx := new(sync.Mutex) - // tick fires `cnt` times for each second. - tick := func(cnt int) { - for i := 0; i < cnt; i++ { - lt = lt.Add(time.Second) - ch <- lt - } + // tick() fires from start to end + // (exclusive) in milliseconds with incr. + // It locks on mtx, so subsequent calls + // run in series. + tick := func(startMs, endMs, incrMs time.Duration) { + mtx.Lock() + go func() { + for tMs := startMs; tMs < endMs; tMs += incrMs { + lt := time.Time{} + lt = lt.Add(tMs * time.Millisecond) + ch <- lt + } + mtx.Unlock() + }() } - // tock consumes Ticker.Chan() events `cnt` times. - tock := func(t *testing.T, rt *RepeatTimer, cnt int) { - for i := 0; i < cnt; i++ { - timeout := time.After(time.Second * 10) - select { - case <-rt.Chan(): - case <-timeout: - panic("expected RepeatTimer to fire") - } + // tock consumes Ticker.Chan() events and checks them against the ms in "timesMs". + tock := func(t *testing.T, rt *RepeatTimer, timesMs []int64) { + + // Check against timesMs. + for _, timeMs := range timesMs { + tyme := <-rt.Chan() + sinceMs := tyme.Sub(time.Time{}) / time.Millisecond + assert.Equal(t, timeMs, int64(sinceMs)) } + + // TODO detect number of running + // goroutines to ensure that + // no other times will fire. + // See https://github.com/tendermint/tmlibs/issues/120. + time.Sleep(time.Millisecond * 100) done := true select { case <-rt.Chan(): @@ -46,47 +62,76 @@ func TestRepeat(t *testing.T) { } tm := NewLogicalTickerMaker(ch) - dur := time.Duration(10 * time.Millisecond) // less than a second - rt := NewRepeatTimerWithTickerMaker("bar", dur, tm) + rt := NewRepeatTimerWithTickerMaker("bar", time.Second, tm) - // Start at 0. - tock(t, rt, 0) - tick(1) // init time + /* NOTE: Useful for debugging deadlocks... + go func() { + time.Sleep(time.Second * 3) + trace := make([]byte, 102400) + count := runtime.Stack(trace, true) + fmt.Printf("Stack of %d bytes: %s\n", count, trace) + }() + */ - tock(t, rt, 0) - tick(1) // wait 1 periods - tock(t, rt, 1) - tick(2) // wait 2 periods - tock(t, rt, 2) - tick(3) // wait 3 periods - tock(t, rt, 3) - tick(4) // wait 4 periods - tock(t, rt, 4) + tick(0, 1000, 10) + tock(t, rt, []int64{}) + tick(1000, 2000, 10) + tock(t, rt, []int64{1000}) + tick(2005, 5000, 10) + tock(t, rt, []int64{2005, 3005, 4005}) + tick(5001, 5999, 1) + // Read 5005 instead of 5001 because + // it's 1 second greater than 4005. + tock(t, rt, []int64{5005}) + tick(6000, 7005, 1) + tock(t, rt, []int64{6005}) + tick(7033, 8032, 1) + tock(t, rt, []int64{7033}) - // Multiple resets leads to no firing. - for i := 0; i < 20; i++ { - time.Sleep(time.Millisecond) - rt.Reset() - } - - // After this, it works as new. - tock(t, rt, 0) - tick(1) // init time - - tock(t, rt, 0) - tick(1) // wait 1 periods - tock(t, rt, 1) - tick(2) // wait 2 periods - tock(t, rt, 2) - tick(3) // wait 3 periods - tock(t, rt, 3) - tick(4) // wait 4 periods - tock(t, rt, 4) + // After a reset, nothing happens + // until two ticks are received. + rt.Reset() + tock(t, rt, []int64{}) + tick(8040, 8041, 1) + tock(t, rt, []int64{}) + tick(9555, 9556, 1) + tock(t, rt, []int64{9555}) // After a stop, nothing more is sent. rt.Stop() - tock(t, rt, 0) + tock(t, rt, []int64{}) // Another stop panics. assert.Panics(t, func() { rt.Stop() }) } + +func TestRepeatTimerReset(t *testing.T) { + // check that we are not leaking any go-routines + defer leaktest.Check(t)() + + timer := NewRepeatTimer("test", 20*time.Millisecond) + defer timer.Stop() + + // test we don't receive tick before duration ms. + select { + case <-timer.Chan(): + t.Fatal("did not expect to receive tick") + default: + } + + timer.Reset() + + // test we receive tick after Reset is called + select { + case <-timer.Chan(): + // all good + case <-time.After(40 * time.Millisecond): + t.Fatal("expected to receive tick after reset") + } + + // just random calls + for i := 0; i < 100; i++ { + time.Sleep(time.Duration(rand.Intn(40)) * time.Millisecond) + timer.Reset() + } +} diff --git a/common/service.go b/common/service.go index d70d16a8..2502d671 100644 --- a/common/service.go +++ b/common/service.go @@ -35,9 +35,13 @@ type Service interface { // Return true if the service is running IsRunning() bool + // Quit returns a channel, which is closed once service is stopped. + Quit() <-chan struct{} + // String representation of the service String() string + // SetLogger sets a logger. SetLogger(log.Logger) } @@ -88,12 +92,13 @@ type BaseService struct { name string started uint32 // atomic stopped uint32 // atomic - Quit chan struct{} + quit chan struct{} // The "subclass" of BaseService impl Service } +// NewBaseService creates a new BaseService. func NewBaseService(logger log.Logger, name string, impl Service) *BaseService { if logger == nil { logger = log.NewNopLogger() @@ -102,16 +107,19 @@ func NewBaseService(logger log.Logger, name string, impl Service) *BaseService { return &BaseService{ Logger: logger, name: name, - Quit: make(chan struct{}), + quit: make(chan struct{}), impl: impl, } } +// SetLogger implements Service by setting a logger. func (bs *BaseService) SetLogger(l log.Logger) { bs.Logger = l } -// Implements Servce +// Start implements Service by calling OnStart (if defined). An error will be +// returned if the service is already running or stopped. Not to start the +// stopped service, you need to call Reset. func (bs *BaseService) Start() error { if atomic.CompareAndSwapUint32(&bs.started, 0, 1) { if atomic.LoadUint32(&bs.stopped) == 1 { @@ -133,17 +141,18 @@ func (bs *BaseService) Start() error { } } -// Implements Service +// OnStart implements Service by doing nothing. // NOTE: Do not put anything in here, // that way users don't need to call BaseService.OnStart() func (bs *BaseService) OnStart() error { return nil } -// Implements Service +// Stop implements Service by calling OnStop (if defined) and closing quit +// channel. An error will be returned if the service is already stopped. func (bs *BaseService) Stop() error { if atomic.CompareAndSwapUint32(&bs.stopped, 0, 1) { bs.Logger.Info(Fmt("Stopping %v", bs.name), "impl", bs.impl) bs.impl.OnStop() - close(bs.Quit) + close(bs.quit) return nil } else { bs.Logger.Debug(Fmt("Stopping %v (ignoring: already stopped)", bs.name), "impl", bs.impl) @@ -151,12 +160,13 @@ func (bs *BaseService) Stop() error { } } -// Implements Service +// OnStop implements Service by doing nothing. // NOTE: Do not put anything in here, // that way users don't need to call BaseService.OnStop() func (bs *BaseService) OnStop() {} -// Implements Service +// Reset implements Service by calling OnReset callback (if defined). An error +// will be returned if the service is running. func (bs *BaseService) Reset() error { if !atomic.CompareAndSwapUint32(&bs.stopped, 1, 0) { bs.Logger.Debug(Fmt("Can't reset %v. Not stopped", bs.name), "impl", bs.impl) @@ -166,41 +176,33 @@ func (bs *BaseService) Reset() error { // whether or not we've started, we can reset atomic.CompareAndSwapUint32(&bs.started, 1, 0) - bs.Quit = make(chan struct{}) + bs.quit = make(chan struct{}) return bs.impl.OnReset() } -// Implements Service +// OnReset implements Service by panicking. func (bs *BaseService) OnReset() error { PanicSanity("The service cannot be reset") return nil } -// Implements Service +// IsRunning implements Service by returning true or false depending on the +// service's state. func (bs *BaseService) IsRunning() bool { return atomic.LoadUint32(&bs.started) == 1 && atomic.LoadUint32(&bs.stopped) == 0 } +// Wait blocks until the service is stopped. func (bs *BaseService) Wait() { - <-bs.Quit + <-bs.quit } -// Implements Servce +// String implements Servce by returning a string representation of the service. func (bs *BaseService) String() string { return bs.name } -//---------------------------------------- - -type QuitService struct { - BaseService -} - -func NewQuitService(logger log.Logger, name string, impl Service) *QuitService { - if logger != nil { - logger.Info("QuitService is deprecated, use BaseService instead") - } - return &QuitService{ - BaseService: *NewBaseService(logger, name, impl), - } +// Quit Implements Service by returning a quit channel. +func (bs *BaseService) Quit() <-chan struct{} { + return bs.quit } diff --git a/common/string.go b/common/string.go index 6924e6a5..a6895eb2 100644 --- a/common/string.go +++ b/common/string.go @@ -29,7 +29,7 @@ func LeftPadString(s string, totalLength int) string { // IsHex returns true for non-empty hex-string prefixed with "0x" func IsHex(s string) bool { - if len(s) > 2 && s[:2] == "0x" { + if len(s) > 2 && strings.EqualFold(s[:2], "0x") { _, err := hex.DecodeString(s[2:]) return err == nil } diff --git a/common/string_test.go b/common/string_test.go index a82f1022..b8a917c1 100644 --- a/common/string_test.go +++ b/common/string_test.go @@ -12,3 +12,21 @@ func TestStringInSlice(t *testing.T) { assert.True(t, StringInSlice("", []string{""})) assert.False(t, StringInSlice("", []string{})) } + +func TestIsHex(t *testing.T) { + notHex := []string{ + "", " ", "a", "x", "0", "0x", "0X", "0x ", "0X ", "0X a", + "0xf ", "0x f", "0xp", "0x-", + "0xf", "0XBED", "0xF", "0xbed", // Odd lengths + } + for _, v := range notHex { + assert.False(t, IsHex(v), "%q is not hex", v) + } + hex := []string{ + "0x00", "0x0a", "0x0F", "0xFFFFFF", "0Xdeadbeef", "0x0BED", + "0X12", "0X0A", + } + for _, v := range hex { + assert.True(t, IsHex(v), "%q is hex", v) + } +} diff --git a/common/types.pb.go b/common/types.pb.go new file mode 100644 index 00000000..047b7aee --- /dev/null +++ b/common/types.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: common/types.proto + +/* +Package common is a generated protocol buffer package. + +It is generated from these files: + common/types.proto + +It has these top-level messages: + KVPair + KI64Pair +*/ +//nolint: gas +package common + +import proto "github.com/gogo/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +// Define these here for compatibility but use tmlibs/common.KVPair. +type KVPair struct { + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` +} + +func (m *KVPair) Reset() { *m = KVPair{} } +func (m *KVPair) String() string { return proto.CompactTextString(m) } +func (*KVPair) ProtoMessage() {} +func (*KVPair) Descriptor() ([]byte, []int) { return fileDescriptorTypes, []int{0} } + +func (m *KVPair) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *KVPair) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +// Define these here for compatibility but use tmlibs/common.KI64Pair. +type KI64Pair struct { + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value int64 `protobuf:"varint,2,opt,name=value,proto3" json:"value,omitempty"` +} + +func (m *KI64Pair) Reset() { *m = KI64Pair{} } +func (m *KI64Pair) String() string { return proto.CompactTextString(m) } +func (*KI64Pair) ProtoMessage() {} +func (*KI64Pair) Descriptor() ([]byte, []int) { return fileDescriptorTypes, []int{1} } + +func (m *KI64Pair) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *KI64Pair) GetValue() int64 { + if m != nil { + return m.Value + } + return 0 +} + +func init() { + proto.RegisterType((*KVPair)(nil), "common.KVPair") + proto.RegisterType((*KI64Pair)(nil), "common.KI64Pair") +} + +func init() { proto.RegisterFile("common/types.proto", fileDescriptorTypes) } + +var fileDescriptorTypes = []byte{ + // 137 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4a, 0xce, 0xcf, 0xcd, + 0xcd, 0xcf, 0xd3, 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, + 0x83, 0x88, 0x49, 0xe9, 0xa6, 0x67, 0x96, 0x64, 0x94, 0x26, 0xe9, 0x25, 0xe7, 0xe7, 0xea, 0xa7, + 0xe7, 0xa7, 0xe7, 0xeb, 0x83, 0xa5, 0x93, 0x4a, 0xd3, 0xc0, 0x3c, 0x30, 0x07, 0xcc, 0x82, 0x68, + 0x53, 0x32, 0xe0, 0x62, 0xf3, 0x0e, 0x0b, 0x48, 0xcc, 0x2c, 0x12, 0x12, 0xe0, 0x62, 0xce, 0x4e, + 0xad, 0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, 0x31, 0x85, 0x44, 0xb8, 0x58, 0xcb, 0x12, + 0x73, 0x4a, 0x53, 0x25, 0x98, 0xc0, 0x62, 0x10, 0x8e, 0x92, 0x11, 0x17, 0x87, 0xb7, 0xa7, 0x99, + 0x09, 0x31, 0x7a, 0x98, 0xa1, 0x7a, 0x92, 0xd8, 0xc0, 0x96, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0x5c, 0xb8, 0x46, 0xc5, 0xb9, 0x00, 0x00, 0x00, +} diff --git a/common/types.proto b/common/types.proto new file mode 100644 index 00000000..94abcccc --- /dev/null +++ b/common/types.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; +package common; + +// For more information on gogo.proto, see: +// https://github.com/gogo/protobuf/blob/master/extensions.md +// NOTE: Try really hard not to use custom types, +// it's often complicated, broken, nor not worth it. +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + + +//---------------------------------------- +// Abstract types + +// Define these here for compatibility but use tmlibs/common.KVPair. +message KVPair { + bytes key = 1; + bytes value = 2; +} + +// Define these here for compatibility but use tmlibs/common.KI64Pair. +message KI64Pair { + bytes key = 1; + int64 value = 2; +} diff --git a/db/backend_test.go b/db/backend_test.go new file mode 100644 index 00000000..80fbbb14 --- /dev/null +++ b/db/backend_test.go @@ -0,0 +1,151 @@ +package db + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" +) + +func cleanupDBDir(dir, name string) { + os.RemoveAll(filepath.Join(dir, name) + ".db") +} + +func testBackendGetSetDelete(t *testing.T, backend DBBackendType) { + // Default + dir, dirname := cmn.Tempdir(fmt.Sprintf("test_backend_%s_", backend)) + defer dir.Close() + db := NewDB("testdb", backend, dirname) + + // A nonexistent key should return nil, even if the key is empty + require.Nil(t, db.Get([]byte(""))) + + // A nonexistent key should return nil, even if the key is nil + require.Nil(t, db.Get(nil)) + + // A nonexistent key should return nil. + key := []byte("abc") + require.Nil(t, db.Get(key)) + + // Set empty value. + db.Set(key, []byte("")) + require.NotNil(t, db.Get(key)) + require.Empty(t, db.Get(key)) + + // Set nil value. + db.Set(key, nil) + require.NotNil(t, db.Get(key)) + require.Empty(t, db.Get(key)) + + // Delete. + db.Delete(key) + require.Nil(t, db.Get(key)) +} + +func TestBackendsGetSetDelete(t *testing.T) { + for dbType, _ := range backends { + testBackendGetSetDelete(t, dbType) + } +} + +func withDB(t *testing.T, creator dbCreator, fn func(DB)) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db, err := creator(name, "") + defer cleanupDBDir("", name) + assert.Nil(t, err) + fn(db) + db.Close() +} + +func TestBackendsNilKeys(t *testing.T) { + + // Test all backends. + for dbType, creator := range backends { + withDB(t, creator, func(db DB) { + t.Run(fmt.Sprintf("Testing %s", dbType), func(t *testing.T) { + + // Nil keys are treated as the empty key for most operations. + expect := func(key, value []byte) { + if len(key) == 0 { // nil or empty + assert.Equal(t, db.Get(nil), db.Get([]byte(""))) + assert.Equal(t, db.Has(nil), db.Has([]byte(""))) + } + assert.Equal(t, db.Get(key), value) + assert.Equal(t, db.Has(key), value != nil) + } + + // Not set + expect(nil, nil) + + // Set nil value + db.Set(nil, nil) + expect(nil, []byte("")) + + // Set empty value + db.Set(nil, []byte("")) + expect(nil, []byte("")) + + // Set nil, Delete nil + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set nil, Delete empty + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // Set empty, Delete nil + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set empty, Delete empty + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // SetSync nil, DeleteSync nil + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync nil, DeleteSync empty + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + + // SetSync empty, DeleteSync nil + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync empty, DeleteSync empty + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + }) + }) + } +} + +func TestGoLevelDBBackend(t *testing.T) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, GoLevelDBBackend, "") + defer cleanupDBDir("", name) + + _, ok := db.(*GoLevelDB) + assert.True(t, ok) +} diff --git a/db/c_level_db.go b/db/c_level_db.go index b1ae49a1..a5913788 100644 --- a/db/c_level_db.go +++ b/db/c_level_db.go @@ -3,22 +3,23 @@ package db import ( + "bytes" "fmt" - "path" + "path/filepath" "github.com/jmhodges/levigo" - - . "github.com/tendermint/tmlibs/common" ) func init() { dbCreator := func(name string, dir string) (DB, error) { return NewCLevelDB(name, dir) } - registerDBCreator(LevelDBBackendStr, dbCreator, true) - registerDBCreator(CLevelDBBackendStr, dbCreator, false) + registerDBCreator(LevelDBBackend, dbCreator, true) + registerDBCreator(CLevelDBBackend, dbCreator, false) } +var _ DB = (*CLevelDB)(nil) + type CLevelDB struct { db *levigo.DB ro *levigo.ReadOptions @@ -27,7 +28,7 @@ type CLevelDB struct { } func NewCLevelDB(name string, dir string) (*CLevelDB, error) { - dbPath := path.Join(dir, name+".db") + dbPath := filepath.Join(dir, name+".db") opts := levigo.NewOptions() opts.SetCache(levigo.NewLRUCache(1 << 30)) @@ -49,39 +50,56 @@ func NewCLevelDB(name string, dir string) (*CLevelDB, error) { return database, nil } +// Implements DB. func (db *CLevelDB) Get(key []byte) []byte { + key = nonNilBytes(key) res, err := db.db.Get(db.ro, key) if err != nil { - PanicCrisis(err) + panic(err) } return res } +// Implements DB. +func (db *CLevelDB) Has(key []byte) bool { + return db.Get(key) != nil +} + +// Implements DB. func (db *CLevelDB) Set(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(db.wo, key, value) if err != nil { - PanicCrisis(err) + panic(err) } } +// Implements DB. func (db *CLevelDB) SetSync(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(db.woSync, key, value) if err != nil { - PanicCrisis(err) + panic(err) } } +// Implements DB. func (db *CLevelDB) Delete(key []byte) { + key = nonNilBytes(key) err := db.db.Delete(db.wo, key) if err != nil { - PanicCrisis(err) + panic(err) } } +// Implements DB. func (db *CLevelDB) DeleteSync(key []byte) { + key = nonNilBytes(key) err := db.db.Delete(db.woSync, key) if err != nil { - PanicCrisis(err) + panic(err) } } @@ -89,6 +107,7 @@ func (db *CLevelDB) DB() *levigo.DB { return db.db } +// Implements DB. func (db *CLevelDB) Close() { db.db.Close() db.ro.Close() @@ -96,57 +115,165 @@ func (db *CLevelDB) Close() { db.woSync.Close() } +// Implements DB. func (db *CLevelDB) Print() { - iter := db.db.NewIterator(db.ro) - defer iter.Close() - for iter.Seek(nil); iter.Valid(); iter.Next() { - key := iter.Key() - value := iter.Value() + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + key := itr.Key() + value := itr.Value() fmt.Printf("[%X]:\t[%X]\n", key, value) } } +// Implements DB. func (db *CLevelDB) Stats() map[string]string { // TODO: Find the available properties for the C LevelDB implementation keys := []string{} stats := make(map[string]string) for _, key := range keys { - str, err := db.db.GetProperty(key) - if err == nil { - stats[key] = str - } + str := db.db.PropertyValue(key) + stats[key] = str } return stats } -func (db *CLevelDB) Iterator() Iterator { - return db.db.NewIterator(nil, nil) -} +//---------------------------------------- +// Batch +// Implements DB. func (db *CLevelDB) NewBatch() Batch { batch := levigo.NewWriteBatch() return &cLevelDBBatch{db, batch} } -//-------------------------------------------------------------------------------- - type cLevelDBBatch struct { db *CLevelDB batch *levigo.WriteBatch } +// Implements Batch. func (mBatch *cLevelDBBatch) Set(key, value []byte) { mBatch.batch.Put(key, value) } +// Implements Batch. func (mBatch *cLevelDBBatch) Delete(key []byte) { mBatch.batch.Delete(key) } +// Implements Batch. func (mBatch *cLevelDBBatch) Write() { err := mBatch.db.db.Write(mBatch.db.wo, mBatch.batch) if err != nil { - PanicCrisis(err) + panic(err) + } +} + +//---------------------------------------- +// Iterator +// NOTE This is almost identical to db/go_level_db.Iterator +// Before creating a third version, refactor. + +func (db *CLevelDB) Iterator(start, end []byte) Iterator { + itr := db.db.NewIterator(db.ro) + return newCLevelDBIterator(itr, start, end, false) +} + +func (db *CLevelDB) ReverseIterator(start, end []byte) Iterator { + panic("not implemented yet") // XXX +} + +var _ Iterator = (*cLevelDBIterator)(nil) + +type cLevelDBIterator struct { + source *levigo.Iterator + start, end []byte + isReverse bool + isInvalid bool +} + +func newCLevelDBIterator(source *levigo.Iterator, start, end []byte, isReverse bool) *cLevelDBIterator { + if isReverse { + panic("not implemented yet") // XXX + } + if start != nil { + source.Seek(start) + } else { + source.SeekToFirst() + } + return &cLevelDBIterator{ + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, + } +} + +func (itr cLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +func (itr cLevelDBIterator) Valid() bool { + + // Once invalid, forever invalid. + if itr.isInvalid { + return false + } + + // Panic on DB error. No way to recover. + itr.assertNoError() + + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false + } + + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true + return false + } + + // It's valid. + return true +} + +func (itr cLevelDBIterator) Key() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Key() +} + +func (itr cLevelDBIterator) Value() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Value() +} + +func (itr cLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + itr.source.Next() +} + +func (itr cLevelDBIterator) Close() { + itr.source.Close() +} + +func (itr cLevelDBIterator) assertNoError() { + if err := itr.source.GetError(); err != nil { + panic(err) + } +} + +func (itr cLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("cLevelDBIterator is invalid") } } diff --git a/db/c_level_db_test.go b/db/c_level_db_test.go index e7336cc5..34bb7227 100644 --- a/db/c_level_db_test.go +++ b/db/c_level_db_test.go @@ -7,7 +7,8 @@ import ( "fmt" "testing" - . "github.com/tendermint/tmlibs/common" + "github.com/stretchr/testify/assert" + cmn "github.com/tendermint/tmlibs/common" ) func BenchmarkRandomReadsWrites2(b *testing.B) { @@ -18,7 +19,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { for i := 0; i < int(numItems); i++ { internal[int64(i)] = int64(0) } - db, err := NewCLevelDB(Fmt("test_%x", RandStr(12)), "") + db, err := NewCLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") if err != nil { b.Fatal(err.Error()) return @@ -30,7 +31,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { for i := 0; i < b.N; i++ { // Write something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) internal[idx] += 1 val := internal[idx] idxBytes := int642Bytes(int64(idx)) @@ -43,7 +44,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { } // Read something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) val := internal[idx] idxBytes := int642Bytes(int64(idx)) valBytes := db.Get(idxBytes) @@ -84,3 +85,12 @@ func bytes2Int64(buf []byte) int64 { return int64(binary.BigEndian.Uint64(buf)) } */ + +func TestCLevelDBBackend(t *testing.T) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, LevelDBBackend, "") + defer cleanupDBDir("", name) + + _, ok := db.(*CLevelDB) + assert.True(t, ok) +} diff --git a/db/common_test.go b/db/common_test.go new file mode 100644 index 00000000..1b0f0041 --- /dev/null +++ b/db/common_test.go @@ -0,0 +1,155 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" +) + +func checkValid(t *testing.T, itr Iterator, expected bool) { + valid := itr.Valid() + require.Equal(t, expected, valid) +} + +func checkNext(t *testing.T, itr Iterator, expected bool) { + itr.Next() + valid := itr.Valid() + require.Equal(t, expected, valid) +} + +func checkNextPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Next() }, "checkNextPanics expected panic but didn't") +} + +func checkItem(t *testing.T, itr Iterator, key []byte, value []byte) { + k, v := itr.Key(), itr.Value() + assert.Exactly(t, key, k) + assert.Exactly(t, value, v) +} + +func checkInvalid(t *testing.T, itr Iterator) { + checkValid(t, itr, false) + checkKeyPanics(t, itr) + checkValuePanics(t, itr) + checkNextPanics(t, itr) +} + +func checkKeyPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkKeyPanics expected panic but didn't") +} + +func checkValuePanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkValuePanics expected panic but didn't") +} + +func newTempDB(t *testing.T, backend DBBackendType) (db DB) { + dir, dirname := cmn.Tempdir("test_go_iterator") + db = NewDB("testdb", backend, dirname) + dir.Close() + return db +} + +func TestDBIteratorSingleKey(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(nil, nil) + + checkValid(t, itr, true) + checkNext(t, itr, false) + checkValid(t, itr, false) + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorTwoKeys(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + db.SetSync(bz("2"), bz("value_1")) + + { // Fail by calling Next too much + itr := db.Iterator(nil, nil) + checkValid(t, itr, true) + + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkNext(t, itr, false) + checkValid(t, itr, false) + + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + } + }) + } +} + +func TestDBIteratorMany(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + keys := make([][]byte, 100) + for i := 0; i < 100; i++ { + keys[i] = []byte{byte(i)} + } + + value := []byte{5} + for _, k := range keys { + db.Set(k, value) + } + + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + assert.Equal(t, db.Get(itr.Key()), itr.Value()) + } + }) + } +} + +func TestDBIteratorEmpty(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(nil, nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorEmptyBeginAfter(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(bz("1"), nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorNonemptyBeginAfter(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(bz("2"), nil) + + checkInvalid(t, itr) + }) + } +} diff --git a/db/db.go b/db/db.go index 8156c1e9..86993766 100644 --- a/db/db.go +++ b/db/db.go @@ -1,53 +1,25 @@ package db -import . "github.com/tendermint/tmlibs/common" +import "fmt" -type DB interface { - Get([]byte) []byte - Set([]byte, []byte) - SetSync([]byte, []byte) - Delete([]byte) - DeleteSync([]byte) - Close() - NewBatch() Batch - Iterator() Iterator - IteratorPrefix([]byte) Iterator +//---------------------------------------- +// Main entry - // For debugging - Print() - Stats() map[string]string -} - -type Batch interface { - Set(key, value []byte) - Delete(key []byte) - Write() -} - -type Iterator interface { - Next() bool - - Key() []byte - Value() []byte - - Release() - Error() error -} - -//----------------------------------------------------------------------------- +type DBBackendType string const ( - LevelDBBackendStr = "leveldb" // legacy, defaults to goleveldb. - CLevelDBBackendStr = "cleveldb" - GoLevelDBBackendStr = "goleveldb" - MemDBBackendStr = "memdb" + LevelDBBackend DBBackendType = "leveldb" // legacy, defaults to goleveldb unless +gcc + CLevelDBBackend DBBackendType = "cleveldb" + GoLevelDBBackend DBBackendType = "goleveldb" + MemDBBackend DBBackendType = "memdb" + FSDBBackend DBBackendType = "fsdb" // using the filesystem naively ) type dbCreator func(name string, dir string) (DB, error) -var backends = map[string]dbCreator{} +var backends = map[DBBackendType]dbCreator{} -func registerDBCreator(backend string, creator dbCreator, force bool) { +func registerDBCreator(backend DBBackendType, creator dbCreator, force bool) { _, ok := backends[backend] if !force && ok { return @@ -55,10 +27,10 @@ func registerDBCreator(backend string, creator dbCreator, force bool) { backends[backend] = creator } -func NewDB(name string, backend string, dir string) DB { +func NewDB(name string, backend DBBackendType, dir string) DB { db, err := backends[backend](name, dir) if err != nil { - PanicSanity(Fmt("Error initializing DB: %v", err)) + panic(fmt.Sprintf("Error initializing DB: %v", err)) } return db } diff --git a/db/fsdb.go b/db/fsdb.go new file mode 100644 index 00000000..578c1785 --- /dev/null +++ b/db/fsdb.go @@ -0,0 +1,254 @@ +package db + +import ( + "fmt" + "io/ioutil" + "net/url" + "os" + "path/filepath" + "sort" + "sync" + + "github.com/pkg/errors" + cmn "github.com/tendermint/tmlibs/common" +) + +const ( + keyPerm = os.FileMode(0600) + dirPerm = os.FileMode(0700) +) + +func init() { + registerDBCreator(FSDBBackend, func(name string, dir string) (DB, error) { + dbPath := filepath.Join(dir, name+".db") + return NewFSDB(dbPath), nil + }, false) +} + +var _ DB = (*FSDB)(nil) + +// It's slow. +type FSDB struct { + mtx sync.Mutex + dir string +} + +func NewFSDB(dir string) *FSDB { + err := os.MkdirAll(dir, dirPerm) + if err != nil { + panic(errors.Wrap(err, "Creating FSDB dir "+dir)) + } + database := &FSDB{ + dir: dir, + } + return database +} + +func (db *FSDB) Get(key []byte) []byte { + db.mtx.Lock() + defer db.mtx.Unlock() + key = escapeKey(key) + + path := db.nameToPath(key) + value, err := read(path) + if os.IsNotExist(err) { + return nil + } else if err != nil { + panic(errors.Wrapf(err, "Getting key %s (0x%X)", string(key), key)) + } + return value +} + +func (db *FSDB) Has(key []byte) bool { + db.mtx.Lock() + defer db.mtx.Unlock() + key = escapeKey(key) + + path := db.nameToPath(key) + return cmn.FileExists(path) +} + +func (db *FSDB) Set(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +func (db *FSDB) SetSync(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) SetNoLock(key []byte, value []byte) { + key = escapeKey(key) + value = nonNilBytes(value) + path := db.nameToPath(key) + err := write(path, value) + if err != nil { + panic(errors.Wrapf(err, "Setting key %s (0x%X)", string(key), key)) + } +} + +func (db *FSDB) Delete(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +func (db *FSDB) DeleteSync(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) DeleteNoLock(key []byte) { + key = escapeKey(key) + path := db.nameToPath(key) + err := remove(path) + if os.IsNotExist(err) { + return + } else if err != nil { + panic(errors.Wrapf(err, "Removing key %s (0x%X)", string(key), key)) + } +} + +func (db *FSDB) Close() { + // Nothing to do. +} + +func (db *FSDB) Print() { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Print not yet implemented") +} + +func (db *FSDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Stats not yet implemented") +} + +func (db *FSDB) NewBatch() Batch { + db.mtx.Lock() + defer db.mtx.Unlock() + + // Not sure we would ever want to try... + // It doesn't seem easy for general filesystems. + panic("FSDB.NewBatch not yet implemented") +} + +func (db *FSDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +func (db *FSDB) Iterator(start, end []byte) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + keys, err := list(db.dir, start, end) + if err != nil { + panic(errors.Wrapf(err, "Listing keys in %s", db.dir)) + } + sort.Strings(keys) + return newMemDBIterator(db, keys, start, end) +} + +func (db *FSDB) ReverseIterator(start, end []byte) Iterator { + panic("not implemented yet") // XXX +} + +func (db *FSDB) nameToPath(name []byte) string { + n := url.PathEscape(string(name)) + return filepath.Join(db.dir, n) +} + +// Read some bytes to a file. +// CONTRACT: returns os errors directly without wrapping. +func read(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + d, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + return d, nil +} + +// Write some bytes from a file. +// CONTRACT: returns os errors directly without wrapping. +func write(path string, d []byte) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, keyPerm) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(d) + if err != nil { + return err + } + err = f.Sync() + return err +} + +// Remove a file. +// CONTRACT: returns os errors directly without wrapping. +func remove(path string) error { + return os.Remove(path) +} + +// List keys in a directory, stripping of escape sequences and dir portions. +// CONTRACT: returns os errors directly without wrapping. +func list(dirPath string, start, end []byte) ([]string, error) { + dir, err := os.Open(dirPath) + if err != nil { + return nil, err + } + defer dir.Close() + + names, err := dir.Readdirnames(0) + if err != nil { + return nil, err + } + var keys []string + for _, name := range names { + n, err := url.PathUnescape(name) + if err != nil { + return nil, fmt.Errorf("Failed to unescape %s while listing", name) + } + key := unescapeKey([]byte(n)) + if IsKeyInDomain(key, start, end, false) { + keys = append(keys, string(key)) + } + } + return keys, nil +} + +// To support empty or nil keys, while the file system doesn't allow empty +// filenames. +func escapeKey(key []byte) []byte { + return []byte("k_" + string(key)) +} +func unescapeKey(escKey []byte) []byte { + if len(escKey) < 2 { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + if string(escKey[:2]) != "k_" { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + return escKey[2:] +} diff --git a/db/go_level_db.go b/db/go_level_db.go index 4abd7611..9fed329b 100644 --- a/db/go_level_db.go +++ b/db/go_level_db.go @@ -1,14 +1,14 @@ package db import ( + "bytes" "fmt" - "path" + "path/filepath" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" - "github.com/syndtr/goleveldb/leveldb/util" . "github.com/tendermint/tmlibs/common" ) @@ -17,58 +17,79 @@ func init() { dbCreator := func(name string, dir string) (DB, error) { return NewGoLevelDB(name, dir) } - registerDBCreator(LevelDBBackendStr, dbCreator, false) - registerDBCreator(GoLevelDBBackendStr, dbCreator, false) + registerDBCreator(LevelDBBackend, dbCreator, false) + registerDBCreator(GoLevelDBBackend, dbCreator, false) } +var _ DB = (*GoLevelDB)(nil) + type GoLevelDB struct { db *leveldb.DB } func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { - dbPath := path.Join(dir, name+".db") + dbPath := filepath.Join(dir, name+".db") db, err := leveldb.OpenFile(dbPath, nil) if err != nil { return nil, err } - database := &GoLevelDB{db: db} + database := &GoLevelDB{ + db: db, + } return database, nil } +// Implements DB. func (db *GoLevelDB) Get(key []byte) []byte { + key = nonNilBytes(key) res, err := db.db.Get(key, nil) if err != nil { if err == errors.ErrNotFound { return nil } else { - PanicCrisis(err) + panic(err) } } return res } +// Implements DB. +func (db *GoLevelDB) Has(key []byte) bool { + return db.Get(key) != nil +} + +// Implements DB. func (db *GoLevelDB) Set(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(key, value, nil) if err != nil { PanicCrisis(err) } } +// Implements DB. func (db *GoLevelDB) SetSync(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(key, value, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) } } +// Implements DB. func (db *GoLevelDB) Delete(key []byte) { + key = nonNilBytes(key) err := db.db.Delete(key, nil) if err != nil { PanicCrisis(err) } } +// Implements DB. func (db *GoLevelDB) DeleteSync(key []byte) { + key = nonNilBytes(key) err := db.db.Delete(key, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) @@ -79,10 +100,12 @@ func (db *GoLevelDB) DB() *leveldb.DB { return db.db } +// Implements DB. func (db *GoLevelDB) Close() { db.db.Close() } +// Implements DB. func (db *GoLevelDB) Print() { str, _ := db.db.GetProperty("leveldb.stats") fmt.Printf("%v\n", str) @@ -95,6 +118,7 @@ func (db *GoLevelDB) Print() { } } +// Implements DB. func (db *GoLevelDB) Stats() map[string]string { keys := []string{ "leveldb.num-files-at-level{n}", @@ -117,71 +141,150 @@ func (db *GoLevelDB) Stats() map[string]string { return stats } -type goLevelDBIterator struct { - source iterator.Iterator -} - -// Key returns a copy of the current key. -func (it *goLevelDBIterator) Key() []byte { - key := it.source.Key() - k := make([]byte, len(key)) - copy(k, key) - - return k -} - -// Value returns a copy of the current value. -func (it *goLevelDBIterator) Value() []byte { - val := it.source.Value() - v := make([]byte, len(val)) - copy(v, val) - - return v -} - -func (it *goLevelDBIterator) Error() error { - return it.source.Error() -} - -func (it *goLevelDBIterator) Next() bool { - return it.source.Next() -} - -func (it *goLevelDBIterator) Release() { - it.source.Release() -} - -func (db *GoLevelDB) Iterator() Iterator { - return &goLevelDBIterator{db.db.NewIterator(nil, nil)} -} - -func (db *GoLevelDB) IteratorPrefix(prefix []byte) Iterator { - return &goLevelDBIterator{db.db.NewIterator(util.BytesPrefix(prefix), nil)} -} +//---------------------------------------- +// Batch +// Implements DB. func (db *GoLevelDB) NewBatch() Batch { batch := new(leveldb.Batch) return &goLevelDBBatch{db, batch} } -//-------------------------------------------------------------------------------- - type goLevelDBBatch struct { db *GoLevelDB batch *leveldb.Batch } +// Implements Batch. func (mBatch *goLevelDBBatch) Set(key, value []byte) { mBatch.batch.Put(key, value) } +// Implements Batch. func (mBatch *goLevelDBBatch) Delete(key []byte) { mBatch.batch.Delete(key) } +// Implements Batch. func (mBatch *goLevelDBBatch) Write() { err := mBatch.db.db.Write(mBatch.batch, nil) if err != nil { - PanicCrisis(err) + panic(err) + } +} + +//---------------------------------------- +// Iterator +// NOTE This is almost identical to db/c_level_db.Iterator +// Before creating a third version, refactor. + +// Implements DB. +func (db *GoLevelDB) Iterator(start, end []byte) Iterator { + itr := db.db.NewIterator(nil, nil) + return newGoLevelDBIterator(itr, start, end, false) +} + +// Implements DB. +func (db *GoLevelDB) ReverseIterator(start, end []byte) Iterator { + panic("not implemented yet") // XXX +} + +type goLevelDBIterator struct { + source iterator.Iterator + start []byte + end []byte + isReverse bool + isInvalid bool +} + +var _ Iterator = (*goLevelDBIterator)(nil) + +func newGoLevelDBIterator(source iterator.Iterator, start, end []byte, isReverse bool) *goLevelDBIterator { + if isReverse { + panic("not implemented yet") // XXX + } + source.Seek(start) + return &goLevelDBIterator{ + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, + } +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Valid() bool { + + // Once invalid, forever invalid. + if itr.isInvalid { + return false + } + + // Panic on DB error. No way to recover. + itr.assertNoError() + + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false + } + + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true + return false + } + + // Valid + return true +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Key() []byte { + // Key returns a copy of the current key. + // See https://github.com/syndtr/goleveldb/blob/52c212e6c196a1404ea59592d3f1c227c9f034b2/leveldb/iterator/iter.go#L88 + itr.assertNoError() + itr.assertIsValid() + return cp(itr.source.Key()) +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Value() []byte { + // Value returns a copy of the current value. + // See https://github.com/syndtr/goleveldb/blob/52c212e6c196a1404ea59592d3f1c227c9f034b2/leveldb/iterator/iter.go#L88 + itr.assertNoError() + itr.assertIsValid() + return cp(itr.source.Value()) +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + itr.source.Next() +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Close() { + itr.source.Release() +} + +func (itr *goLevelDBIterator) assertNoError() { + if err := itr.source.Error(); err != nil { + panic(err) + } +} + +func (itr goLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("goLevelDBIterator is invalid") } } diff --git a/db/go_level_db_test.go b/db/go_level_db_test.go index 2cd3192c..88b6730f 100644 --- a/db/go_level_db_test.go +++ b/db/go_level_db_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) func BenchmarkRandomReadsWrites(b *testing.B) { @@ -17,7 +17,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { for i := 0; i < int(numItems); i++ { internal[int64(i)] = int64(0) } - db, err := NewGoLevelDB(Fmt("test_%x", RandStr(12)), "") + db, err := NewGoLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") if err != nil { b.Fatal(err.Error()) return @@ -29,7 +29,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { for i := 0; i < b.N; i++ { // Write something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) internal[idx] += 1 val := internal[idx] idxBytes := int642Bytes(int64(idx)) @@ -42,7 +42,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { } // Read something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) val := internal[idx] idxBytes := int642Bytes(int64(idx)) valBytes := db.Get(idxBytes) diff --git a/db/mem_batch.go b/db/mem_batch.go new file mode 100644 index 00000000..7072d931 --- /dev/null +++ b/db/mem_batch.go @@ -0,0 +1,50 @@ +package db + +import "sync" + +type atomicSetDeleter interface { + Mutex() *sync.Mutex + SetNoLock(key, value []byte) + DeleteNoLock(key []byte) +} + +type memBatch struct { + db atomicSetDeleter + ops []operation +} + +type opType int + +const ( + opTypeSet opType = 1 + opTypeDelete opType = 2 +) + +type operation struct { + opType + key []byte + value []byte +} + +func (mBatch *memBatch) Set(key, value []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) +} + +func (mBatch *memBatch) Delete(key []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) +} + +func (mBatch *memBatch) Write() { + mtx := mBatch.db.Mutex() + mtx.Lock() + defer mtx.Unlock() + + for _, op := range mBatch.ops { + switch op.opType { + case opTypeSet: + mBatch.db.SetNoLock(op.key, op.value) + case opTypeDelete: + mBatch.db.DeleteNoLock(op.key) + } + } +} diff --git a/db/mem_db.go b/db/mem_db.go index 2f507321..f2c484fa 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -3,56 +3,96 @@ package db import ( "fmt" "sort" - "strings" "sync" ) func init() { - registerDBCreator(MemDBBackendStr, func(name string, dir string) (DB, error) { + registerDBCreator(MemDBBackend, func(name string, dir string) (DB, error) { return NewMemDB(), nil }, false) } +var _ DB = (*MemDB)(nil) + type MemDB struct { mtx sync.Mutex db map[string][]byte } func NewMemDB() *MemDB { - database := &MemDB{db: make(map[string][]byte)} + database := &MemDB{ + db: make(map[string][]byte), + } return database } +// Implements DB. func (db *MemDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() + key = nonNilBytes(key) + return db.db[string(key)] } +// Implements DB. +func (db *MemDB) Has(key []byte) bool { + db.mtx.Lock() + defer db.mtx.Unlock() + key = nonNilBytes(key) + + _, ok := db.db[string(key)] + return ok +} + +// Implements DB. func (db *MemDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - db.db[string(key)] = value + + db.SetNoLock(key, value) } +// Implements DB. func (db *MemDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// Implements atomicSetDeleter. +func (db *MemDB) SetNoLock(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + db.db[string(key)] = value } +// Implements DB. func (db *MemDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - delete(db.db, string(key)) + + db.DeleteNoLock(key) } +// Implements DB. func (db *MemDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// Implements atomicSetDeleter. +func (db *MemDB) DeleteNoLock(key []byte) { + key = nonNilBytes(key) + delete(db.db, string(key)) } +// Implements DB. func (db *MemDB) Close() { // Close is a noop since for an in-memory // database, we don't have a destination @@ -61,120 +101,143 @@ func (db *MemDB) Close() { // See the discussion in https://github.com/tendermint/tmlibs/pull/56 } +// Implements DB. func (db *MemDB) Print() { db.mtx.Lock() defer db.mtx.Unlock() + for key, value := range db.db { fmt.Printf("[%X]:\t[%X]\n", []byte(key), value) } } +// Implements DB. func (db *MemDB) Stats() map[string]string { - stats := make(map[string]string) - stats["database.type"] = "memDB" - return stats -} - -type memDBIterator struct { - last int - keys []string - db *MemDB -} - -func newMemDBIterator() *memDBIterator { - return &memDBIterator{} -} - -func (it *memDBIterator) Next() bool { - if it.last >= len(it.keys)-1 { - return false - } - it.last++ - return true -} - -func (it *memDBIterator) Key() []byte { - return []byte(it.keys[it.last]) -} - -func (it *memDBIterator) Value() []byte { - return it.db.Get(it.Key()) -} - -func (it *memDBIterator) Release() { - it.db = nil - it.keys = nil -} - -func (it *memDBIterator) Error() error { - return nil -} - -func (db *MemDB) Iterator() Iterator { - return db.IteratorPrefix([]byte{}) -} - -func (db *MemDB) IteratorPrefix(prefix []byte) Iterator { - it := newMemDBIterator() - it.db = db - it.last = -1 - db.mtx.Lock() defer db.mtx.Unlock() - // unfortunately we need a copy of all of the keys - for key, _ := range db.db { - if strings.HasPrefix(key, string(prefix)) { - it.keys = append(it.keys, key) - } - } - // and we need to sort them - sort.Strings(it.keys) - return it + stats := make(map[string]string) + stats["database.type"] = "memDB" + stats["database.size"] = fmt.Sprintf("%d", len(db.db)) + return stats } +//---------------------------------------- +// Batch + +// Implements DB. func (db *MemDB) NewBatch() Batch { - return &memDBBatch{db, nil} + db.mtx.Lock() + defer db.mtx.Unlock() + + return &memBatch{db, nil} } -//-------------------------------------------------------------------------------- - -type memDBBatch struct { - db *MemDB - ops []operation +func (db *MemDB) Mutex() *sync.Mutex { + return &(db.mtx) } -type opType int +//---------------------------------------- +// Iterator -const ( - opTypeSet = 1 - opTypeDelete = 2 -) +// Implements DB. +func (db *MemDB) Iterator(start, end []byte) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() -type operation struct { - opType - key []byte - value []byte + keys := db.getSortedKeys(start, end, false) + return newMemDBIterator(db, keys, start, end) } -func (mBatch *memDBBatch) Set(key, value []byte) { - mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) +// Implements DB. +func (db *MemDB) ReverseIterator(start, end []byte) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() + + keys := db.getSortedKeys(end, start, true) + return newMemDBIterator(db, keys, start, end) } -func (mBatch *memDBBatch) Delete(key []byte) { - mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) +// We need a copy of all of the keys. +// Not the best, but probably not a bottleneck depending. +type memDBIterator struct { + db DB + cur int + keys []string + start []byte + end []byte } -func (mBatch *memDBBatch) Write() { - mBatch.db.mtx.Lock() - defer mBatch.db.mtx.Unlock() +var _ Iterator = (*memDBIterator)(nil) - for _, op := range mBatch.ops { - if op.opType == opTypeSet { - mBatch.db.db[string(op.key)] = op.value - } else if op.opType == opTypeDelete { - delete(mBatch.db.db, string(op.key)) +// Keys is expected to be in reverse order for reverse iterators. +func newMemDBIterator(db DB, keys []string, start, end []byte) *memDBIterator { + return &memDBIterator{ + db: db, + cur: 0, + keys: keys, + start: start, + end: end, + } +} + +// Implements Iterator. +func (itr *memDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +// Implements Iterator. +func (itr *memDBIterator) Valid() bool { + return 0 <= itr.cur && itr.cur < len(itr.keys) +} + +// Implements Iterator. +func (itr *memDBIterator) Next() { + itr.assertIsValid() + itr.cur++ +} + +// Implements Iterator. +func (itr *memDBIterator) Key() []byte { + itr.assertIsValid() + return []byte(itr.keys[itr.cur]) +} + +// Implements Iterator. +func (itr *memDBIterator) Value() []byte { + itr.assertIsValid() + key := []byte(itr.keys[itr.cur]) + return itr.db.Get(key) +} + +// Implements Iterator. +func (itr *memDBIterator) Close() { + itr.keys = nil + itr.db = nil +} + +func (itr *memDBIterator) assertIsValid() { + if !itr.Valid() { + panic("memDBIterator is invalid") + } +} + +//---------------------------------------- +// Misc. + +func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string { + keys := []string{} + for key, _ := range db.db { + if IsKeyInDomain([]byte(key), start, end, false) { + keys = append(keys, key) } } - + sort.Strings(keys) + if reverse { + nkeys := len(keys) + for i := 0; i < nkeys/2; i++ { + keys[i] = keys[nkeys-i-1] + } + } + return keys } diff --git a/db/mem_db_test.go b/db/mem_db_test.go deleted file mode 100644 index 503e361f..00000000 --- a/db/mem_db_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package db - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMemDbIterator(t *testing.T) { - db := NewMemDB() - keys := make([][]byte, 100) - for i := 0; i < 100; i++ { - keys[i] = []byte{byte(i)} - } - - value := []byte{5} - for _, k := range keys { - db.Set(k, value) - } - - iter := db.Iterator() - i := 0 - for iter.Next() { - assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key") - i += 1 - } - assert.Equal(t, i, len(db.db), "iterator didnt cover whole db") -} - -func TestMemDBClose(t *testing.T) { - db := NewMemDB() - copyDB := func(orig map[string][]byte) map[string][]byte { - copy := make(map[string][]byte) - for k, v := range orig { - copy[k] = v - } - return copy - } - k, v := []byte("foo"), []byte("bar") - db.Set(k, v) - require.Equal(t, db.Get(k), v, "expecting a successful get") - copyBefore := copyDB(db.db) - db.Close() - require.Equal(t, db.Get(k), v, "Close is a noop, expecting a successful get") - copyAfter := copyDB(db.db) - require.Equal(t, copyBefore, copyAfter, "Close is a noop and shouldn't modify any internal data") -} diff --git a/db/types.go b/db/types.go new file mode 100644 index 00000000..07858087 --- /dev/null +++ b/db/types.go @@ -0,0 +1,133 @@ +package db + +type DB interface { + + // Get returns nil iff key doesn't exist. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Get([]byte) []byte + + // Has checks if a key exists. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Has(key []byte) bool + + // Set sets the key. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Set([]byte, []byte) + SetSync([]byte, []byte) + + // Delete deletes the key. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key readonly []byte + Delete([]byte) + DeleteSync([]byte) + + // Iterate over a domain of keys in ascending order. End is exclusive. + // Start must be less than end, or the Iterator is invalid. + // A nil start is interpreted as an empty byteslice. + // If end is nil, iterates up to the last item (inclusive). + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + // CONTRACT: start, end readonly []byte + Iterator(start, end []byte) Iterator + + // Iterate over a domain of keys in descending order. End is exclusive. + // Start must be greater than end, or the Iterator is invalid. + // If start is nil, iterates from the last/greatest item (inclusive). + // If end is nil, iterates up to the first/least item (iclusive). + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + // CONTRACT: start, end readonly []byte + ReverseIterator(start, end []byte) Iterator + + // Closes the connection. + Close() + + // Creates a batch for atomic updates. + NewBatch() Batch + + // For debugging + Print() + + // Stats returns a map of property values for all keys and the size of the cache. + Stats() map[string]string +} + +//---------------------------------------- +// Batch + +type Batch interface { + SetDeleter + Write() +} + +type SetDeleter interface { + Set(key, value []byte) // CONTRACT: key, value readonly []byte + Delete(key []byte) // CONTRACT: key readonly []byte +} + +//---------------------------------------- +// Iterator + +/* + Usage: + + var itr Iterator = ... + defer itr.Close() + + for ; itr.Valid(); itr.Next() { + k, v := itr.Key(); itr.Value() + // ... + } +*/ +type Iterator interface { + + // The start & end (exclusive) limits to iterate over. + // If end < start, then the Iterator goes in reverse order. + // + // A domain of ([]byte{12, 13}, []byte{12, 14}) will iterate + // over anything with the prefix []byte{12, 13}. + // + // The smallest key is the empty byte array []byte{} - see BeginningKey(). + // The largest key is the nil byte array []byte(nil) - see EndingKey(). + // CONTRACT: start, end readonly []byte + Domain() (start []byte, end []byte) + + // Valid returns whether the current position is valid. + // Once invalid, an Iterator is forever invalid. + Valid() bool + + // Next moves the iterator to the next sequential key in the database, as + // defined by order of iteration. + // + // If Valid returns false, this method will panic. + Next() + + // Key returns the key of the cursor. + // If Valid returns false, this method will panic. + // CONTRACT: key readonly []byte + Key() (key []byte) + + // Value returns the value of the cursor. + // If Valid returns false, this method will panic. + // CONTRACT: value readonly []byte + Value() (value []byte) + + // Close releases the Iterator. + Close() +} + +// For testing convenience. +func bz(s string) []byte { + return []byte(s) +} + +// We defensively turn nil keys or values into []byte{} for +// most operations. +func nonNilBytes(bz []byte) []byte { + if bz == nil { + return []byte{} + } else { + return bz + } +} diff --git a/db/util.go b/db/util.go new file mode 100644 index 00000000..b0ab7f6a --- /dev/null +++ b/db/util.go @@ -0,0 +1,60 @@ +package db + +import ( + "bytes" +) + +func IteratePrefix(db DB, prefix []byte) Iterator { + var start, end []byte + if len(prefix) == 0 { + start = nil + end = nil + } else { + start = cp(prefix) + end = cpIncr(prefix) + } + return db.Iterator(start, end) +} + +//---------------------------------------- + +func cp(bz []byte) (ret []byte) { + ret = make([]byte, len(bz)) + copy(ret, bz) + return ret +} + +// CONTRACT: len(bz) > 0 +func cpIncr(bz []byte) (ret []byte) { + ret = cp(bz) + for i := len(bz) - 1; i >= 0; i-- { + if ret[i] < byte(0xFF) { + ret[i] += 1 + return + } else { + ret[i] = byte(0x00) + } + } + return nil +} + +// See DB interface documentation for more information. +func IsKeyInDomain(key, start, end []byte, isReverse bool) bool { + if !isReverse { + if bytes.Compare(key, start) < 0 { + return false + } + if end != nil && bytes.Compare(end, key) <= 0 { + return false + } + return true + } else { + if start != nil && bytes.Compare(start, key) < 0 { + return false + } + if end != nil && bytes.Compare(key, end) <= 0 { + return false + } + return true + } +} diff --git a/db/util_test.go b/db/util_test.go new file mode 100644 index 00000000..854448af --- /dev/null +++ b/db/util_test.go @@ -0,0 +1,93 @@ +package db + +import ( + "fmt" + "testing" +) + +// Empty iterator for empty db. +func TestPrefixIteratorNoMatchNil(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + + checkInvalid(t, itr) + }) + } +} + +// Empty iterator for db populated after iterator created. +func TestPrefixIteratorNoMatch1(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + db.SetSync(bz("1"), bz("value_1")) + + checkInvalid(t, itr) + }) + } +} + +// Empty iterator for prefix starting after db entry. +func TestPrefixIteratorNoMatch2(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("3"), bz("value_3")) + itr := IteratePrefix(db, []byte("4")) + + checkInvalid(t, itr) + }) + } +} + +// Iterator with single val for db with single val, starting from that val. +func TestPrefixIteratorMatch1(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("2"), bz("value_2")) + itr := IteratePrefix(db, bz("2")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("2"), bz("value_2")) + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Iterator with prefix iterates over everything with same prefix. +func TestPrefixIteratorMatches1N(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + // prefixed + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + + // not + db.SetSync(bz("b/3"), bz("value_3")) + db.SetSync(bz("a-3"), bz("value_3")) + db.SetSync(bz("a.3"), bz("value_3")) + db.SetSync(bz("abcdefg"), bz("value_3")) + itr := IteratePrefix(db, bz("a/")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + + // Bad! + checkNext(t, itr, false) + + //Once invalid... + checkInvalid(t, itr) + }) + } +} diff --git a/glide.lock b/glide.lock index 4b9c46c7..a0ada5a4 100644 --- a/glide.lock +++ b/glide.lock @@ -1,24 +1,24 @@ -hash: 1f3d3426e823e4a8e6d4473615fcc86c767bbea6da9114ea1e7e0a9f0ccfa129 -updated: 2017-12-05T23:47:13.202024407Z +hash: 98752078f39da926f655268b3b143f713d64edd379fc9fcb1210d9d8aa7ab4e0 +updated: 2018-02-03T01:28:00.221548057-05:00 imports: - name: github.com/fsnotify/fsnotify - version: 4da3e2cfbabc9f751898f250b49f2439785783a1 + version: c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9 - name: github.com/go-kit/kit - version: 53f10af5d5c7375d4655a3d6852457ed17ab5cc7 + version: 4dc7be5d2d12881735283bcab7352178e190fc71 subpackages: - log - log/level - log/term - name: github.com/go-logfmt/logfmt version: 390ab7935ee28ec6b286364bba9b4dd6410cb3d5 -- name: github.com/go-playground/locales - version: e4cbcb5d0652150d40ad0646651076b6bd2be4f6 - subpackages: - - currency -- name: github.com/go-playground/universal-translator - version: 71201497bace774495daed26a3874fd339e0b538 - name: github.com/go-stack/stack - version: 259ab82a6cad3992b4e21ff5cac294ccb06474bc + version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf +- name: github.com/gogo/protobuf + version: 1adfc126b41513cc696b209667c8656ea7aac67c + subpackages: + - gogoproto + - proto + - protoc-gen-gogo/descriptor - name: github.com/golang/snappy version: 553a641470496b2327abcac10b36396bd98e45c9 - name: github.com/hashicorp/hcl @@ -40,32 +40,28 @@ imports: version: b84e30acd515aadc4b783ad4ff83aff3299bdfe0 - name: github.com/magiconair/properties version: 49d762b9817ba1c2e9d0c69183c2b4a8b8f1d934 -- name: github.com/mattn/go-colorable - version: 6fcc0c1fd9b620311d821b106a400b35dc95c497 -- name: github.com/mattn/go-isatty - version: 6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c - name: github.com/mitchellh/mapstructure - version: 06020f85339e21b2478f756a78e295255ffa4d6a + version: b4575eea38cca1123ec2dc90c26529b5c5acfcff - name: github.com/pelletier/go-toml - version: 4e9e0ee19b60b13eb79915933f44d8ed5f268bdd + version: acdc4509485b587f5e675510c4f2c63e90ff68a8 - name: github.com/pkg/errors - version: f15c970de5b76fac0b59abb32d62c17cc7bed265 + version: 645ef00459ed84a119197bfb8d8205042c6df63d - name: github.com/spf13/afero - version: 8d919cbe7e2627e417f3e45c3c0e489a5b7e2536 + version: bb8f1927f2a9d3ab41c9340aa034f6b803f4359c subpackages: - mem - name: github.com/spf13/cast version: acbeb36b902d72a7a4c18e8f3241075e7ab763e4 - name: github.com/spf13/cobra - version: de2d9c4eca8f3c1de17d48b096b6504e0296f003 + version: 7b2c5ac9fc04fc5efafb60700713d4fa609b777b - name: github.com/spf13/jwalterweatherman - version: 12bd96e66386c1960ab0f74ced1362f66f552f7b + version: 7c0cea34c8ece3fbeb2b27ab9b59511d360fb394 - name: github.com/spf13/pflag - version: 4c012f6dcd9546820e378d0bdda4d8fc772cdfea + version: 97afa5e7ca8a08a383cb259e06636b5e2cc7897f - name: github.com/spf13/viper - version: 4dddf7c62e16bce5807744018f5b753bfe21bbd2 + version: 25b30aa063fc18e48662b86996252eabdcf2f0c7 - name: github.com/syndtr/goleveldb - version: adf24ef3f94bd13ec4163060b21a5678f22b429b + version: b89cc31ef7977104127d34c1bd31ebd1a9db2199 subpackages: - leveldb - leveldb/cache @@ -79,43 +75,34 @@ imports: - leveldb/storage - leveldb/table - leveldb/util -- name: github.com/tendermint/go-wire - version: 2baffcb6b690057568bc90ef1d457efb150b979a - subpackages: - - data - - data/base58 -- name: github.com/tendermint/log15 - version: f91285dece9f4875421b481da3e613d83d44f29b - subpackages: - - term - name: golang.org/x/crypto - version: 94eea52f7b742c7cbe0b03b22f0c4c8631ece122 + version: edd5e9b0879d13ee6970a50153d85b8fec9f7686 subpackages: - ripemd160 - name: golang.org/x/sys - version: 8b4580aae2a0dd0c231a45d3ccb8434ff533b840 + version: 37707fdb30a5b38865cfb95e5aab41707daec7fd subpackages: - unix - name: golang.org/x/text - version: 57961680700a5336d15015c8c50686ca5ba362a4 + version: c01e4764d870b77f8abe5096ee19ad20d80e8075 subpackages: - transform - unicode/norm -- name: gopkg.in/go-playground/validator.v9 - version: 61caf9d3038e1af346dbf5c2e16f6678e1548364 - name: gopkg.in/yaml.v2 - version: 287cf08546ab5e7e37d55a84f7ed3fd1db036de5 + version: d670f9405373e636a5a2765eea47fac0c9bc91a4 testImports: - name: github.com/davecgh/go-spew - version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + version: 346938d642f2ec3594ed81d874461961cd0faa76 subpackages: - spew +- name: github.com/fortytw2/leaktest + version: 3b724c3d7b8729a35bf4e577f71653aec6e53513 - name: github.com/pmezard/go-difflib version: d8ed2627bdf02c080bf22230dbb337003b7aba2d subpackages: - difflib - name: github.com/stretchr/testify - version: 2aa2c176b9dab406a6970f6a55f513e8a8c8b18f + version: 12b6f73e6084dad08a7c6e575284b177ecafbc71 subpackages: - assert - require diff --git a/glide.yaml b/glide.yaml index 0d722c85..cf3da346 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,31 +1,38 @@ package: github.com/tendermint/tmlibs import: - package: github.com/go-kit/kit + version: ^0.6.0 subpackages: - log - log/level - log/term - package: github.com/go-logfmt/logfmt + version: ^0.3.0 +- package: github.com/gogo/protobuf + version: ^1.0.0 + subpackages: + - gogoproto + - proto - package: github.com/jmhodges/levigo - package: github.com/pkg/errors + version: ^0.8.0 - package: github.com/spf13/cobra + version: ^0.0.1 - package: github.com/spf13/viper + version: ^1.0.0 - package: github.com/syndtr/goleveldb subpackages: - leveldb - leveldb/errors + - leveldb/iterator - leveldb/opt -- package: github.com/tendermint/go-wire - subpackages: - - data - - data/base58 -- package: github.com/tendermint/log15 - package: golang.org/x/crypto subpackages: - ripemd160 -- package: gopkg.in/go-playground/validator.v9 testImport: - package: github.com/stretchr/testify + version: ^1.2.1 subpackages: - assert - require +- package: github.com/fortytw2/leaktest diff --git a/logger/log.go b/logger/log.go deleted file mode 100644 index 2f4faef6..00000000 --- a/logger/log.go +++ /dev/null @@ -1,78 +0,0 @@ -// DEPRECATED! Use newer log package. -package logger - -import ( - "os" - - "github.com/tendermint/log15" - . "github.com/tendermint/tmlibs/common" -) - -var mainHandler log15.Handler -var bypassHandler log15.Handler - -func init() { - resetWithLogLevel("debug") -} - -func SetLogLevel(logLevel string) { - resetWithLogLevel(logLevel) -} - -func resetWithLogLevel(logLevel string) { - // main handler - //handlers := []log15.Handler{} - mainHandler = log15.LvlFilterHandler( - getLevel(logLevel), - log15.StreamHandler(os.Stdout, log15.TerminalFormat()), - ) - //handlers = append(handlers, mainHandler) - - // bypass handler for not filtering on global logLevel. - bypassHandler = log15.StreamHandler(os.Stdout, log15.TerminalFormat()) - //handlers = append(handlers, bypassHandler) - - // By setting handlers on the root, we handle events from all loggers. - log15.Root().SetHandler(mainHandler) -} - -// See go-wire/log for an example of usage. -func MainHandler() log15.Handler { - return mainHandler -} - -func New(ctx ...interface{}) log15.Logger { - return NewMain(ctx...) -} - -func BypassHandler() log15.Handler { - return bypassHandler -} - -func NewMain(ctx ...interface{}) log15.Logger { - return log15.Root().New(ctx...) -} - -func NewBypass(ctx ...interface{}) log15.Logger { - bypass := log15.New(ctx...) - bypass.SetHandler(bypassHandler) - return bypass -} - -func getLevel(lvlString string) log15.Lvl { - lvl, err := log15.LvlFromString(lvlString) - if err != nil { - Exit(Fmt("Invalid log level %v: %v", lvlString, err)) - } - return lvl -} - -//---------------------------------------- -// Exported from log15 - -var LvlFilterHandler = log15.LvlFilterHandler -var LvlDebug = log15.LvlDebug -var LvlInfo = log15.LvlInfo -var LvlNotice = log15.LvlNotice -var LvlWarn = log15.LvlWarn -var LvlError = log15.LvlError diff --git a/merkle/simple_map.go b/merkle/simple_map.go new file mode 100644 index 00000000..b59e3b4b --- /dev/null +++ b/merkle/simple_map.go @@ -0,0 +1,84 @@ +package merkle + +import ( + cmn "github.com/tendermint/tmlibs/common" + "golang.org/x/crypto/ripemd160" +) + +type SimpleMap struct { + kvs cmn.KVPairs + sorted bool +} + +func NewSimpleMap() *SimpleMap { + return &SimpleMap{ + kvs: nil, + sorted: false, + } +} + +func (sm *SimpleMap) Set(key string, value Hasher) { + sm.sorted = false + + // Hash the key to blind it... why not? + khash := SimpleHashFromBytes([]byte(key)) + + // And the value is hashed too, so you can + // check for equality with a cached value (say) + // and make a determination to fetch or not. + vhash := value.Hash() + + sm.kvs = append(sm.kvs, cmn.KVPair{ + Key: khash, + Value: vhash, + }) +} + +// Merkle root hash of items sorted by key +// (UNSTABLE: and by value too if duplicate key). +func (sm *SimpleMap) Hash() []byte { + sm.Sort() + return hashKVPairs(sm.kvs) +} + +func (sm *SimpleMap) Sort() { + if sm.sorted { + return + } + sm.kvs.Sort() + sm.sorted = true +} + +// Returns a copy of sorted KVPairs. +func (sm *SimpleMap) KVPairs() cmn.KVPairs { + sm.Sort() + kvs := make(cmn.KVPairs, len(sm.kvs)) + copy(kvs, sm.kvs) + return kvs +} + +//---------------------------------------- + +// A local extension to KVPair that can be hashed. +type kvPair cmn.KVPair + +func (kv kvPair) Hash() []byte { + hasher := ripemd160.New() + err := encodeByteSlice(hasher, kv.Key) + if err != nil { + panic(err) + } + err = encodeByteSlice(hasher, kv.Value) + if err != nil { + panic(err) + } + return hasher.Sum(nil) +} + +func hashKVPairs(kvs cmn.KVPairs) []byte { + kvsH := make([]Hasher, 0, len(kvs)) + for _, kvp := range kvs { + kvsH = append(kvsH, kvPair(kvp)) + } + return SimpleHashFromHashers(kvsH) +} diff --git a/merkle/simple_map_test.go b/merkle/simple_map_test.go new file mode 100644 index 00000000..61210132 --- /dev/null +++ b/merkle/simple_map_test.go @@ -0,0 +1,53 @@ +package merkle + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type strHasher string + +func (str strHasher) Hash() []byte { + return SimpleHashFromBytes([]byte(str)) +} + +func TestSimpleMap(t *testing.T) { + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + assert.Equal(t, "19618304d1ad2635c4238bce87f72331b22a11a1", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value2")) + assert.Equal(t, "51cb96d3d41e1714def72eb4bacc211de9ddf284", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + db.Set("key2", strHasher("value2")) + assert.Equal(t, "58a0a99d5019fdcad4bcf55942e833b2dfab9421", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", strHasher("value2")) // NOTE: out of order + db.Set("key1", strHasher("value1")) + assert.Equal(t, "58a0a99d5019fdcad4bcf55942e833b2dfab9421", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + db.Set("key2", strHasher("value2")) + db.Set("key3", strHasher("value3")) + assert.Equal(t, "cb56db3c7993e977f4c2789559ae3e5e468a6e9b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", strHasher("value2")) // NOTE: out of order + db.Set("key1", strHasher("value1")) + db.Set("key3", strHasher("value3")) + assert.Equal(t, "cb56db3c7993e977f4c2789559ae3e5e468a6e9b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } +} diff --git a/merkle/simple_proof.go b/merkle/simple_proof.go new file mode 100644 index 00000000..83f89e59 --- /dev/null +++ b/merkle/simple_proof.go @@ -0,0 +1,131 @@ +package merkle + +import ( + "bytes" + "fmt" +) + +type SimpleProof struct { + Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. +} + +// proofs[0] is the proof for items[0]. +func SimpleProofsFromHashers(items []Hasher) (rootHash []byte, proofs []*SimpleProof) { + trails, rootSPN := trailsFromHashers(items) + rootHash = rootSPN.Hash + proofs = make([]*SimpleProof, len(items)) + for i, trail := range trails { + proofs[i] = &SimpleProof{ + Aunts: trail.FlattenAunts(), + } + } + return +} + +// Verify that leafHash is a leaf hash of the simple-merkle-tree +// which hashes to rootHash. +func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { + computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) + return computedHash != nil && bytes.Equal(computedHash, rootHash) +} + +func (sp *SimpleProof) String() string { + return sp.StringIndented("") +} + +func (sp *SimpleProof) StringIndented(indent string) string { + return fmt.Sprintf(`SimpleProof{ +%s Aunts: %X +%s}`, + indent, sp.Aunts, + indent) +} + +// Use the leafHash and innerHashes to get the root merkle hash. +// If the length of the innerHashes slice isn't exactly correct, the result is nil. +func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { + // Recursive impl. + if index >= total { + return nil + } + switch total { + case 0: + panic("Cannot call computeHashFromAunts() with 0 total") + case 1: + if len(innerHashes) != 0 { + return nil + } + return leafHash + default: + if len(innerHashes) == 0 { + return nil + } + numLeft := (total + 1) / 2 + if index < numLeft { + leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if leftHash == nil { + return nil + } + return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) + } else { + rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if rightHash == nil { + return nil + } + return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) + } + } +} + +// Helper structure to construct merkle proof. +// The node and the tree is thrown away afterwards. +// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. +// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +type SimpleProofNode struct { + Hash []byte + Parent *SimpleProofNode + Left *SimpleProofNode // Left sibling (only one of Left,Right is set) + Right *SimpleProofNode // Right sibling (only one of Left,Right is set) +} + +// Starting from a leaf SimpleProofNode, FlattenAunts() will return +// the inner hashes for the item corresponding to the leaf. +func (spn *SimpleProofNode) FlattenAunts() [][]byte { + // Nonrecursive impl. + innerHashes := [][]byte{} + for spn != nil { + if spn.Left != nil { + innerHashes = append(innerHashes, spn.Left.Hash) + } else if spn.Right != nil { + innerHashes = append(innerHashes, spn.Right.Hash) + } else { + break + } + spn = spn.Parent + } + return innerHashes +} + +// trails[0].Hash is the leaf hash for items[0]. +// trails[i].Parent.Parent....Parent == root for all i. +func trailsFromHashers(items []Hasher) (trails []*SimpleProofNode, root *SimpleProofNode) { + // Recursive impl. + switch len(items) { + case 0: + return nil, nil + case 1: + trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} + return []*SimpleProofNode{trail}, trail + default: + lefts, leftRoot := trailsFromHashers(items[:(len(items)+1)/2]) + rights, rightRoot := trailsFromHashers(items[(len(items)+1)/2:]) + rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) + root := &SimpleProofNode{rootHash, nil, nil, nil} + leftRoot.Parent = root + leftRoot.Right = rightRoot + rightRoot.Parent = root + rightRoot.Left = leftRoot + return append(lefts, rights...), root + } +} diff --git a/merkle/simple_tree.go b/merkle/simple_tree.go index 8106246d..a363ea8e 100644 --- a/merkle/simple_tree.go +++ b/merkle/simple_tree.go @@ -25,24 +25,15 @@ For larger datasets, use IAVLTree. package merkle import ( - "bytes" - "fmt" - "sort" - "golang.org/x/crypto/ripemd160" - - "github.com/tendermint/go-wire" - . "github.com/tendermint/tmlibs/common" ) func SimpleHashFromTwoHashes(left []byte, right []byte) []byte { - var n int - var err error var hasher = ripemd160.New() - wire.WriteByteSlice(left, hasher, &n, &err) - wire.WriteByteSlice(right, hasher, &n, &err) + err := encodeByteSlice(hasher, left) + err = encodeByteSlice(hasher, right) if err != nil { - PanicCrisis(err) + panic(err) } return hasher.Sum(nil) } @@ -61,27 +52,25 @@ func SimpleHashFromHashes(hashes [][]byte) []byte { } } -// Convenience for SimpleHashFromHashes. -func SimpleHashFromBinaries(items []interface{}) []byte { - hashes := make([][]byte, len(items)) - for i, item := range items { - hashes[i] = SimpleHashFromBinary(item) +// NOTE: Do not implement this, use SimpleHashFromByteslices instead. +// type Byteser interface { Bytes() []byte } +// func SimpleHashFromBytesers(items []Byteser) []byte { ... } + +func SimpleHashFromByteslices(bzs [][]byte) []byte { + hashes := make([][]byte, len(bzs)) + for i, bz := range bzs { + hashes[i] = SimpleHashFromBytes(bz) } return SimpleHashFromHashes(hashes) } -// General Convenience -func SimpleHashFromBinary(item interface{}) []byte { - hasher, n, err := ripemd160.New(), new(int), new(error) - wire.WriteBinary(item, hasher, n, err) - if *err != nil { - PanicCrisis(err) - } +func SimpleHashFromBytes(bz []byte) []byte { + hasher := ripemd160.New() + hasher.Write(bz) return hasher.Sum(nil) } -// Convenience for SimpleHashFromHashes. -func SimpleHashFromHashables(items []Hashable) []byte { +func SimpleHashFromHashers(items []Hasher) []byte { hashes := make([][]byte, len(items)) for i, item := range items { hash := item.Hash() @@ -90,188 +79,10 @@ func SimpleHashFromHashables(items []Hashable) []byte { return SimpleHashFromHashes(hashes) } -// Convenience for SimpleHashFromHashes. -func SimpleHashFromMap(m map[string]interface{}) []byte { - kpPairsH := MakeSortedKVPairs(m) - return SimpleHashFromHashables(kpPairsH) -} - -//-------------------------------------------------------------------------------- - -/* Convenience struct for key-value pairs. -A list of KVPairs is hashed via `SimpleHashFromHashables`. -NOTE: Each `Value` is encoded for hashing without extra type information, -so the user is presumed to be aware of the Value types. -*/ -type KVPair struct { - Key string - Value interface{} -} - -func (kv KVPair) Hash() []byte { - hasher, n, err := ripemd160.New(), new(int), new(error) - wire.WriteString(kv.Key, hasher, n, err) - if kvH, ok := kv.Value.(Hashable); ok { - wire.WriteByteSlice(kvH.Hash(), hasher, n, err) - } else { - wire.WriteBinary(kv.Value, hasher, n, err) - } - if *err != nil { - PanicSanity(*err) - } - return hasher.Sum(nil) -} - -type KVPairs []KVPair - -func (kvps KVPairs) Len() int { return len(kvps) } -func (kvps KVPairs) Less(i, j int) bool { return kvps[i].Key < kvps[j].Key } -func (kvps KVPairs) Swap(i, j int) { kvps[i], kvps[j] = kvps[j], kvps[i] } -func (kvps KVPairs) Sort() { sort.Sort(kvps) } - -func MakeSortedKVPairs(m map[string]interface{}) []Hashable { - kvPairs := []KVPair{} +func SimpleHashFromMap(m map[string]Hasher) []byte { + sm := NewSimpleMap() for k, v := range m { - kvPairs = append(kvPairs, KVPair{k, v}) - } - KVPairs(kvPairs).Sort() - kvPairsH := []Hashable{} - for _, kvp := range kvPairs { - kvPairsH = append(kvPairsH, kvp) - } - return kvPairsH -} - -//-------------------------------------------------------------------------------- - -type SimpleProof struct { - Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. -} - -// proofs[0] is the proof for items[0]. -func SimpleProofsFromHashables(items []Hashable) (rootHash []byte, proofs []*SimpleProof) { - trails, rootSPN := trailsFromHashables(items) - rootHash = rootSPN.Hash - proofs = make([]*SimpleProof, len(items)) - for i, trail := range trails { - proofs[i] = &SimpleProof{ - Aunts: trail.FlattenAunts(), - } - } - return -} - -// Verify that leafHash is a leaf hash of the simple-merkle-tree -// which hashes to rootHash. -func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { - computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) - if computedHash == nil { - return false - } - if !bytes.Equal(computedHash, rootHash) { - return false - } - return true -} - -func (sp *SimpleProof) String() string { - return sp.StringIndented("") -} - -func (sp *SimpleProof) StringIndented(indent string) string { - return fmt.Sprintf(`SimpleProof{ -%s Aunts: %X -%s}`, - indent, sp.Aunts, - indent) -} - -// Use the leafHash and innerHashes to get the root merkle hash. -// If the length of the innerHashes slice isn't exactly correct, the result is nil. -func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { - // Recursive impl. - if index >= total { - return nil - } - switch total { - case 0: - PanicSanity("Cannot call computeHashFromAunts() with 0 total") - return nil - case 1: - if len(innerHashes) != 0 { - return nil - } - return leafHash - default: - if len(innerHashes) == 0 { - return nil - } - numLeft := (total + 1) / 2 - if index < numLeft { - leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) - if leftHash == nil { - return nil - } - return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) - } else { - rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) - if rightHash == nil { - return nil - } - return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) - } - } -} - -// Helper structure to construct merkle proof. -// The node and the tree is thrown away afterwards. -// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. -// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or -// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. -type SimpleProofNode struct { - Hash []byte - Parent *SimpleProofNode - Left *SimpleProofNode // Left sibling (only one of Left,Right is set) - Right *SimpleProofNode // Right sibling (only one of Left,Right is set) -} - -// Starting from a leaf SimpleProofNode, FlattenAunts() will return -// the inner hashes for the item corresponding to the leaf. -func (spn *SimpleProofNode) FlattenAunts() [][]byte { - // Nonrecursive impl. - innerHashes := [][]byte{} - for spn != nil { - if spn.Left != nil { - innerHashes = append(innerHashes, spn.Left.Hash) - } else if spn.Right != nil { - innerHashes = append(innerHashes, spn.Right.Hash) - } else { - break - } - spn = spn.Parent - } - return innerHashes -} - -// trails[0].Hash is the leaf hash for items[0]. -// trails[i].Parent.Parent....Parent == root for all i. -func trailsFromHashables(items []Hashable) (trails []*SimpleProofNode, root *SimpleProofNode) { - // Recursive impl. - switch len(items) { - case 0: - return nil, nil - case 1: - trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} - return []*SimpleProofNode{trail}, trail - default: - lefts, leftRoot := trailsFromHashables(items[:(len(items)+1)/2]) - rights, rightRoot := trailsFromHashables(items[(len(items)+1)/2:]) - rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) - root := &SimpleProofNode{rootHash, nil, nil, nil} - leftRoot.Parent = root - leftRoot.Right = rightRoot - rightRoot.Parent = root - rightRoot.Left = leftRoot - return append(lefts, rights...), root + sm.Set(k, v) } + return sm.Hash() } diff --git a/merkle/simple_tree_test.go b/merkle/simple_tree_test.go index 6299fa33..26f35c80 100644 --- a/merkle/simple_tree_test.go +++ b/merkle/simple_tree_test.go @@ -19,14 +19,14 @@ func TestSimpleProof(t *testing.T) { total := 100 - items := make([]Hashable, total) + items := make([]Hasher, total) for i := 0; i < total; i++ { items[i] = testItem(RandBytes(32)) } - rootHash := SimpleHashFromHashables(items) + rootHash := SimpleHashFromHashers(items) - rootHash2, proofs := SimpleProofsFromHashables(items) + rootHash2, proofs := SimpleProofsFromHashers(items) if !bytes.Equal(rootHash, rootHash2) { t.Errorf("Unmatched root hashes: %X vs %X", rootHash, rootHash2) diff --git a/merkle/types.go b/merkle/types.go index 93541eda..e0fe35fa 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -1,5 +1,10 @@ package merkle +import ( + "encoding/binary" + "io" +) + type Tree interface { Size() (size int) Height() (height int8) @@ -18,6 +23,25 @@ type Tree interface { IterateRange(start []byte, end []byte, ascending bool, fx func(key []byte, value []byte) (stop bool)) (stopped bool) } -type Hashable interface { +type Hasher interface { Hash() []byte } + +//----------------------------------------------------------------------- +// NOTE: these are duplicated from go-wire so we dont need go-wire as a dep + +func encodeByteSlice(w io.Writer, bz []byte) (err error) { + err = encodeVarint(w, int64(len(bz))) + if err != nil { + return + } + _, err = w.Write(bz) + return +} + +func encodeVarint(w io.Writer, i int64) (err error) { + var buf [10]byte + n := binary.PutVarint(buf[:], i) + _, err = w.Write(buf[0:n]) + return +} diff --git a/version/version.go b/version/version.go index 6cc88728..2c0474fa 100644 --- a/version/version.go +++ b/version/version.go @@ -1,3 +1,3 @@ package version -const Version = "0.6.0" +const Version = "0.7.0"