diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..82f77436 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,19 @@ +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[Makefile] +indent_style = tab + +[*.sh] +indent_style = tab + +[*.proto] +indent_style = space +indent_size = 2 diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a2ebfde2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.sw[opqr] +vendor +.glide + +pubsub/query/fuzz_test/output diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..0f900c57 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,438 @@ +# Changelog + +## 0.9.0 + +*June 24th, 2018* + +BREAKING: + - [events, pubsub] Removed - moved to github.com/tendermint/tendermint + - [merkle] Use 20-bytes of SHA256 instead of RIPEMD160. NOTE: this package is + moving to github.com/tendermint/go-crypto ! + - [common] Remove gogoproto from KVPair types + - [common] Error simplification, #220 + +FEATURES: + + - [db/remotedb] New DB type using an external CLevelDB process via + GRPC + - [autofile] logjack command for piping stdin to a rotating file + - [bech32] New package. NOTE: should move out of here - it's just two small + functions + - [common] ColoredBytes([]byte) string for printing mixed ascii and bytes + - [db] DebugDB uses ColoredBytes() + +## 0.8.4 + +*June 5, 2018* + +IMPROVEMENTS: + + - [autofile] Flush on Stop; Close() method to Flush and close file + +## 0.8.3 + +*May 21, 2018* + +FEATURES: + + - [common] ASCIITrim() + +## 0.8.2 (April 23rd, 2018) + +FEATURES: + + - [pubsub] TagMap, NewTagMap + - [merkle] SimpleProofsFromMap() + - [common] IsASCIIText() + - [common] PrefixEndBytes // e.g. increment or nil + - [common] BitArray.MarshalJSON/.UnmarshalJSON + - [common] BitArray uses 'x' not 'X' for String() and above. + - [db] DebugDB shows better colorized output + +BUG FIXES: + + - [common] Fix TestParallelAbort nondeterministic failure #201/#202 + - [db] PrefixDB Iterator/ReverseIterator fixes + - [db] DebugDB fixes + +## 0.8.1 (April 5th, 2018) + +FEATURES: + + - [common] Error.Error() includes cause + - [common] IsEmpty() for 0 length + +## 0.8.0 (April 4th, 2018) + +BREAKING: + + - [merkle] `PutVarint->PutUvarint` in encodeByteSlice + - [db] batch.WriteSync() + - [common] Refactored and fixed `Parallel` function + - [common] Refactored `Rand` functionality + - [common] Remove unused `Right/LeftPadString` functions + - [common] Remove StackError, introduce Error interface (to replace use of pkg/errors) + +FEATURES: + + - [db] NewPrefixDB for a DB with all keys prefixed + - [db] NewDebugDB prints everything during operation + - [common] SplitAndTrim func + - [common] rand.Float64(), rand.Int63n(n), rand.Int31n(n) and global equivalents + - [common] HexBytes Format() + +BUG FIXES: + + - [pubsub] Fix unsubscribing + - [cli] Return config errors + - [common] Fix WriteFileAtomic Windows bug + +## 0.7.1 (March 22, 2018) + +IMPROVEMENTS: + + - glide -> dep + +BUG FIXES: + + - [common] Fix panic in NewBitArray for negative bits + - [common] Fix and simplify WriteFileAtomic so it cleans up properly + +## 0.7.0 (February 20, 2018) + +BREAKING: + + - [db] Major API upgrade. See `db/types.go`. + - [common] added `Quit() <-chan struct{}` to Service interface. + The returned channel is closed when service is stopped. + - [common] Remove HTTP functions + - [common] Heap.Push takes an `int`, new Heap.PushComparable takes the comparable. + - [logger] Removed. Use `log` + - [merkle] Major API updade - uses cmn.KVPairs. + - [cli] WriteDemoConfig -> WriteConfigValues + - [all] Remove go-wire dependency! + +FEATURES: + + - [db] New FSDB that uses the filesystem directly + - [common] HexBytes + - [common] KVPair and KI64Pair (protobuf based key-value pair objects) + +IMPROVEMENTS: + + - [clist] add WaitChan() to CList, NextWaitChan() and PrevWaitChan() + to CElement. These can be used instead of blocking `*Wait()` methods + if you need to be able to send quit signal and not block forever + - [common] IsHex handles 0x-prefix + +BUG FIXES: + + - [common] BitArray check for nil arguments + - [common] Fix memory leak in RepeatTimer + +## 0.6.0 (December 29, 2017) + +BREAKING: + - [cli] remove --root + - [pubsub] add String() method to Query interface + +IMPROVEMENTS: + - [common] use a thread-safe and well seeded non-crypto rng + +BUG FIXES + - [clist] fix misuse of wait group + - [common] introduce Ticker interface and logicalTicker for better testing of timers + +## 0.5.0 (December 5, 2017) + +BREAKING: + - [common] replace Service#Start, Service#Stop first return value (bool) with an + error (ErrAlreadyStarted, ErrAlreadyStopped) + - [common] replace Service#Reset first return value (bool) with an error + - [process] removed + +FEATURES: + - [common] IntInSlice and StringInSlice functions + - [pubsub/query] introduce `Condition` struct, expose `Operator`, and add `query.Conditions()` + +## 0.4.1 (November 27, 2017) + +FEATURES: + - [common] `Keys()` method on `CMap` + +IMPROVEMENTS: + - [log] complex types now encoded as "%+v" by default if `String()` method is undefined (previously resulted in error) + - [log] logger logs its own errors + +BUG FIXES: + - [common] fixed `Kill()` to build on Windows (Windows does not have `syscall.Kill`) + +## 0.4.0 (October 26, 2017) + +BREAKING: + - [common] GoPath is now a function + - [db] `DB` and `Iterator` interfaces have new methods to better support iteration + +FEATURES: + - [autofile] `Read([]byte)` and `Write([]byte)` methods on `Group` to support binary WAL + - [common] `Kill()` sends SIGTERM to the current process + +IMPROVEMENTS: + - comments and linting + +BUG FIXES: + - [events] fix allocation error prefixing cache with 1000 empty events + +## 0.3.2 (October 2, 2017) + +BUG FIXES: + +- [autofile] fix AutoFile.Sync() to open file if it's been closed +- [db] fix MemDb.Close() to not empty the database (ie. its just a noop) + + +## 0.3.1 (September 22, 2017) + +BUG FIXES: + +- [common] fix WriteFileAtomic to not use /tmp, which can be on another device + +## 0.3.0 (September 22, 2017) + +BREAKING CHANGES: + +- [log] logger functions no longer returns an error +- [common] NewBaseService takes the new logger +- [cli] RunCaptureWithArgs now captures stderr and stdout + - +func RunCaptureWithArgs(cmd Executable, args []string, env map[string]string) (stdout, stderr string, err error) + - -func RunCaptureWithArgs(cmd Executable, args []string, env map[string]string) (output string, err error) + +FEATURES: + +- [common] various common HTTP functionality +- [common] Date range parsing from string (ex. "2015-12-31:2017-12-31") +- [common] ProtocolAndAddress function +- [pubsub] New package for publish-subscribe with more advanced filtering + +BUG FIXES: + +- [common] fix atomicity of WriteFileAtomic by calling fsync +- [db] fix memDb iteration index out of range +- [autofile] fix Flush by calling fsync + +## 0.2.2 (June 16, 2017) + +FEATURES: + +- [common] IsHex and StripHex for handling `0x` prefixed hex strings +- [log] NewTracingLogger returns a logger that output error traces, ala `github.com/pkg/errors` + +IMPROVEMENTS: + +- [cli] Error handling for tests +- [cli] Support dashes in ENV variables + +BUG FIXES: + +- [flowrate] Fix non-deterministic test failures + +## 0.2.1 (June 2, 2017) + +FEATURES: + +- [cli] Log level parsing moved here from tendermint repo + +## 0.2.0 (May 18, 2017) + +BREAKING CHANGES: + +- [common] NewBaseService takes the new logger + + +FEATURES: + +- [cli] New library to standardize building command line tools +- [log] New logging library + +BUG FIXES: + +- [autofile] Close file before rotating + +## 0.1.0 (May 1, 2017) + +Initial release, combines what were previously independent repos: + +- go-autofile +- go-clist +- go-common +- go-db +- go-events +- go-flowrate +- go-logger +- go-merkle +- go-process + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..d2dddf85 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,3 @@ +* @melekes @ebuchman +*.md @zramsay +*.rst @zramsay diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 00000000..f0eaee18 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,281 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + branch = "master" + name = "github.com/btcsuite/btcutil" + packages = ["bech32"] + revision = "d4cc87b860166d00d6b5b9e0d3b3d71d6088d4d4" + +[[projects]] + name = "github.com/davecgh/go-spew" + packages = ["spew"] + revision = "346938d642f2ec3594ed81d874461961cd0faa76" + version = "v1.1.0" + +[[projects]] + branch = "master" + name = "github.com/fortytw2/leaktest" + packages = ["."] + revision = "3b724c3d7b8729a35bf4e577f71653aec6e53513" + +[[projects]] + name = "github.com/fsnotify/fsnotify" + packages = ["."] + revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" + version = "v1.4.7" + +[[projects]] + name = "github.com/go-kit/kit" + packages = [ + "log", + "log/level", + "log/term" + ] + revision = "4dc7be5d2d12881735283bcab7352178e190fc71" + version = "v0.6.0" + +[[projects]] + name = "github.com/go-logfmt/logfmt" + packages = ["."] + revision = "390ab7935ee28ec6b286364bba9b4dd6410cb3d5" + version = "v0.3.0" + +[[projects]] + name = "github.com/go-stack/stack" + packages = ["."] + revision = "817915b46b97fd7bb80e8ab6b69f01a53ac3eebf" + version = "v1.6.0" + +[[projects]] + name = "github.com/golang/protobuf" + packages = [ + "proto", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/timestamp" + ] + revision = "b4deda0973fb4c70b50d226b1af49f3da59f5265" + version = "v1.1.0" + +[[projects]] + branch = "master" + name = "github.com/golang/snappy" + packages = ["."] + revision = "553a641470496b2327abcac10b36396bd98e45c9" + +[[projects]] + branch = "master" + name = "github.com/hashicorp/hcl" + packages = [ + ".", + "hcl/ast", + "hcl/parser", + "hcl/scanner", + "hcl/strconv", + "hcl/token", + "json/parser", + "json/scanner", + "json/token" + ] + revision = "23c074d0eceb2b8a5bfdbb271ab780cde70f05a8" + +[[projects]] + name = "github.com/inconshreveable/mousetrap" + packages = ["."] + revision = "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75" + version = "v1.0" + +[[projects]] + branch = "master" + name = "github.com/jmhodges/levigo" + packages = ["."] + revision = "c42d9e0ca023e2198120196f842701bb4c55d7b9" + +[[projects]] + branch = "master" + name = "github.com/kr/logfmt" + packages = ["."] + revision = "b84e30acd515aadc4b783ad4ff83aff3299bdfe0" + +[[projects]] + name = "github.com/magiconair/properties" + packages = ["."] + revision = "49d762b9817ba1c2e9d0c69183c2b4a8b8f1d934" + +[[projects]] + name = "github.com/mitchellh/mapstructure" + packages = ["."] + revision = "b4575eea38cca1123ec2dc90c26529b5c5acfcff" + +[[projects]] + name = "github.com/pelletier/go-toml" + packages = ["."] + revision = "acdc4509485b587f5e675510c4f2c63e90ff68a8" + version = "v1.1.0" + +[[projects]] + name = "github.com/pkg/errors" + packages = ["."] + revision = "645ef00459ed84a119197bfb8d8205042c6df63d" + version = "v0.8.0" + +[[projects]] + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + name = "github.com/spf13/afero" + packages = [ + ".", + "mem" + ] + revision = "bb8f1927f2a9d3ab41c9340aa034f6b803f4359c" + version = "v1.0.2" + +[[projects]] + name = "github.com/spf13/cast" + packages = ["."] + revision = "acbeb36b902d72a7a4c18e8f3241075e7ab763e4" + version = "v1.1.0" + +[[projects]] + name = "github.com/spf13/cobra" + packages = ["."] + revision = "7b2c5ac9fc04fc5efafb60700713d4fa609b777b" + version = "v0.0.1" + +[[projects]] + branch = "master" + name = "github.com/spf13/jwalterweatherman" + packages = ["."] + revision = "7c0cea34c8ece3fbeb2b27ab9b59511d360fb394" + +[[projects]] + name = "github.com/spf13/pflag" + packages = ["."] + revision = "97afa5e7ca8a08a383cb259e06636b5e2cc7897f" + +[[projects]] + name = "github.com/spf13/viper" + packages = ["."] + revision = "25b30aa063fc18e48662b86996252eabdcf2f0c7" + version = "v1.0.0" + +[[projects]] + name = "github.com/stretchr/testify" + packages = [ + "assert", + "require" + ] + revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71" + version = "v1.2.1" + +[[projects]] + name = "github.com/syndtr/goleveldb" + packages = [ + "leveldb", + "leveldb/cache", + "leveldb/comparer", + "leveldb/errors", + "leveldb/filter", + "leveldb/iterator", + "leveldb/journal", + "leveldb/memdb", + "leveldb/opt", + "leveldb/storage", + "leveldb/table", + "leveldb/util" + ] + revision = "b89cc31ef7977104127d34c1bd31ebd1a9db2199" + +[[projects]] + branch = "master" + name = "golang.org/x/net" + packages = [ + "context", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace" + ] + revision = "d11bb6cd8e3c4e60239c9cb20ef68586d74500d0" + +[[projects]] + name = "golang.org/x/sys" + packages = ["unix"] + revision = "37707fdb30a5b38865cfb95e5aab41707daec7fd" + +[[projects]] + name = "golang.org/x/text" + packages = [ + "collate", + "collate/build", + "internal/colltab", + "internal/gen", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", + "transform", + "unicode/bidi", + "unicode/cldr", + "unicode/norm", + "unicode/rangetable" + ] + revision = "c01e4764d870b77f8abe5096ee19ad20d80e8075" + +[[projects]] + branch = "master" + name = "google.golang.org/genproto" + packages = ["googleapis/rpc/status"] + revision = "86e600f69ee4704c6efbf6a2a40a5c10700e76c2" + +[[projects]] + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "codes", + "connectivity", + "credentials", + "encoding", + "encoding/proto", + "grpclb/grpc_lb_v1/messages", + "grpclog", + "internal", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "stats", + "status", + "tap", + "transport" + ] + revision = "d11072e7ca9811b1100b80ca0269ac831f06d024" + version = "v1.11.3" + +[[projects]] + name = "gopkg.in/yaml.v2" + packages = ["."] + revision = "d670f9405373e636a5a2765eea47fac0c9bc91a4" + version = "v2.0.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "e0c0af880b57928787ea78a820abefd2759e6aee4cba18e67ab36b80e62ad581" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 00000000..ff42087f --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,69 @@ +# Gopkg.toml example +# +# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + branch = "master" + name = "github.com/fortytw2/leaktest" + +[[constraint]] + name = "github.com/go-kit/kit" + version = "0.6.0" + +[[constraint]] + name = "github.com/go-logfmt/logfmt" + version = "0.3.0" + +[[constraint]] + name = "github.com/gogo/protobuf" + version = "1.0.0" + +[[constraint]] + branch = "master" + name = "github.com/jmhodges/levigo" + +[[constraint]] + name = "github.com/pkg/errors" + version = "0.8.0" + +[[constraint]] + name = "github.com/spf13/cobra" + version = "0.0.1" + +[[constraint]] + name = "github.com/spf13/viper" + version = "1.0.0" + +[[constraint]] + name = "github.com/stretchr/testify" + version = "1.2.1" + +[[constraint]] + name = "github.com/btcsuite/btcutil" + branch ="master" +[prune] + go-tests = true + unused-packages = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..06bc5e1c --- /dev/null +++ b/LICENSE @@ -0,0 +1,193 @@ +Tendermint Libraries +Copyright (C) 2017 Tendermint + + + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a55bc139 --- /dev/null +++ b/Makefile @@ -0,0 +1,137 @@ +GOTOOLS = \ + github.com/golang/dep/cmd/dep \ + github.com/golang/protobuf/protoc-gen-go \ + github.com/square/certstrap + # github.com/alecthomas/gometalinter.v2 \ + +GOTOOLS_CHECK = dep gometalinter.v2 protoc protoc-gen-go +INCLUDE = -I=. -I=${GOPATH}/src + +all: check get_vendor_deps protoc grpc_dbserver build test install metalinter + +check: check_tools + +######################################## +### Build + +protoc: + ## If you get the following error, + ## "error while loading shared libraries: libprotobuf.so.14: cannot open shared object file: No such file or directory" + ## See https://stackoverflow.com/a/25518702 + protoc $(INCLUDE) --go_out=plugins=grpc:. common/*.proto + @echo "--> adding nolint declarations to protobuf generated files" + @awk '/package common/ { print "//nolint: gas"; print; next }1' common/types.pb.go > common/types.pb.go.new + @mv common/types.pb.go.new common/types.pb.go + +build: + # Nothing to build! + +install: + # Nothing to install! + + +######################################## +### Tools & dependencies + +check_tools: + @# https://stackoverflow.com/a/25668869 + @echo "Found tools: $(foreach tool,$(GOTOOLS_CHECK),\ + $(if $(shell which $(tool)),$(tool),$(error "No $(tool) in PATH")))" + +get_tools: + @echo "--> Installing tools" + go get -u -v $(GOTOOLS) + # @gometalinter.v2 --install + +get_protoc: + @# https://github.com/google/protobuf/releases + curl -L https://github.com/google/protobuf/releases/download/v3.4.1/protobuf-cpp-3.4.1.tar.gz | tar xvz && \ + cd protobuf-3.4.1 && \ + DIST_LANG=cpp ./configure && \ + make && \ + make install && \ + cd .. && \ + rm -rf protobuf-3.4.1 + +update_tools: + @echo "--> Updating tools" + @go get -u $(GOTOOLS) + +get_vendor_deps: + @rm -rf vendor/ + @echo "--> Running dep ensure" + @dep ensure + + +######################################## +### Testing + +gen_certs: clean_certs + ## Generating certificates for TLS testing... + certstrap init --common-name "tendermint.com" --passphrase "" + certstrap request-cert -ip "::" --passphrase "" + certstrap sign "::" --CA "tendermint.com" --passphrase "" + mv out/::.crt out/::.key db/remotedb + +clean_certs: + ## Cleaning TLS testing certificates... + rm -rf out + rm -f db/remotedb/::.crt db/remotedb/::.key + +test: gen_certs + GOCACHE=off go test -tags gcc $(shell go list ./... | grep -v vendor) + make clean_certs + +test100: + @for i in {1..100}; do make test; done + + +######################################## +### Formatting, linting, and vetting + +fmt: + @go fmt ./... + +metalinter: + @echo "==> Running linter" + gometalinter.v2 --vendor --deadline=600s --disable-all \ + --enable=deadcode \ + --enable=goconst \ + --enable=goimports \ + --enable=gosimple \ + --enable=ineffassign \ + --enable=megacheck \ + --enable=misspell \ + --enable=staticcheck \ + --enable=safesql \ + --enable=structcheck \ + --enable=unconvert \ + --enable=unused \ + --enable=varcheck \ + --enable=vetshadow \ + ./... + + #--enable=maligned \ + #--enable=gas \ + #--enable=aligncheck \ + #--enable=dupl \ + #--enable=errcheck \ + #--enable=gocyclo \ + #--enable=golint \ <== comments on anything exported + #--enable=gotype \ + #--enable=interfacer \ + #--enable=unparam \ + #--enable=vet \ + +metalinter_all: + protoc $(INCLUDE) --lint_out=. types/*.proto + gometalinter.v2 --vendor --deadline=600s --enable-all --disable=lll ./... + + +# To avoid unintended conflicts with file names, always add to .PHONY +# unless there is a reason not to. +# https://www.gnu.org/software/make/manual/html_node/Phony-Targets.html +.PHONY: check protoc build check_tools get_tools get_protoc update_tools get_vendor_deps test fmt metalinter metalinter_all gen_certs clean_certs + +grpc_dbserver: + protoc -I db/remotedb/proto/ db/remotedb/proto/defs.proto --go_out=plugins=grpc:db/remotedb/proto diff --git a/README.md b/README.md new file mode 100644 index 00000000..9ea618db --- /dev/null +++ b/README.md @@ -0,0 +1,49 @@ +# TMLIBS + +This repo is a home for various small packages. + +## autofile + +Autofile is file access with automatic log rotation. A group of files is maintained and rotation happens +when the leading file gets too big. Provides a reader for reading from the file group. + +## cli + +CLI wraps the `cobra` and `viper` packages and handles some common elements of building a CLI like flags and env vars for the home directory and the logger. + +## clist + +Clist provides a linekd list that is safe for concurrent access by many readers. + +## common + +Common provides a hodgepodge of useful functions. + +## db + +DB provides a database interface and a number of implementions, including ones using an in-memory map, the filesystem directory structure, +an implemention of LevelDB in Go, and the official LevelDB in C. + +## events + +Events is a synchronous PubSub package. + +## flowrate + +Flowrate is a fork of https://github.com/mxk/go-flowrate that added a `SetREMA` method. + +## log + +Log is a log package structured around key-value pairs that allows logging level to be set differently for different keys. + +## merkle + +Merkle provides a simple static merkle tree and corresponding proofs. + +## process + +Process is a simple utility for spawning OS processes. + +## pubsub + +PubSub is an asynchronous PubSub package. diff --git a/autofile/README.md b/autofile/README.md new file mode 100644 index 00000000..23799200 --- /dev/null +++ b/autofile/README.md @@ -0,0 +1 @@ +# go-autofile diff --git a/autofile/autofile.go b/autofile/autofile.go new file mode 100644 index 00000000..790be522 --- /dev/null +++ b/autofile/autofile.go @@ -0,0 +1,142 @@ +package autofile + +import ( + "os" + "sync" + "time" + + cmn "github.com/tendermint/tmlibs/common" +) + +/* AutoFile usage + +// Create/Append to ./autofile_test +af, err := OpenAutoFile("autofile_test") +if err != nil { + panic(err) +} + +// Stream of writes. +// During this time, the file may be moved e.g. by logRotate. +for i := 0; i < 60; i++ { + af.Write([]byte(Fmt("LOOP(%v)", i))) + time.Sleep(time.Second) +} + +// Close the AutoFile +err = af.Close() +if err != nil { + panic(err) +} +*/ + +const autoFileOpenDuration = 1000 * time.Millisecond + +// Automatically closes and re-opens file for writing. +// This is useful for using a log file with the logrotate tool. +type AutoFile struct { + ID string + Path string + ticker *time.Ticker + mtx sync.Mutex + file *os.File +} + +func OpenAutoFile(path string) (af *AutoFile, err error) { + af = &AutoFile{ + ID: cmn.RandStr(12) + ":" + path, + Path: path, + ticker: time.NewTicker(autoFileOpenDuration), + } + if err = af.openFile(); err != nil { + return + } + go af.processTicks() + sighupWatchers.addAutoFile(af) + return +} + +func (af *AutoFile) Close() error { + af.ticker.Stop() + err := af.closeFile() + sighupWatchers.removeAutoFile(af) + return err +} + +func (af *AutoFile) processTicks() { + for { + _, ok := <-af.ticker.C + if !ok { + return // Done. + } + af.closeFile() + } +} + +func (af *AutoFile) closeFile() (err error) { + af.mtx.Lock() + defer af.mtx.Unlock() + + file := af.file + if file == nil { + return nil + } + af.file = nil + return file.Close() +} + +func (af *AutoFile) Write(b []byte) (n int, err error) { + af.mtx.Lock() + defer af.mtx.Unlock() + + if af.file == nil { + if err = af.openFile(); err != nil { + return + } + } + + n, err = af.file.Write(b) + return +} + +func (af *AutoFile) Sync() error { + af.mtx.Lock() + defer af.mtx.Unlock() + + if af.file == nil { + if err := af.openFile(); err != nil { + return err + } + } + return af.file.Sync() +} + +func (af *AutoFile) openFile() error { + file, err := os.OpenFile(af.Path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0600) + if err != nil { + return err + } + af.file = file + return nil +} + +func (af *AutoFile) Size() (int64, error) { + af.mtx.Lock() + defer af.mtx.Unlock() + + if af.file == nil { + err := af.openFile() + if err != nil { + if err == os.ErrNotExist { + return 0, nil + } + return -1, err + } + } + stat, err := af.file.Stat() + if err != nil { + return -1, err + } + return stat.Size(), nil + +} diff --git a/autofile/autofile_test.go b/autofile/autofile_test.go new file mode 100644 index 00000000..8f453dd0 --- /dev/null +++ b/autofile/autofile_test.go @@ -0,0 +1,71 @@ +package autofile + +import ( + "os" + "sync/atomic" + "syscall" + "testing" + "time" + + cmn "github.com/tendermint/tmlibs/common" +) + +func TestSIGHUP(t *testing.T) { + + // First, create an AutoFile writing to a tempfile dir + file, name := cmn.Tempfile("sighup_test") + if err := file.Close(); err != nil { + t.Fatalf("Error creating tempfile: %v", err) + } + // Here is the actual AutoFile + af, err := OpenAutoFile(name) + if err != nil { + t.Fatalf("Error creating autofile: %v", err) + } + + // Write to the file. + _, err = af.Write([]byte("Line 1\n")) + if err != nil { + t.Fatalf("Error writing to autofile: %v", err) + } + _, err = af.Write([]byte("Line 2\n")) + if err != nil { + t.Fatalf("Error writing to autofile: %v", err) + } + + // Move the file over + err = os.Rename(name, name+"_old") + if err != nil { + t.Fatalf("Error moving autofile: %v", err) + } + + // Send SIGHUP to self. + oldSighupCounter := atomic.LoadInt32(&sighupCounter) + syscall.Kill(syscall.Getpid(), syscall.SIGHUP) + + // Wait a bit... signals are not handled synchronously. + for atomic.LoadInt32(&sighupCounter) == oldSighupCounter { + time.Sleep(time.Millisecond * 10) + } + + // Write more to the file. + _, err = af.Write([]byte("Line 3\n")) + if err != nil { + t.Fatalf("Error writing to autofile: %v", err) + } + _, err = af.Write([]byte("Line 4\n")) + if err != nil { + t.Fatalf("Error writing to autofile: %v", err) + } + if err := af.Close(); err != nil { + t.Fatalf("Error closing autofile") + } + + // Both files should exist + if body := cmn.MustReadFile(name + "_old"); string(body) != "Line 1\nLine 2\n" { + t.Errorf("Unexpected body %s", body) + } + if body := cmn.MustReadFile(name); string(body) != "Line 3\nLine 4\n" { + t.Errorf("Unexpected body %s", body) + } +} diff --git a/autofile/cmd/logjack.go b/autofile/cmd/logjack.go new file mode 100644 index 00000000..f2739a7e --- /dev/null +++ b/autofile/cmd/logjack.go @@ -0,0 +1,108 @@ +package main + +import ( + "flag" + "fmt" + "io" + "os" + "strconv" + "strings" + + auto "github.com/tendermint/tmlibs/autofile" + cmn "github.com/tendermint/tmlibs/common" +) + +const Version = "0.0.1" +const sleepSeconds = 1 // Every second +const readBufferSize = 1024 // 1KB at a time + +// Parse command-line options +func parseFlags() (headPath string, chopSize int64, limitSize int64, version bool) { + var flagSet = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + var chopSizeStr, limitSizeStr string + flagSet.StringVar(&headPath, "head", "logjack.out", "Destination (head) file.") + flagSet.StringVar(&chopSizeStr, "chop", "100M", "Move file if greater than this") + flagSet.StringVar(&limitSizeStr, "limit", "10G", "Only keep this much (for each specified file). Remove old files.") + flagSet.BoolVar(&version, "version", false, "Version") + flagSet.Parse(os.Args[1:]) + chopSize = parseBytesize(chopSizeStr) + limitSize = parseBytesize(limitSizeStr) + return +} + +func main() { + + // Read options + headPath, chopSize, limitSize, version := parseFlags() + if version { + fmt.Printf("logjack version %v\n", Version) + return + } + + // Open Group + group, err := auto.OpenGroup(headPath) + if err != nil { + fmt.Printf("logjack couldn't create output file %v\n", headPath) + os.Exit(1) + } + group.SetHeadSizeLimit(chopSize) + group.SetTotalSizeLimit(limitSize) + err = group.Start() + if err != nil { + fmt.Printf("logjack couldn't start with file %v\n", headPath) + os.Exit(1) + } + + go func() { + // Forever, read from stdin and write to AutoFile. + buf := make([]byte, readBufferSize) + for { + n, err := os.Stdin.Read(buf) + group.Write(buf[:n]) + group.Flush() + if err != nil { + group.Stop() + if err == io.EOF { + os.Exit(0) + } else { + fmt.Println("logjack errored") + os.Exit(1) + } + } + } + }() + + // Trap signal + cmn.TrapSignal(func() { + fmt.Println("logjack shutting down") + }) +} + +func parseBytesize(chopSize string) int64 { + // Handle suffix multiplier + var multiplier int64 = 1 + if strings.HasSuffix(chopSize, "T") { + multiplier = 1042 * 1024 * 1024 * 1024 + chopSize = chopSize[:len(chopSize)-1] + } + if strings.HasSuffix(chopSize, "G") { + multiplier = 1042 * 1024 * 1024 + chopSize = chopSize[:len(chopSize)-1] + } + if strings.HasSuffix(chopSize, "M") { + multiplier = 1042 * 1024 + chopSize = chopSize[:len(chopSize)-1] + } + if strings.HasSuffix(chopSize, "K") { + multiplier = 1042 + chopSize = chopSize[:len(chopSize)-1] + } + + // Parse the numeric part + chopSizeInt, err := strconv.Atoi(chopSize) + if err != nil { + panic(err) + } + + return int64(chopSizeInt) * multiplier +} diff --git a/autofile/group.go b/autofile/group.go new file mode 100644 index 00000000..1ae54503 --- /dev/null +++ b/autofile/group.go @@ -0,0 +1,747 @@ +package autofile + +import ( + "bufio" + "errors" + "fmt" + "io" + "log" + "os" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + "sync" + "time" + + cmn "github.com/tendermint/tmlibs/common" +) + +const ( + groupCheckDuration = 5000 * time.Millisecond + defaultHeadSizeLimit = 10 * 1024 * 1024 // 10MB + defaultTotalSizeLimit = 1 * 1024 * 1024 * 1024 // 1GB + maxFilesToRemove = 4 // needs to be greater than 1 +) + +/* +You can open a Group to keep restrictions on an AutoFile, like +the maximum size of each chunk, and/or the total amount of bytes +stored in the group. + +The first file to be written in the Group.Dir is the head file. + + Dir/ + - + +Once the Head file reaches the size limit, it will be rotated. + + Dir/ + - .000 // First rolled file + - // New head path, starts empty. + // The implicit index is 001. + +As more files are written, the index numbers grow... + + Dir/ + - .000 // First rolled file + - .001 // Second rolled file + - ... + - // New head path + +The Group can also be used to binary-search for some line, +assuming that marker lines are written occasionally. +*/ +type Group struct { + cmn.BaseService + + ID string + Head *AutoFile // The head AutoFile to write to + headBuf *bufio.Writer + Dir string // Directory that contains .Head + ticker *time.Ticker + mtx sync.Mutex + headSizeLimit int64 + totalSizeLimit int64 + minIndex int // Includes head + maxIndex int // Includes head, where Head will move to + + // TODO: When we start deleting files, we need to start tracking GroupReaders + // and their dependencies. +} + +// OpenGroup creates a new Group with head at headPath. It returns an error if +// it fails to open head file. +func OpenGroup(headPath string) (g *Group, err error) { + dir := path.Dir(headPath) + head, err := OpenAutoFile(headPath) + if err != nil { + return nil, err + } + + g = &Group{ + ID: "group:" + head.ID, + Head: head, + headBuf: bufio.NewWriterSize(head, 4096*10), + Dir: dir, + ticker: time.NewTicker(groupCheckDuration), + headSizeLimit: defaultHeadSizeLimit, + totalSizeLimit: defaultTotalSizeLimit, + minIndex: 0, + maxIndex: 0, + } + g.BaseService = *cmn.NewBaseService(nil, "Group", g) + + gInfo := g.readGroupInfo() + g.minIndex = gInfo.MinIndex + g.maxIndex = gInfo.MaxIndex + return +} + +// OnStart implements Service by starting the goroutine that checks file and +// group limits. +func (g *Group) OnStart() error { + go g.processTicks() + return nil +} + +// OnStop implements Service by stopping the goroutine described above. +// NOTE: g.Head must be closed separately using Close. +func (g *Group) OnStop() { + g.ticker.Stop() + g.Flush() // flush any uncommitted data +} + +// Close closes the head file. The group must be stopped by this moment. +func (g *Group) Close() { + g.Flush() // flush any uncommitted data + + g.mtx.Lock() + _ = g.Head.closeFile() + g.mtx.Unlock() +} + +// SetHeadSizeLimit allows you to overwrite default head size limit - 10MB. +func (g *Group) SetHeadSizeLimit(limit int64) { + g.mtx.Lock() + g.headSizeLimit = limit + g.mtx.Unlock() +} + +// HeadSizeLimit returns the current head size limit. +func (g *Group) HeadSizeLimit() int64 { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.headSizeLimit +} + +// SetTotalSizeLimit allows you to overwrite default total size limit of the +// group - 1GB. +func (g *Group) SetTotalSizeLimit(limit int64) { + g.mtx.Lock() + g.totalSizeLimit = limit + g.mtx.Unlock() +} + +// TotalSizeLimit returns total size limit of the group. +func (g *Group) TotalSizeLimit() int64 { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.totalSizeLimit +} + +// MaxIndex returns index of the last file in the group. +func (g *Group) MaxIndex() int { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.maxIndex +} + +// MinIndex returns index of the first file in the group. +func (g *Group) MinIndex() int { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.minIndex +} + +// Write writes the contents of p into the current head of the group. It +// returns the number of bytes written. If nn < len(p), it also returns an +// error explaining why the write is short. +// NOTE: Writes are buffered so they don't write synchronously +// TODO: Make it halt if space is unavailable +func (g *Group) Write(p []byte) (nn int, err error) { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.headBuf.Write(p) +} + +// WriteLine writes line into the current head of the group. It also appends "\n". +// NOTE: Writes are buffered so they don't write synchronously +// TODO: Make it halt if space is unavailable +func (g *Group) WriteLine(line string) error { + g.mtx.Lock() + defer g.mtx.Unlock() + _, err := g.headBuf.Write([]byte(line + "\n")) + return err +} + +// Flush writes any buffered data to the underlying file and commits the +// current content of the file to stable storage. +func (g *Group) Flush() error { + g.mtx.Lock() + defer g.mtx.Unlock() + err := g.headBuf.Flush() + if err == nil { + err = g.Head.Sync() + } + return err +} + +func (g *Group) processTicks() { + for { + _, ok := <-g.ticker.C + if !ok { + return // Done. + } + g.checkHeadSizeLimit() + g.checkTotalSizeLimit() + } +} + +// NOTE: for testing +func (g *Group) stopTicker() { + g.ticker.Stop() +} + +// NOTE: this function is called manually in tests. +func (g *Group) checkHeadSizeLimit() { + limit := g.HeadSizeLimit() + if limit == 0 { + return + } + size, err := g.Head.Size() + if err != nil { + panic(err) + } + if size >= limit { + g.RotateFile() + } +} + +func (g *Group) checkTotalSizeLimit() { + limit := g.TotalSizeLimit() + if limit == 0 { + return + } + + gInfo := g.readGroupInfo() + totalSize := gInfo.TotalSize + for i := 0; i < maxFilesToRemove; i++ { + index := gInfo.MinIndex + i + if totalSize < limit { + return + } + if index == gInfo.MaxIndex { + // Special degenerate case, just do nothing. + log.Println("WARNING: Group's head " + g.Head.Path + "may grow without bound") + return + } + pathToRemove := filePathForIndex(g.Head.Path, index, gInfo.MaxIndex) + fileInfo, err := os.Stat(pathToRemove) + if err != nil { + log.Println("WARNING: Failed to fetch info for file @" + pathToRemove) + continue + } + err = os.Remove(pathToRemove) + if err != nil { + log.Println(err) + return + } + totalSize -= fileInfo.Size() + } +} + +// RotateFile causes group to close the current head and assign it some index. +// Note it does not create a new head. +func (g *Group) RotateFile() { + g.mtx.Lock() + defer g.mtx.Unlock() + + headPath := g.Head.Path + + if err := g.Head.closeFile(); err != nil { + panic(err) + } + + indexPath := filePathForIndex(headPath, g.maxIndex, g.maxIndex+1) + if err := os.Rename(headPath, indexPath); err != nil { + panic(err) + } + + g.maxIndex++ +} + +// NewReader returns a new group reader. +// CONTRACT: Caller must close the returned GroupReader. +func (g *Group) NewReader(index int) (*GroupReader, error) { + r := newGroupReader(g) + err := r.SetIndex(index) + if err != nil { + return nil, err + } + return r, nil +} + +// Returns -1 if line comes after, 0 if found, 1 if line comes before. +type SearchFunc func(line string) (int, error) + +// Searches for the right file in Group, then returns a GroupReader to start +// streaming lines. +// Returns true if an exact match was found, otherwise returns the next greater +// line that starts with prefix. +// CONTRACT: Caller must close the returned GroupReader +func (g *Group) Search(prefix string, cmp SearchFunc) (*GroupReader, bool, error) { + g.mtx.Lock() + minIndex, maxIndex := g.minIndex, g.maxIndex + g.mtx.Unlock() + // Now minIndex/maxIndex may change meanwhile, + // but it shouldn't be a big deal + // (maybe we'll want to limit scanUntil though) + + for { + curIndex := (minIndex + maxIndex + 1) / 2 + + // Base case, when there's only 1 choice left. + if minIndex == maxIndex { + r, err := g.NewReader(maxIndex) + if err != nil { + return nil, false, err + } + match, err := scanUntil(r, prefix, cmp) + if err != nil { + r.Close() + return nil, false, err + } + return r, match, err + } + + // Read starting roughly at the middle file, + // until we find line that has prefix. + r, err := g.NewReader(curIndex) + if err != nil { + return nil, false, err + } + foundIndex, line, err := scanNext(r, prefix) + r.Close() + if err != nil { + return nil, false, err + } + + // Compare this line to our search query. + val, err := cmp(line) + if err != nil { + return nil, false, err + } + if val < 0 { + // Line will come later + minIndex = foundIndex + } else if val == 0 { + // Stroke of luck, found the line + r, err := g.NewReader(foundIndex) + if err != nil { + return nil, false, err + } + match, err := scanUntil(r, prefix, cmp) + if !match { + panic("Expected match to be true") + } + if err != nil { + r.Close() + return nil, false, err + } + return r, true, err + } else { + // We passed it + maxIndex = curIndex - 1 + } + } + +} + +// Scans and returns the first line that starts with 'prefix' +// Consumes line and returns it. +func scanNext(r *GroupReader, prefix string) (int, string, error) { + for { + line, err := r.ReadLine() + if err != nil { + return 0, "", err + } + if !strings.HasPrefix(line, prefix) { + continue + } + index := r.CurIndex() + return index, line, nil + } +} + +// Returns true iff an exact match was found. +// Pushes line, does not consume it. +func scanUntil(r *GroupReader, prefix string, cmp SearchFunc) (bool, error) { + for { + line, err := r.ReadLine() + if err != nil { + return false, err + } + if !strings.HasPrefix(line, prefix) { + continue + } + val, err := cmp(line) + if err != nil { + return false, err + } + if val < 0 { + continue + } else if val == 0 { + r.PushLine(line) + return true, nil + } else { + r.PushLine(line) + return false, nil + } + } +} + +// Searches backwards for the last line in Group with prefix. +// Scans each file forward until the end to find the last match. +func (g *Group) FindLast(prefix string) (match string, found bool, err error) { + g.mtx.Lock() + minIndex, maxIndex := g.minIndex, g.maxIndex + g.mtx.Unlock() + + r, err := g.NewReader(maxIndex) + if err != nil { + return "", false, err + } + defer r.Close() + + // Open files from the back and read +GROUP_LOOP: + for i := maxIndex; i >= minIndex; i-- { + err := r.SetIndex(i) + if err != nil { + return "", false, err + } + // Scan each line and test whether line matches + for { + line, err := r.ReadLine() + if err == io.EOF { + if found { + return match, found, nil + } + continue GROUP_LOOP + } else if err != nil { + return "", false, err + } + if strings.HasPrefix(line, prefix) { + match = line + found = true + } + if r.CurIndex() > i { + if found { + return match, found, nil + } + continue GROUP_LOOP + } + } + } + + return +} + +// GroupInfo holds information about the group. +type GroupInfo struct { + MinIndex int // index of the first file in the group, including head + MaxIndex int // index of the last file in the group, including head + TotalSize int64 // total size of the group + HeadSize int64 // size of the head +} + +// Returns info after scanning all files in g.Head's dir. +func (g *Group) ReadGroupInfo() GroupInfo { + g.mtx.Lock() + defer g.mtx.Unlock() + return g.readGroupInfo() +} + +// Index includes the head. +// CONTRACT: caller should have called g.mtx.Lock +func (g *Group) readGroupInfo() GroupInfo { + groupDir := filepath.Dir(g.Head.Path) + headBase := filepath.Base(g.Head.Path) + var minIndex, maxIndex int = -1, -1 + var totalSize, headSize int64 = 0, 0 + + dir, err := os.Open(groupDir) + if err != nil { + panic(err) + } + defer dir.Close() + fiz, err := dir.Readdir(0) + if err != nil { + panic(err) + } + + // For each file in the directory, filter by pattern + for _, fileInfo := range fiz { + if fileInfo.Name() == headBase { + fileSize := fileInfo.Size() + totalSize += fileSize + headSize = fileSize + continue + } else if strings.HasPrefix(fileInfo.Name(), headBase) { + fileSize := fileInfo.Size() + totalSize += fileSize + indexedFilePattern := regexp.MustCompile(`^.+\.([0-9]{3,})$`) + submatch := indexedFilePattern.FindSubmatch([]byte(fileInfo.Name())) + if len(submatch) != 0 { + // Matches + fileIndex, err := strconv.Atoi(string(submatch[1])) + if err != nil { + panic(err) + } + if maxIndex < fileIndex { + maxIndex = fileIndex + } + if minIndex == -1 || fileIndex < minIndex { + minIndex = fileIndex + } + } + } + } + + // Now account for the head. + if minIndex == -1 { + // If there were no numbered files, + // then the head is index 0. + minIndex, maxIndex = 0, 0 + } else { + // Otherwise, the head file is 1 greater + maxIndex++ + } + return GroupInfo{minIndex, maxIndex, totalSize, headSize} +} + +func filePathForIndex(headPath string, index int, maxIndex int) string { + if index == maxIndex { + return headPath + } + return fmt.Sprintf("%v.%03d", headPath, index) +} + +//-------------------------------------------------------------------------------- + +// GroupReader provides an interface for reading from a Group. +type GroupReader struct { + *Group + mtx sync.Mutex + curIndex int + curFile *os.File + curReader *bufio.Reader + curLine []byte +} + +func newGroupReader(g *Group) *GroupReader { + return &GroupReader{ + Group: g, + curIndex: 0, + curFile: nil, + curReader: nil, + curLine: nil, + } +} + +// Close closes the GroupReader by closing the cursor file. +func (gr *GroupReader) Close() error { + gr.mtx.Lock() + defer gr.mtx.Unlock() + + if gr.curReader != nil { + err := gr.curFile.Close() + gr.curIndex = 0 + gr.curReader = nil + gr.curFile = nil + gr.curLine = nil + return err + } + return nil +} + +// Read implements io.Reader, reading bytes from the current Reader +// incrementing index until enough bytes are read. +func (gr *GroupReader) Read(p []byte) (n int, err error) { + lenP := len(p) + if lenP == 0 { + return 0, errors.New("given empty slice") + } + + gr.mtx.Lock() + defer gr.mtx.Unlock() + + // Open file if not open yet + if gr.curReader == nil { + if err = gr.openFile(gr.curIndex); err != nil { + return 0, err + } + } + + // Iterate over files until enough bytes are read + var nn int + for { + nn, err = gr.curReader.Read(p[n:]) + n += nn + if err == io.EOF { + if n >= lenP { + return n, nil + } + // Open the next file + if err1 := gr.openFile(gr.curIndex + 1); err1 != nil { + return n, err1 + } + } else if err != nil { + return n, err + } else if nn == 0 { // empty file + return n, err + } + } +} + +// ReadLine reads a line (without delimiter). +// just return io.EOF if no new lines found. +func (gr *GroupReader) ReadLine() (string, error) { + gr.mtx.Lock() + defer gr.mtx.Unlock() + + // From PushLine + if gr.curLine != nil { + line := string(gr.curLine) + gr.curLine = nil + return line, nil + } + + // Open file if not open yet + if gr.curReader == nil { + err := gr.openFile(gr.curIndex) + if err != nil { + return "", err + } + } + + // Iterate over files until line is found + var linePrefix string + for { + bytesRead, err := gr.curReader.ReadBytes('\n') + if err == io.EOF { + // Open the next file + if err1 := gr.openFile(gr.curIndex + 1); err1 != nil { + return "", err1 + } + if len(bytesRead) > 0 && bytesRead[len(bytesRead)-1] == byte('\n') { + return linePrefix + string(bytesRead[:len(bytesRead)-1]), nil + } + linePrefix += string(bytesRead) + continue + } else if err != nil { + return "", err + } + return linePrefix + string(bytesRead[:len(bytesRead)-1]), nil + } +} + +// IF index > gr.Group.maxIndex, returns io.EOF +// CONTRACT: caller should hold gr.mtx +func (gr *GroupReader) openFile(index int) error { + + // Lock on Group to ensure that head doesn't move in the meanwhile. + gr.Group.mtx.Lock() + defer gr.Group.mtx.Unlock() + + if index > gr.Group.maxIndex { + return io.EOF + } + + curFilePath := filePathForIndex(gr.Head.Path, index, gr.Group.maxIndex) + curFile, err := os.Open(curFilePath) + if err != nil { + return err + } + curReader := bufio.NewReader(curFile) + + // Update gr.cur* + if gr.curFile != nil { + gr.curFile.Close() // TODO return error? + } + gr.curIndex = index + gr.curFile = curFile + gr.curReader = curReader + gr.curLine = nil + return nil +} + +// PushLine makes the given line the current one, so the next time somebody +// calls ReadLine, this line will be returned. +// panics if called twice without calling ReadLine. +func (gr *GroupReader) PushLine(line string) { + gr.mtx.Lock() + defer gr.mtx.Unlock() + + if gr.curLine == nil { + gr.curLine = []byte(line) + } else { + panic("PushLine failed, already have line") + } +} + +// CurIndex returns cursor's file index. +func (gr *GroupReader) CurIndex() int { + gr.mtx.Lock() + defer gr.mtx.Unlock() + return gr.curIndex +} + +// SetIndex sets the cursor's file index to index by opening a file at this +// position. +func (gr *GroupReader) SetIndex(index int) error { + gr.mtx.Lock() + defer gr.mtx.Unlock() + return gr.openFile(index) +} + +//-------------------------------------------------------------------------------- + +// A simple SearchFunc that assumes that the marker is of form +// . +// For example, if prefix is '#HEIGHT:', the markers of expected to be of the form: +// +// #HEIGHT:1 +// ... +// #HEIGHT:2 +// ... +func MakeSimpleSearchFunc(prefix string, target int) SearchFunc { + return func(line string) (int, error) { + if !strings.HasPrefix(line, prefix) { + return -1, errors.New(cmn.Fmt("Marker line did not have prefix: %v", prefix)) + } + i, err := strconv.Atoi(line[len(prefix):]) + if err != nil { + return -1, errors.New(cmn.Fmt("Failed to parse marker line: %v", err.Error())) + } + if target < i { + return 1, nil + } else if target == i { + return 0, nil + } else { + return -1, nil + } + } +} diff --git a/autofile/group_test.go b/autofile/group_test.go new file mode 100644 index 00000000..2ffedcc2 --- /dev/null +++ b/autofile/group_test.go @@ -0,0 +1,438 @@ +package autofile + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + cmn "github.com/tendermint/tmlibs/common" +) + +// NOTE: Returned group has ticker stopped +func createTestGroup(t *testing.T, headSizeLimit int64) *Group { + testID := cmn.RandStr(12) + testDir := "_test_" + testID + err := cmn.EnsureDir(testDir, 0700) + require.NoError(t, err, "Error creating dir") + headPath := testDir + "/myfile" + g, err := OpenGroup(headPath) + require.NoError(t, err, "Error opening Group") + g.SetHeadSizeLimit(headSizeLimit) + g.stopTicker() + require.NotEqual(t, nil, g, "Failed to create Group") + return g +} + +func destroyTestGroup(t *testing.T, g *Group) { + g.Close() + err := os.RemoveAll(g.Dir) + require.NoError(t, err, "Error removing test Group directory") +} + +func assertGroupInfo(t *testing.T, gInfo GroupInfo, minIndex, maxIndex int, totalSize, headSize int64) { + assert.Equal(t, minIndex, gInfo.MinIndex) + assert.Equal(t, maxIndex, gInfo.MaxIndex) + assert.Equal(t, totalSize, gInfo.TotalSize) + assert.Equal(t, headSize, gInfo.HeadSize) +} + +func TestCheckHeadSizeLimit(t *testing.T) { + g := createTestGroup(t, 1000*1000) + + // At first, there are no files. + assertGroupInfo(t, g.ReadGroupInfo(), 0, 0, 0, 0) + + // Write 1000 bytes 999 times. + for i := 0; i < 999; i++ { + err := g.WriteLine(cmn.RandStr(999)) + require.NoError(t, err, "Error appending to head") + } + g.Flush() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 0, 999000, 999000) + + // Even calling checkHeadSizeLimit manually won't rotate it. + g.checkHeadSizeLimit() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 0, 999000, 999000) + + // Write 1000 more bytes. + err := g.WriteLine(cmn.RandStr(999)) + require.NoError(t, err, "Error appending to head") + g.Flush() + + // Calling checkHeadSizeLimit this time rolls it. + g.checkHeadSizeLimit() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 1, 1000000, 0) + + // Write 1000 more bytes. + err = g.WriteLine(cmn.RandStr(999)) + require.NoError(t, err, "Error appending to head") + g.Flush() + + // Calling checkHeadSizeLimit does nothing. + g.checkHeadSizeLimit() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 1, 1001000, 1000) + + // Write 1000 bytes 999 times. + for i := 0; i < 999; i++ { + err = g.WriteLine(cmn.RandStr(999)) + require.NoError(t, err, "Error appending to head") + } + g.Flush() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 1, 2000000, 1000000) + + // Calling checkHeadSizeLimit rolls it again. + g.checkHeadSizeLimit() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 2, 2000000, 0) + + // Write 1000 more bytes. + _, err = g.Head.Write([]byte(cmn.RandStr(999) + "\n")) + require.NoError(t, err, "Error appending to head") + g.Flush() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 2, 2001000, 1000) + + // Calling checkHeadSizeLimit does nothing. + g.checkHeadSizeLimit() + assertGroupInfo(t, g.ReadGroupInfo(), 0, 2, 2001000, 1000) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestSearch(t *testing.T) { + g := createTestGroup(t, 10*1000) + + // Create some files in the group that have several INFO lines in them. + // Try to put the INFO lines in various spots. + for i := 0; i < 100; i++ { + // The random junk at the end ensures that this INFO linen + // is equally likely to show up at the end. + _, err := g.Head.Write([]byte(fmt.Sprintf("INFO %v %v\n", i, cmn.RandStr(123)))) + require.NoError(t, err, "Failed to write to head") + g.checkHeadSizeLimit() + for j := 0; j < 10; j++ { + _, err1 := g.Head.Write([]byte(cmn.RandStr(123) + "\n")) + require.NoError(t, err1, "Failed to write to head") + g.checkHeadSizeLimit() + } + } + + // Create a search func that searches for line + makeSearchFunc := func(target int) SearchFunc { + return func(line string) (int, error) { + parts := strings.Split(line, " ") + if len(parts) != 3 { + return -1, errors.New("Line did not have 3 parts") + } + i, err := strconv.Atoi(parts[1]) + if err != nil { + return -1, errors.New("Failed to parse INFO: " + err.Error()) + } + if target < i { + return 1, nil + } else if target == i { + return 0, nil + } else { + return -1, nil + } + } + } + + // Now search for each number + for i := 0; i < 100; i++ { + t.Log("Testing for i", i) + gr, match, err := g.Search("INFO", makeSearchFunc(i)) + require.NoError(t, err, "Failed to search for line") + assert.True(t, match, "Expected Search to return exact match") + line, err := gr.ReadLine() + require.NoError(t, err, "Failed to read line after search") + if !strings.HasPrefix(line, fmt.Sprintf("INFO %v ", i)) { + t.Fatal("Failed to get correct line") + } + // Make sure we can continue to read from there. + cur := i + 1 + for { + line, err := gr.ReadLine() + if err == io.EOF { + if cur == 99+1 { + // OK! + break + } else { + t.Fatal("Got EOF after the wrong INFO #") + } + } else if err != nil { + t.Fatal("Error reading line", err) + } + if !strings.HasPrefix(line, "INFO ") { + continue + } + if !strings.HasPrefix(line, fmt.Sprintf("INFO %v ", cur)) { + t.Fatalf("Unexpected INFO #. Expected %v got:\n%v", cur, line) + } + cur++ + } + gr.Close() + } + + // Now search for something that is too small. + // We should get the first available line. + { + gr, match, err := g.Search("INFO", makeSearchFunc(-999)) + require.NoError(t, err, "Failed to search for line") + assert.False(t, match, "Expected Search to not return exact match") + line, err := gr.ReadLine() + require.NoError(t, err, "Failed to read line after search") + if !strings.HasPrefix(line, "INFO 0 ") { + t.Error("Failed to fetch correct line, which is the earliest INFO") + } + err = gr.Close() + require.NoError(t, err, "Failed to close GroupReader") + } + + // Now search for something that is too large. + // We should get an EOF error. + { + gr, _, err := g.Search("INFO", makeSearchFunc(999)) + assert.Equal(t, io.EOF, err) + assert.Nil(t, gr) + } + + // Cleanup + destroyTestGroup(t, g) +} + +func TestRotateFile(t *testing.T) { + g := createTestGroup(t, 0) + g.WriteLine("Line 1") + g.WriteLine("Line 2") + g.WriteLine("Line 3") + g.Flush() + g.RotateFile() + g.WriteLine("Line 4") + g.WriteLine("Line 5") + g.WriteLine("Line 6") + g.Flush() + + // Read g.Head.Path+"000" + body1, err := ioutil.ReadFile(g.Head.Path + ".000") + assert.NoError(t, err, "Failed to read first rolled file") + if string(body1) != "Line 1\nLine 2\nLine 3\n" { + t.Errorf("Got unexpected contents: [%v]", string(body1)) + } + + // Read g.Head.Path + body2, err := ioutil.ReadFile(g.Head.Path) + assert.NoError(t, err, "Failed to read first rolled file") + if string(body2) != "Line 4\nLine 5\nLine 6\n" { + t.Errorf("Got unexpected contents: [%v]", string(body2)) + } + + // Cleanup + destroyTestGroup(t, g) +} + +func TestFindLast1(t *testing.T) { + g := createTestGroup(t, 0) + + g.WriteLine("Line 1") + g.WriteLine("Line 2") + g.WriteLine("# a") + g.WriteLine("Line 3") + g.Flush() + g.RotateFile() + g.WriteLine("Line 4") + g.WriteLine("Line 5") + g.WriteLine("Line 6") + g.WriteLine("# b") + g.Flush() + + match, found, err := g.FindLast("#") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "# b", match) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestFindLast2(t *testing.T) { + g := createTestGroup(t, 0) + + g.WriteLine("Line 1") + g.WriteLine("Line 2") + g.WriteLine("Line 3") + g.Flush() + g.RotateFile() + g.WriteLine("# a") + g.WriteLine("Line 4") + g.WriteLine("Line 5") + g.WriteLine("# b") + g.WriteLine("Line 6") + g.Flush() + + match, found, err := g.FindLast("#") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "# b", match) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestFindLast3(t *testing.T) { + g := createTestGroup(t, 0) + + g.WriteLine("Line 1") + g.WriteLine("# a") + g.WriteLine("Line 2") + g.WriteLine("# b") + g.WriteLine("Line 3") + g.Flush() + g.RotateFile() + g.WriteLine("Line 4") + g.WriteLine("Line 5") + g.WriteLine("Line 6") + g.Flush() + + match, found, err := g.FindLast("#") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "# b", match) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestFindLast4(t *testing.T) { + g := createTestGroup(t, 0) + + g.WriteLine("Line 1") + g.WriteLine("Line 2") + g.WriteLine("Line 3") + g.Flush() + g.RotateFile() + g.WriteLine("Line 4") + g.WriteLine("Line 5") + g.WriteLine("Line 6") + g.Flush() + + match, found, err := g.FindLast("#") + assert.NoError(t, err) + assert.False(t, found) + assert.Empty(t, match) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestWrite(t *testing.T) { + g := createTestGroup(t, 0) + + written := []byte("Medusa") + g.Write(written) + g.Flush() + + read := make([]byte, len(written)) + gr, err := g.NewReader(0) + require.NoError(t, err, "failed to create reader") + + _, err = gr.Read(read) + assert.NoError(t, err, "failed to read data") + assert.Equal(t, written, read) + + // Cleanup + destroyTestGroup(t, g) +} + +// test that Read reads the required amount of bytes from all the files in the +// group and returns no error if n == size of the given slice. +func TestGroupReaderRead(t *testing.T) { + g := createTestGroup(t, 0) + + professor := []byte("Professor Monster") + g.Write(professor) + g.Flush() + g.RotateFile() + frankenstein := []byte("Frankenstein's Monster") + g.Write(frankenstein) + g.Flush() + + totalWrittenLength := len(professor) + len(frankenstein) + read := make([]byte, totalWrittenLength) + gr, err := g.NewReader(0) + require.NoError(t, err, "failed to create reader") + + n, err := gr.Read(read) + assert.NoError(t, err, "failed to read data") + assert.Equal(t, totalWrittenLength, n, "not enough bytes read") + professorPlusFrankenstein := professor + professorPlusFrankenstein = append(professorPlusFrankenstein, frankenstein...) + assert.Equal(t, professorPlusFrankenstein, read) + + // Cleanup + destroyTestGroup(t, g) +} + +// test that Read returns an error if number of bytes read < size of +// the given slice. Subsequent call should return 0, io.EOF. +func TestGroupReaderRead2(t *testing.T) { + g := createTestGroup(t, 0) + + professor := []byte("Professor Monster") + g.Write(professor) + g.Flush() + g.RotateFile() + frankenstein := []byte("Frankenstein's Monster") + frankensteinPart := []byte("Frankenstein") + g.Write(frankensteinPart) // note writing only a part + g.Flush() + + totalLength := len(professor) + len(frankenstein) + read := make([]byte, totalLength) + gr, err := g.NewReader(0) + require.NoError(t, err, "failed to create reader") + + // 1) n < (size of the given slice), io.EOF + n, err := gr.Read(read) + assert.Equal(t, io.EOF, err) + assert.Equal(t, len(professor)+len(frankensteinPart), n, "Read more/less bytes than it is in the group") + + // 2) 0, io.EOF + n, err = gr.Read([]byte("0")) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) + + // Cleanup + destroyTestGroup(t, g) +} + +func TestMinIndex(t *testing.T) { + g := createTestGroup(t, 0) + + assert.Zero(t, g.MinIndex(), "MinIndex should be zero at the beginning") + + // Cleanup + destroyTestGroup(t, g) +} + +func TestMaxIndex(t *testing.T) { + g := createTestGroup(t, 0) + + assert.Zero(t, g.MaxIndex(), "MaxIndex should be zero at the beginning") + + g.WriteLine("Line 1") + g.Flush() + g.RotateFile() + + assert.Equal(t, 1, g.MaxIndex(), "MaxIndex should point to the last file") + + // Cleanup + destroyTestGroup(t, g) +} diff --git a/autofile/sighup_watcher.go b/autofile/sighup_watcher.go new file mode 100644 index 00000000..56fbd4d8 --- /dev/null +++ b/autofile/sighup_watcher.go @@ -0,0 +1,63 @@ +package autofile + +import ( + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" +) + +func init() { + initSighupWatcher() +} + +var sighupWatchers *SighupWatcher +var sighupCounter int32 // For testing + +func initSighupWatcher() { + sighupWatchers = newSighupWatcher() + + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP) + + go func() { + for range c { + sighupWatchers.closeAll() + atomic.AddInt32(&sighupCounter, 1) + } + }() +} + +// Watchces for SIGHUP events and notifies registered AutoFiles +type SighupWatcher struct { + mtx sync.Mutex + autoFiles map[string]*AutoFile +} + +func newSighupWatcher() *SighupWatcher { + return &SighupWatcher{ + autoFiles: make(map[string]*AutoFile, 10), + } +} + +func (w *SighupWatcher) addAutoFile(af *AutoFile) { + w.mtx.Lock() + w.autoFiles[af.ID] = af + w.mtx.Unlock() +} + +// If AutoFile isn't registered or was already removed, does nothing. +func (w *SighupWatcher) removeAutoFile(af *AutoFile) { + w.mtx.Lock() + delete(w.autoFiles, af.ID) + w.mtx.Unlock() +} + +func (w *SighupWatcher) closeAll() { + w.mtx.Lock() + for _, af := range w.autoFiles { + af.closeFile() + } + w.mtx.Unlock() +} diff --git a/bech32/bech32.go b/bech32/bech32.go new file mode 100644 index 00000000..a4db86d5 --- /dev/null +++ b/bech32/bech32.go @@ -0,0 +1,29 @@ +package bech32 + +import ( + "github.com/btcsuite/btcutil/bech32" + "github.com/pkg/errors" +) + +//ConvertAndEncode converts from a base64 encoded byte string to base32 encoded byte string and then to bech32 +func ConvertAndEncode(hrp string, data []byte) (string, error) { + converted, err := bech32.ConvertBits(data, 8, 5, true) + if err != nil { + return "", errors.Wrap(err, "encoding bech32 failed") + } + return bech32.Encode(hrp, converted) + +} + +//DecodeAndConvert decodes a bech32 encoded string and converts to base64 encoded bytes +func DecodeAndConvert(bech string) (string, []byte, error) { + hrp, data, err := bech32.Decode(bech) + if err != nil { + return "", nil, errors.Wrap(err, "decoding bech32 failed") + } + converted, err := bech32.ConvertBits(data, 5, 8, false) + if err != nil { + return "", nil, errors.Wrap(err, "decoding bech32 failed") + } + return hrp, converted, nil +} diff --git a/bech32/bech32_test.go b/bech32/bech32_test.go new file mode 100644 index 00000000..7cdebba2 --- /dev/null +++ b/bech32/bech32_test.go @@ -0,0 +1,31 @@ +package bech32_test + +import ( + "bytes" + "crypto/sha256" + "testing" + + "github.com/tendermint/tmlibs/bech32" +) + +func TestEncodeAndDecode(t *testing.T) { + + sum := sha256.Sum256([]byte("hello world\n")) + + bech, err := bech32.ConvertAndEncode("shasum", sum[:]) + + if err != nil { + t.Error(err) + } + hrp, data, err := bech32.DecodeAndConvert(bech) + + if err != nil { + t.Error(err) + } + if hrp != "shasum" { + t.Error("Invalid hrp") + } + if bytes.Compare(data, sum[:]) != 0 { + t.Error("Invalid decode") + } +} diff --git a/circle.yml b/circle.yml new file mode 100644 index 00000000..390ffb03 --- /dev/null +++ b/circle.yml @@ -0,0 +1,21 @@ +machine: + environment: + GOPATH: "${HOME}/.go_workspace" + PROJECT_PARENT_PATH: "$GOPATH/src/github.com/$CIRCLE_PROJECT_USERNAME" + PROJECT_PATH: $GOPATH/src/github.com/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME + hosts: + localhost: 127.0.0.1 + +dependencies: + override: + - mkdir -p "$PROJECT_PARENT_PATH" + - ln -sf "$HOME/$CIRCLE_PROJECT_REPONAME/" "$PROJECT_PATH" + post: + - go version + +test: + override: + - cd $PROJECT_PATH && make get_tools && make get_vendor_deps && bash ./test.sh + post: + - cd "$PROJECT_PATH" && bash <(curl -s https://codecov.io/bash) -f coverage.txt + - cd "$PROJECT_PATH" && mv coverage.txt "${CIRCLE_ARTIFACTS}" diff --git a/cli/flags/log_level.go b/cli/flags/log_level.go new file mode 100644 index 00000000..ee4825cf --- /dev/null +++ b/cli/flags/log_level.go @@ -0,0 +1,86 @@ +package flags + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/tendermint/tmlibs/log" +) + +const ( + defaultLogLevelKey = "*" +) + +// ParseLogLevel parses complex log level - comma-separated +// list of module:level pairs with an optional *:level pair (* means +// all other modules). +// +// Example: +// ParseLogLevel("consensus:debug,mempool:debug,*:error", log.NewTMLogger(os.Stdout), "info") +func ParseLogLevel(lvl string, logger log.Logger, defaultLogLevelValue string) (log.Logger, error) { + if lvl == "" { + return nil, errors.New("Empty log level") + } + + l := lvl + + // prefix simple one word levels (e.g. "info") with "*" + if !strings.Contains(l, ":") { + l = defaultLogLevelKey + ":" + l + } + + options := make([]log.Option, 0) + + isDefaultLogLevelSet := false + var option log.Option + var err error + + list := strings.Split(l, ",") + for _, item := range list { + moduleAndLevel := strings.Split(item, ":") + + if len(moduleAndLevel) != 2 { + return nil, fmt.Errorf("Expected list in a form of \"module:level\" pairs, given pair %s, list %s", item, list) + } + + module := moduleAndLevel[0] + level := moduleAndLevel[1] + + if module == defaultLogLevelKey { + option, err = log.AllowLevel(level) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Failed to parse default log level (pair %s, list %s)", item, l)) + } + options = append(options, option) + isDefaultLogLevelSet = true + } else { + switch level { + case "debug": + option = log.AllowDebugWith("module", module) + case "info": + option = log.AllowInfoWith("module", module) + case "error": + option = log.AllowErrorWith("module", module) + case "none": + option = log.AllowNoneWith("module", module) + default: + return nil, fmt.Errorf("Expected either \"info\", \"debug\", \"error\" or \"none\" log level, given %s (pair %s, list %s)", level, item, list) + } + options = append(options, option) + + } + } + + // if "*" is not provided, set default global level + if !isDefaultLogLevelSet { + option, err = log.AllowLevel(defaultLogLevelValue) + if err != nil { + return nil, err + } + options = append(options, option) + } + + return log.NewFilter(logger, options...), nil +} diff --git a/cli/flags/log_level_test.go b/cli/flags/log_level_test.go new file mode 100644 index 00000000..faf9b19d --- /dev/null +++ b/cli/flags/log_level_test.go @@ -0,0 +1,94 @@ +package flags_test + +import ( + "bytes" + "strings" + "testing" + + tmflags "github.com/tendermint/tmlibs/cli/flags" + "github.com/tendermint/tmlibs/log" +) + +const ( + defaultLogLevelValue = "info" +) + +func TestParseLogLevel(t *testing.T) { + var buf bytes.Buffer + jsonLogger := log.NewTMJSONLogger(&buf) + + correctLogLevels := []struct { + lvl string + expectedLogLines []string + }{ + {"mempool:error", []string{ + ``, // if no default is given, assume info + ``, + `{"_msg":"Mesmero","level":"error","module":"mempool"}`, + `{"_msg":"Mind","level":"info","module":"state"}`, // if no default is given, assume info + ``}}, + + {"mempool:error,*:debug", []string{ + `{"_msg":"Kingpin","level":"debug","module":"wire"}`, + ``, + `{"_msg":"Mesmero","level":"error","module":"mempool"}`, + `{"_msg":"Mind","level":"info","module":"state"}`, + `{"_msg":"Gideon","level":"debug"}`}}, + + {"*:debug,wire:none", []string{ + ``, + `{"_msg":"Kitty Pryde","level":"info","module":"mempool"}`, + `{"_msg":"Mesmero","level":"error","module":"mempool"}`, + `{"_msg":"Mind","level":"info","module":"state"}`, + `{"_msg":"Gideon","level":"debug"}`}}, + } + + for _, c := range correctLogLevels { + logger, err := tmflags.ParseLogLevel(c.lvl, jsonLogger, defaultLogLevelValue) + if err != nil { + t.Fatal(err) + } + + buf.Reset() + + logger.With("module", "wire").Debug("Kingpin") + if have := strings.TrimSpace(buf.String()); c.expectedLogLines[0] != have { + t.Errorf("\nwant '%s'\nhave '%s'\nlevel '%s'", c.expectedLogLines[0], have, c.lvl) + } + + buf.Reset() + + logger.With("module", "mempool").Info("Kitty Pryde") + if have := strings.TrimSpace(buf.String()); c.expectedLogLines[1] != have { + t.Errorf("\nwant '%s'\nhave '%s'\nlevel '%s'", c.expectedLogLines[1], have, c.lvl) + } + + buf.Reset() + + logger.With("module", "mempool").Error("Mesmero") + if have := strings.TrimSpace(buf.String()); c.expectedLogLines[2] != have { + t.Errorf("\nwant '%s'\nhave '%s'\nlevel '%s'", c.expectedLogLines[2], have, c.lvl) + } + + buf.Reset() + + logger.With("module", "state").Info("Mind") + if have := strings.TrimSpace(buf.String()); c.expectedLogLines[3] != have { + t.Errorf("\nwant '%s'\nhave '%s'\nlevel '%s'", c.expectedLogLines[3], have, c.lvl) + } + + buf.Reset() + + logger.Debug("Gideon") + if have := strings.TrimSpace(buf.String()); c.expectedLogLines[4] != have { + t.Errorf("\nwant '%s'\nhave '%s'\nlevel '%s'", c.expectedLogLines[4], have, c.lvl) + } + } + + incorrectLogLevel := []string{"some", "mempool:some", "*:some,mempool:error"} + for _, lvl := range incorrectLogLevel { + if _, err := tmflags.ParseLogLevel(lvl, jsonLogger, defaultLogLevelValue); err == nil { + t.Fatalf("Expected %s to produce error", lvl) + } + } +} diff --git a/cli/helper.go b/cli/helper.go new file mode 100644 index 00000000..878cf26e --- /dev/null +++ b/cli/helper.go @@ -0,0 +1,87 @@ +package cli + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" +) + +// WriteConfigVals writes a toml file with the given values. +// It returns an error if writing was impossible. +func WriteConfigVals(dir string, vals map[string]string) error { + data := "" + for k, v := range vals { + data = data + fmt.Sprintf("%s = \"%s\"\n", k, v) + } + cfile := filepath.Join(dir, "config.toml") + return ioutil.WriteFile(cfile, []byte(data), 0666) +} + +// RunWithArgs executes the given command with the specified command line args +// and environmental variables set. It returns any error returned from cmd.Execute() +func RunWithArgs(cmd Executable, args []string, env map[string]string) error { + oargs := os.Args + oenv := map[string]string{} + // defer returns the environment back to normal + defer func() { + os.Args = oargs + for k, v := range oenv { + os.Setenv(k, v) + } + }() + + // set the args and env how we want them + os.Args = args + for k, v := range env { + // backup old value if there, to restore at end + oenv[k] = os.Getenv(k) + err := os.Setenv(k, v) + if err != nil { + return err + } + } + + // and finally run the command + return cmd.Execute() +} + +// RunCaptureWithArgs executes the given command with the specified command +// line args and environmental variables set. It returns string fields +// representing output written to stdout and stderr, additionally any error +// from cmd.Execute() is also returned +func RunCaptureWithArgs(cmd Executable, args []string, env map[string]string) (stdout, stderr string, err error) { + oldout, olderr := os.Stdout, os.Stderr // keep backup of the real stdout + rOut, wOut, _ := os.Pipe() + rErr, wErr, _ := os.Pipe() + os.Stdout, os.Stderr = wOut, wErr + defer func() { + os.Stdout, os.Stderr = oldout, olderr // restoring the real stdout + }() + + // copy the output in a separate goroutine so printing can't block indefinitely + copyStd := func(reader *os.File) *(chan string) { + stdC := make(chan string) + go func() { + var buf bytes.Buffer + // io.Copy will end when we call reader.Close() below + io.Copy(&buf, reader) + stdC <- buf.String() + }() + return &stdC + } + outC := copyStd(rOut) + errC := copyStd(rErr) + + // now run the command + err = RunWithArgs(cmd, args, env) + + // and grab the stdout to return + wOut.Close() + wErr.Close() + stdout = <-*outC + stderr = <-*errC + return stdout, stderr, err +} diff --git a/cli/setup.go b/cli/setup.go new file mode 100644 index 00000000..06cf1cd1 --- /dev/null +++ b/cli/setup.go @@ -0,0 +1,157 @@ +package cli + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const ( + HomeFlag = "home" + TraceFlag = "trace" + OutputFlag = "output" + EncodingFlag = "encoding" +) + +// Executable is the minimal interface to *corba.Command, so we can +// wrap if desired before the test +type Executable interface { + Execute() error +} + +// PrepareBaseCmd is meant for tendermint and other servers +func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor { + cobra.OnInitialize(func() { initEnv(envPrefix) }) + cmd.PersistentFlags().StringP(HomeFlag, "", defaultHome, "directory for config and data") + cmd.PersistentFlags().Bool(TraceFlag, false, "print out full stack trace on errors") + cmd.PersistentPreRunE = concatCobraCmdFuncs(bindFlagsLoadViper, cmd.PersistentPreRunE) + return Executor{cmd, os.Exit} +} + +// PrepareMainCmd is meant for client side libs that want some more flags +// +// This adds --encoding (hex, btc, base64) and --output (text, json) to +// the command. These only really make sense in interactive commands. +func PrepareMainCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor { + cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)") + cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)") + cmd.PersistentPreRunE = concatCobraCmdFuncs(validateOutput, cmd.PersistentPreRunE) + return PrepareBaseCmd(cmd, envPrefix, defaultHome) +} + +// initEnv sets to use ENV variables if set. +func initEnv(prefix string) { + copyEnvVars(prefix) + + // env variables with TM prefix (eg. TM_ROOT) + viper.SetEnvPrefix(prefix) + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) + viper.AutomaticEnv() +} + +// This copies all variables like TMROOT to TM_ROOT, +// so we can support both formats for the user +func copyEnvVars(prefix string) { + prefix = strings.ToUpper(prefix) + ps := prefix + "_" + for _, e := range os.Environ() { + kv := strings.SplitN(e, "=", 2) + if len(kv) == 2 { + k, v := kv[0], kv[1] + if strings.HasPrefix(k, prefix) && !strings.HasPrefix(k, ps) { + k2 := strings.Replace(k, prefix, ps, 1) + os.Setenv(k2, v) + } + } + } +} + +// Executor wraps the cobra Command with a nicer Execute method +type Executor struct { + *cobra.Command + Exit func(int) // this is os.Exit by default, override in tests +} + +type ExitCoder interface { + ExitCode() int +} + +// execute adds all child commands to the root command sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func (e Executor) Execute() error { + e.SilenceUsage = true + e.SilenceErrors = true + err := e.Command.Execute() + if err != nil { + if viper.GetBool(TraceFlag) { + fmt.Fprintf(os.Stderr, "ERROR: %+v\n", err) + } else { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + } + + // return error code 1 by default, can override it with a special error type + exitCode := 1 + if ec, ok := err.(ExitCoder); ok { + exitCode = ec.ExitCode() + } + e.Exit(exitCode) + } + return err +} + +type cobraCmdFunc func(cmd *cobra.Command, args []string) error + +// Returns a single function that calls each argument function in sequence +// RunE, PreRunE, PersistentPreRunE, etc. all have this same signature +func concatCobraCmdFuncs(fs ...cobraCmdFunc) cobraCmdFunc { + return func(cmd *cobra.Command, args []string) error { + for _, f := range fs { + if f != nil { + if err := f(cmd, args); err != nil { + return err + } + } + } + return nil + } +} + +// Bind all flags and read the config into viper +func bindFlagsLoadViper(cmd *cobra.Command, args []string) error { + // cmd.Flags() includes flags from this command and all persistent flags from the parent + if err := viper.BindPFlags(cmd.Flags()); err != nil { + return err + } + + homeDir := viper.GetString(HomeFlag) + viper.Set(HomeFlag, homeDir) + viper.SetConfigName("config") // name of config file (without extension) + viper.AddConfigPath(homeDir) // search root directory + viper.AddConfigPath(filepath.Join(homeDir, "config")) // search root directory /config + + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + // stderr, so if we redirect output to json file, this doesn't appear + // fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed()) + } else if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + // ignore not found error, return other errors + return err + } + return nil +} + +func validateOutput(cmd *cobra.Command, args []string) error { + // validate output format + output := viper.GetString(OutputFlag) + switch output { + case "text", "json": + default: + return errors.Errorf("Unsupported output format: %s", output) + } + return nil +} diff --git a/cli/setup_test.go b/cli/setup_test.go new file mode 100644 index 00000000..04209e49 --- /dev/null +++ b/cli/setup_test.go @@ -0,0 +1,237 @@ +package cli + +import ( + "fmt" + "io/ioutil" + "strconv" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSetupEnv(t *testing.T) { + cases := []struct { + args []string + env map[string]string + expected string + }{ + {nil, nil, ""}, + {[]string{"--foobar", "bang!"}, nil, "bang!"}, + // make sure reset is good + {nil, nil, ""}, + // test both variants of the prefix + {nil, map[string]string{"DEMO_FOOBAR": "good"}, "good"}, + {nil, map[string]string{"DEMOFOOBAR": "silly"}, "silly"}, + // and that cli overrides env... + {[]string{"--foobar", "important"}, + map[string]string{"DEMO_FOOBAR": "ignored"}, "important"}, + } + + for idx, tc := range cases { + i := strconv.Itoa(idx) + // test command that store value of foobar in local variable + var foo string + demo := &cobra.Command{ + Use: "demo", + RunE: func(cmd *cobra.Command, args []string) error { + foo = viper.GetString("foobar") + return nil + }, + } + demo.Flags().String("foobar", "", "Some test value from config") + cmd := PrepareBaseCmd(demo, "DEMO", "/qwerty/asdfgh") // some missing dir.. + cmd.Exit = func(int) {} + + viper.Reset() + args := append([]string{cmd.Use}, tc.args...) + err := RunWithArgs(cmd, args, tc.env) + require.Nil(t, err, i) + assert.Equal(t, tc.expected, foo, i) + } +} + +func tempDir() string { + cdir, err := ioutil.TempDir("", "test-cli") + if err != nil { + panic(err) + } + return cdir +} + +func TestSetupConfig(t *testing.T) { + // we pre-create two config files we can refer to in the rest of + // the test cases. + cval1 := "fubble" + conf1 := tempDir() + err := WriteConfigVals(conf1, map[string]string{"boo": cval1}) + require.Nil(t, err) + + cases := []struct { + args []string + env map[string]string + expected string + expectedTwo string + }{ + {nil, nil, "", ""}, + // setting on the command line + {[]string{"--boo", "haha"}, nil, "haha", ""}, + {[]string{"--two-words", "rocks"}, nil, "", "rocks"}, + {[]string{"--home", conf1}, nil, cval1, ""}, + // test both variants of the prefix + {nil, map[string]string{"RD_BOO": "bang"}, "bang", ""}, + {nil, map[string]string{"RD_TWO_WORDS": "fly"}, "", "fly"}, + {nil, map[string]string{"RDTWO_WORDS": "fly"}, "", "fly"}, + {nil, map[string]string{"RD_HOME": conf1}, cval1, ""}, + {nil, map[string]string{"RDHOME": conf1}, cval1, ""}, + } + + for idx, tc := range cases { + i := strconv.Itoa(idx) + // test command that store value of foobar in local variable + var foo, two string + boo := &cobra.Command{ + Use: "reader", + RunE: func(cmd *cobra.Command, args []string) error { + foo = viper.GetString("boo") + two = viper.GetString("two-words") + return nil + }, + } + boo.Flags().String("boo", "", "Some test value from config") + boo.Flags().String("two-words", "", "Check out env handling -") + cmd := PrepareBaseCmd(boo, "RD", "/qwerty/asdfgh") // some missing dir... + cmd.Exit = func(int) {} + + viper.Reset() + args := append([]string{cmd.Use}, tc.args...) + err := RunWithArgs(cmd, args, tc.env) + require.Nil(t, err, i) + assert.Equal(t, tc.expected, foo, i) + assert.Equal(t, tc.expectedTwo, two, i) + } +} + +type DemoConfig struct { + Name string `mapstructure:"name"` + Age int `mapstructure:"age"` + Unused int `mapstructure:"unused"` +} + +func TestSetupUnmarshal(t *testing.T) { + // we pre-create two config files we can refer to in the rest of + // the test cases. + cval1, cval2 := "someone", "else" + conf1 := tempDir() + err := WriteConfigVals(conf1, map[string]string{"name": cval1}) + require.Nil(t, err) + // even with some ignored fields, should be no problem + conf2 := tempDir() + err = WriteConfigVals(conf2, map[string]string{"name": cval2, "foo": "bar"}) + require.Nil(t, err) + + // unused is not declared on a flag and remains from base + base := DemoConfig{ + Name: "default", + Age: 42, + Unused: -7, + } + c := func(name string, age int) DemoConfig { + r := base + // anything set on the flags as a default is used over + // the default config object + r.Name = "from-flag" + if name != "" { + r.Name = name + } + if age != 0 { + r.Age = age + } + return r + } + + cases := []struct { + args []string + env map[string]string + expected DemoConfig + }{ + {nil, nil, c("", 0)}, + // setting on the command line + {[]string{"--name", "haha"}, nil, c("haha", 0)}, + {[]string{"--home", conf1}, nil, c(cval1, 0)}, + // test both variants of the prefix + {nil, map[string]string{"MR_AGE": "56"}, c("", 56)}, + {nil, map[string]string{"MR_HOME": conf1}, c(cval1, 0)}, + {[]string{"--age", "17"}, map[string]string{"MRHOME": conf2}, c(cval2, 17)}, + } + + for idx, tc := range cases { + i := strconv.Itoa(idx) + // test command that store value of foobar in local variable + cfg := base + marsh := &cobra.Command{ + Use: "marsh", + RunE: func(cmd *cobra.Command, args []string) error { + return viper.Unmarshal(&cfg) + }, + } + marsh.Flags().String("name", "from-flag", "Some test value from config") + // if we want a flag to use the proper default, then copy it + // from the default config here + marsh.Flags().Int("age", base.Age, "Some test value from config") + cmd := PrepareBaseCmd(marsh, "MR", "/qwerty/asdfgh") // some missing dir... + cmd.Exit = func(int) {} + + viper.Reset() + args := append([]string{cmd.Use}, tc.args...) + err := RunWithArgs(cmd, args, tc.env) + require.Nil(t, err, i) + assert.Equal(t, tc.expected, cfg, i) + } +} + +func TestSetupTrace(t *testing.T) { + cases := []struct { + args []string + env map[string]string + long bool + expected string + }{ + {nil, nil, false, "Trace flag = false"}, + {[]string{"--trace"}, nil, true, "Trace flag = true"}, + {[]string{"--no-such-flag"}, nil, false, "unknown flag: --no-such-flag"}, + {nil, map[string]string{"DBG_TRACE": "true"}, true, "Trace flag = true"}, + } + + for idx, tc := range cases { + i := strconv.Itoa(idx) + // test command that store value of foobar in local variable + trace := &cobra.Command{ + Use: "trace", + RunE: func(cmd *cobra.Command, args []string) error { + return errors.Errorf("Trace flag = %t", viper.GetBool(TraceFlag)) + }, + } + cmd := PrepareBaseCmd(trace, "DBG", "/qwerty/asdfgh") // some missing dir.. + cmd.Exit = func(int) {} + + viper.Reset() + args := append([]string{cmd.Use}, tc.args...) + stdout, stderr, err := RunCaptureWithArgs(cmd, args, tc.env) + require.NotNil(t, err, i) + require.Equal(t, "", stdout, i) + require.NotEqual(t, "", stderr, i) + msg := strings.Split(stderr, "\n") + desired := fmt.Sprintf("ERROR: %s", tc.expected) + assert.Equal(t, desired, msg[0], i) + if tc.long && assert.True(t, len(msg) > 2, i) { + // the next line starts the stack trace... + assert.Contains(t, msg[1], "TestSetupTrace", i) + assert.Contains(t, msg[2], "setup_test.go", i) + } + } +} diff --git a/clist/clist.go b/clist/clist.go new file mode 100644 index 00000000..ccb1f577 --- /dev/null +++ b/clist/clist.go @@ -0,0 +1,384 @@ +package clist + +/* + +The purpose of CList is to provide a goroutine-safe linked-list. +This list can be traversed concurrently by any number of goroutines. +However, removed CElements cannot be added back. +NOTE: Not all methods of container/list are (yet) implemented. +NOTE: Removed elements need to DetachPrev or DetachNext consistently +to ensure garbage collection of removed elements. + +*/ + +import ( + "sync" +) + +/* + +CElement is an element of a linked-list +Traversal from a CElement is goroutine-safe. + +We can't avoid using WaitGroups or for-loops given the documentation +spec without re-implementing the primitives that already exist in +golang/sync. Notice that WaitGroup allows many go-routines to be +simultaneously released, which is what we want. Mutex doesn't do +this. RWMutex does this, but it's clumsy to use in the way that a +WaitGroup would be used -- and we'd end up having two RWMutex's for +prev/next each, which is doubly confusing. + +sync.Cond would be sort-of useful, but we don't need a write-lock in +the for-loop. Use sync.Cond when you need serial access to the +"condition". In our case our condition is if `next != nil || removed`, +and there's no reason to serialize that condition for goroutines +waiting on NextWait() (since it's just a read operation). + +*/ +type CElement struct { + mtx sync.RWMutex + prev *CElement + prevWg *sync.WaitGroup + prevWaitCh chan struct{} + next *CElement + nextWg *sync.WaitGroup + nextWaitCh chan struct{} + removed bool + + Value interface{} // immutable +} + +// Blocking implementation of Next(). +// May return nil iff CElement was tail and got removed. +func (e *CElement) NextWait() *CElement { + for { + e.mtx.RLock() + next := e.next + nextWg := e.nextWg + removed := e.removed + e.mtx.RUnlock() + + if next != nil || removed { + return next + } + + nextWg.Wait() + // e.next doesn't necessarily exist here. + // That's why we need to continue a for-loop. + } +} + +// Blocking implementation of Prev(). +// May return nil iff CElement was head and got removed. +func (e *CElement) PrevWait() *CElement { + for { + e.mtx.RLock() + prev := e.prev + prevWg := e.prevWg + removed := e.removed + e.mtx.RUnlock() + + if prev != nil || removed { + return prev + } + + prevWg.Wait() + } +} + +// PrevWaitChan can be used to wait until Prev becomes not nil. Once it does, +// channel will be closed. +func (e *CElement) PrevWaitChan() <-chan struct{} { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.prevWaitCh +} + +// NextWaitChan can be used to wait until Next becomes not nil. Once it does, +// channel will be closed. +func (e *CElement) NextWaitChan() <-chan struct{} { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.nextWaitCh +} + +// Nonblocking, may return nil if at the end. +func (e *CElement) Next() *CElement { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.next +} + +// Nonblocking, may return nil if at the end. +func (e *CElement) Prev() *CElement { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.prev +} + +func (e *CElement) Removed() bool { + e.mtx.RLock() + defer e.mtx.RUnlock() + + return e.removed +} + +func (e *CElement) DetachNext() { + if !e.Removed() { + panic("DetachNext() must be called after Remove(e)") + } + e.mtx.Lock() + defer e.mtx.Unlock() + + e.next = nil +} + +func (e *CElement) DetachPrev() { + if !e.Removed() { + panic("DetachPrev() must be called after Remove(e)") + } + e.mtx.Lock() + defer e.mtx.Unlock() + + e.prev = nil +} + +// NOTE: This function needs to be safe for +// concurrent goroutines waiting on nextWg. +func (e *CElement) SetNext(newNext *CElement) { + e.mtx.Lock() + defer e.mtx.Unlock() + + oldNext := e.next + e.next = newNext + if oldNext != nil && newNext == nil { + // See https://golang.org/pkg/sync/: + // + // If a WaitGroup is reused to wait for several independent sets of + // events, new Add calls must happen after all previous Wait calls have + // returned. + e.nextWg = waitGroup1() // WaitGroups are difficult to re-use. + e.nextWaitCh = make(chan struct{}) + } + if oldNext == nil && newNext != nil { + e.nextWg.Done() + close(e.nextWaitCh) + } +} + +// NOTE: This function needs to be safe for +// concurrent goroutines waiting on prevWg +func (e *CElement) SetPrev(newPrev *CElement) { + e.mtx.Lock() + defer e.mtx.Unlock() + + oldPrev := e.prev + e.prev = newPrev + if oldPrev != nil && newPrev == nil { + e.prevWg = waitGroup1() // WaitGroups are difficult to re-use. + e.prevWaitCh = make(chan struct{}) + } + if oldPrev == nil && newPrev != nil { + e.prevWg.Done() + close(e.prevWaitCh) + } +} + +func (e *CElement) SetRemoved() { + e.mtx.Lock() + defer e.mtx.Unlock() + + e.removed = true + + // This wakes up anyone waiting in either direction. + if e.prev == nil { + e.prevWg.Done() + close(e.prevWaitCh) + } + if e.next == nil { + e.nextWg.Done() + close(e.nextWaitCh) + } +} + +//-------------------------------------------------------------------------------- + +// CList represents a linked list. +// The zero value for CList is an empty list ready to use. +// Operations are goroutine-safe. +type CList struct { + mtx sync.RWMutex + wg *sync.WaitGroup + waitCh chan struct{} + head *CElement // first element + tail *CElement // last element + len int // list length +} + +func (l *CList) Init() *CList { + l.mtx.Lock() + defer l.mtx.Unlock() + + l.wg = waitGroup1() + l.waitCh = make(chan struct{}) + l.head = nil + l.tail = nil + l.len = 0 + return l +} + +func New() *CList { return new(CList).Init() } + +func (l *CList) Len() int { + l.mtx.RLock() + defer l.mtx.RUnlock() + + return l.len +} + +func (l *CList) Front() *CElement { + l.mtx.RLock() + defer l.mtx.RUnlock() + + return l.head +} + +func (l *CList) FrontWait() *CElement { + // Loop until the head is non-nil else wait and try again + for { + l.mtx.RLock() + head := l.head + wg := l.wg + l.mtx.RUnlock() + + if head != nil { + return head + } + wg.Wait() + // NOTE: If you think l.head exists here, think harder. + } +} + +func (l *CList) Back() *CElement { + l.mtx.RLock() + defer l.mtx.RUnlock() + + return l.tail +} + +func (l *CList) BackWait() *CElement { + for { + l.mtx.RLock() + tail := l.tail + wg := l.wg + l.mtx.RUnlock() + + if tail != nil { + return tail + } + wg.Wait() + // l.tail doesn't necessarily exist here. + // That's why we need to continue a for-loop. + } +} + +// WaitChan can be used to wait until Front or Back becomes not nil. Once it +// does, channel will be closed. +func (l *CList) WaitChan() <-chan struct{} { + l.mtx.Lock() + defer l.mtx.Unlock() + + return l.waitCh +} + +func (l *CList) PushBack(v interface{}) *CElement { + l.mtx.Lock() + defer l.mtx.Unlock() + + // Construct a new element + e := &CElement{ + prev: nil, + prevWg: waitGroup1(), + prevWaitCh: make(chan struct{}), + next: nil, + nextWg: waitGroup1(), + nextWaitCh: make(chan struct{}), + removed: false, + Value: v, + } + + // Release waiters on FrontWait/BackWait maybe + if l.len == 0 { + l.wg.Done() + close(l.waitCh) + } + l.len++ + + // Modify the tail + if l.tail == nil { + l.head = e + l.tail = e + } else { + e.SetPrev(l.tail) // We must init e first. + l.tail.SetNext(e) // This will make e accessible. + l.tail = e // Update the list. + } + + return e +} + +// CONTRACT: Caller must call e.DetachPrev() and/or e.DetachNext() to avoid memory leaks. +// NOTE: As per the contract of CList, removed elements cannot be added back. +func (l *CList) Remove(e *CElement) interface{} { + l.mtx.Lock() + defer l.mtx.Unlock() + + prev := e.Prev() + next := e.Next() + + if l.head == nil || l.tail == nil { + panic("Remove(e) on empty CList") + } + if prev == nil && l.head != e { + panic("Remove(e) with false head") + } + if next == nil && l.tail != e { + panic("Remove(e) with false tail") + } + + // If we're removing the only item, make CList FrontWait/BackWait wait. + if l.len == 1 { + l.wg = waitGroup1() // WaitGroups are difficult to re-use. + l.waitCh = make(chan struct{}) + } + + // Update l.len + l.len-- + + // Connect next/prev and set head/tail + if prev == nil { + l.head = next + } else { + prev.SetNext(next) + } + if next == nil { + l.tail = prev + } else { + next.SetPrev(prev) + } + + // Set .Done() on e, otherwise waiters will wait forever. + e.SetRemoved() + + return e.Value +} + +func waitGroup1() (wg *sync.WaitGroup) { + wg = &sync.WaitGroup{} + wg.Add(1) + return +} diff --git a/clist/clist_test.go b/clist/clist_test.go new file mode 100644 index 00000000..6171f1a3 --- /dev/null +++ b/clist/clist_test.go @@ -0,0 +1,293 @@ +package clist + +import ( + "fmt" + "math/rand" + "runtime" + "sync/atomic" + "testing" + "time" +) + +func TestSmall(t *testing.T) { + l := New() + el1 := l.PushBack(1) + el2 := l.PushBack(2) + el3 := l.PushBack(3) + if l.Len() != 3 { + t.Error("Expected len 3, got ", l.Len()) + } + + //fmt.Printf("%p %v\n", el1, el1) + //fmt.Printf("%p %v\n", el2, el2) + //fmt.Printf("%p %v\n", el3, el3) + + r1 := l.Remove(el1) + + //fmt.Printf("%p %v\n", el1, el1) + //fmt.Printf("%p %v\n", el2, el2) + //fmt.Printf("%p %v\n", el3, el3) + + r2 := l.Remove(el2) + + //fmt.Printf("%p %v\n", el1, el1) + //fmt.Printf("%p %v\n", el2, el2) + //fmt.Printf("%p %v\n", el3, el3) + + r3 := l.Remove(el3) + + if r1 != 1 { + t.Error("Expected 1, got ", r1) + } + if r2 != 2 { + t.Error("Expected 2, got ", r2) + } + if r3 != 3 { + t.Error("Expected 3, got ", r3) + } + if l.Len() != 0 { + t.Error("Expected len 0, got ", l.Len()) + } + +} + +/* +This test is quite hacky because it relies on SetFinalizer +which isn't guaranteed to run at all. +*/ +// nolint: megacheck +func _TestGCFifo(t *testing.T) { + + const numElements = 1000000 + l := New() + gcCount := new(uint64) + + // SetFinalizer doesn't work well with circular structures, + // so we construct a trivial non-circular structure to + // track. + type value struct { + Int int + } + done := make(chan struct{}) + + for i := 0; i < numElements; i++ { + v := new(value) + v.Int = i + l.PushBack(v) + runtime.SetFinalizer(v, func(v *value) { + atomic.AddUint64(gcCount, 1) + }) + } + + for el := l.Front(); el != nil; { + l.Remove(el) + //oldEl := el + el = el.Next() + //oldEl.DetachPrev() + //oldEl.DetachNext() + } + + runtime.GC() + time.Sleep(time.Second * 3) + runtime.GC() + time.Sleep(time.Second * 3) + _ = done + + if *gcCount != numElements { + t.Errorf("Expected gcCount to be %v, got %v", numElements, + *gcCount) + } +} + +/* +This test is quite hacky because it relies on SetFinalizer +which isn't guaranteed to run at all. +*/ +// nolint: megacheck +func _TestGCRandom(t *testing.T) { + + const numElements = 1000000 + l := New() + gcCount := 0 + + // SetFinalizer doesn't work well with circular structures, + // so we construct a trivial non-circular structure to + // track. + type value struct { + Int int + } + + for i := 0; i < numElements; i++ { + v := new(value) + v.Int = i + l.PushBack(v) + runtime.SetFinalizer(v, func(v *value) { + gcCount++ + }) + } + + els := make([]*CElement, 0, numElements) + for el := l.Front(); el != nil; el = el.Next() { + els = append(els, el) + } + + for _, i := range rand.Perm(numElements) { + el := els[i] + l.Remove(el) + _ = el.Next() + } + + runtime.GC() + time.Sleep(time.Second * 3) + + if gcCount != numElements { + t.Errorf("Expected gcCount to be %v, got %v", numElements, + gcCount) + } +} + +func TestScanRightDeleteRandom(t *testing.T) { + + const numElements = 10000 + const numTimes = 1000 + const numScanners = 10 + + l := New() + stop := make(chan struct{}) + + els := make([]*CElement, numElements) + for i := 0; i < numElements; i++ { + el := l.PushBack(i) + els[i] = el + } + + // Launch scanner routines that will rapidly iterate over elements. + for i := 0; i < numScanners; i++ { + go func(scannerID int) { + var el *CElement + restartCounter := 0 + counter := 0 + FOR_LOOP: + for { + select { + case <-stop: + fmt.Println("stopped") + break FOR_LOOP + default: + } + if el == nil { + el = l.FrontWait() + restartCounter++ + } + el = el.Next() + counter++ + } + fmt.Printf("Scanner %v restartCounter: %v counter: %v\n", scannerID, restartCounter, counter) + }(i) + } + + // Remove an element, push back an element. + for i := 0; i < numTimes; i++ { + // Pick an element to remove + rmElIdx := rand.Intn(len(els)) + rmEl := els[rmElIdx] + + // Remove it + l.Remove(rmEl) + //fmt.Print(".") + + // Insert a new element + newEl := l.PushBack(-1*i - 1) + els[rmElIdx] = newEl + + if i%100000 == 0 { + fmt.Printf("Pushed %vK elements so far...\n", i/1000) + } + + } + + // Stop scanners + close(stop) + time.Sleep(time.Second * 1) + + // And remove all the elements. + for el := l.Front(); el != nil; el = el.Next() { + l.Remove(el) + } + if l.Len() != 0 { + t.Fatal("Failed to remove all elements from CList") + } +} + +func TestWaitChan(t *testing.T) { + l := New() + ch := l.WaitChan() + + // 1) add one element to an empty list + go l.PushBack(1) + <-ch + + // 2) and remove it + el := l.Front() + v := l.Remove(el) + if v != 1 { + t.Fatal("where is 1 coming from?") + } + + // 3) test iterating forward and waiting for Next (NextWaitChan and Next) + el = l.PushBack(0) + + done := make(chan struct{}) + pushed := 0 + go func() { + for i := 1; i < 100; i++ { + l.PushBack(i) + pushed++ + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + } + close(done) + }() + + next := el + seen := 0 +FOR_LOOP: + for { + select { + case <-next.NextWaitChan(): + next = next.Next() + seen++ + if next == nil { + continue + } + case <-done: + break FOR_LOOP + case <-time.After(10 * time.Second): + t.Fatal("max execution time") + } + } + + if pushed != seen { + t.Fatalf("number of pushed items (%d) not equal to number of seen items (%d)", pushed, seen) + } + + // 4) test iterating backwards (PrevWaitChan and Prev) + prev := next + seen = 0 +FOR_LOOP2: + for { + select { + case <-prev.PrevWaitChan(): + prev = prev.Prev() + seen++ + if prev == nil { + t.Fatal("expected PrevWaitChan to block forever on nil when reached first elem") + } + case <-time.After(5 * time.Second): + break FOR_LOOP2 + } + } + + if pushed != seen { + t.Fatalf("number of pushed items (%d) not equal to number of seen items (%d)", pushed, seen) + } +} diff --git a/common/LICENSE b/common/LICENSE new file mode 100644 index 00000000..8a142a71 --- /dev/null +++ b/common/LICENSE @@ -0,0 +1,193 @@ +Tendermint Go-Common +Copyright (C) 2015 Tendermint + + + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/common/array.go b/common/array.go new file mode 100644 index 00000000..adedc42b --- /dev/null +++ b/common/array.go @@ -0,0 +1,5 @@ +package common + +func Arr(items ...interface{}) []interface{} { + return items +} diff --git a/common/async.go b/common/async.go new file mode 100644 index 00000000..7be09a3c --- /dev/null +++ b/common/async.go @@ -0,0 +1,177 @@ +package common + +import ( + "sync/atomic" +) + +//---------------------------------------- +// Task + +// val: the value returned after task execution. +// err: the error returned during task completion. +// abort: tells Parallel to return, whether or not all tasks have completed. +type Task func(i int) (val interface{}, err error, abort bool) + +type TaskResult struct { + Value interface{} + Error error +} + +type TaskResultCh <-chan TaskResult + +type taskResultOK struct { + TaskResult + OK bool +} + +type TaskResultSet struct { + chz []TaskResultCh + results []taskResultOK +} + +func newTaskResultSet(chz []TaskResultCh) *TaskResultSet { + return &TaskResultSet{ + chz: chz, + results: make([]taskResultOK, len(chz)), + } +} + +func (trs *TaskResultSet) Channels() []TaskResultCh { + return trs.chz +} + +func (trs *TaskResultSet) LatestResult(index int) (TaskResult, bool) { + if len(trs.results) <= index { + return TaskResult{}, false + } + resultOK := trs.results[index] + return resultOK.TaskResult, resultOK.OK +} + +// NOTE: Not concurrency safe. +// Writes results to trs.results without waiting for all tasks to complete. +func (trs *TaskResultSet) Reap() *TaskResultSet { + for i := 0; i < len(trs.results); i++ { + var trch = trs.chz[i] + select { + case result, ok := <-trch: + if ok { + // Write result. + trs.results[i] = taskResultOK{ + TaskResult: result, + OK: true, + } + } else { + // We already wrote it. + } + default: + // Do nothing. + } + } + return trs +} + +// NOTE: Not concurrency safe. +// Like Reap() but waits until all tasks have returned or panic'd. +func (trs *TaskResultSet) Wait() *TaskResultSet { + for i := 0; i < len(trs.results); i++ { + var trch = trs.chz[i] + select { + case result, ok := <-trch: + if ok { + // Write result. + trs.results[i] = taskResultOK{ + TaskResult: result, + OK: true, + } + } else { + // We already wrote it. + } + } + } + return trs +} + +// Returns the firstmost (by task index) error as +// discovered by all previous Reap() calls. +func (trs *TaskResultSet) FirstValue() interface{} { + for _, result := range trs.results { + if result.Value != nil { + return result.Value + } + } + return nil +} + +// Returns the firstmost (by task index) error as +// discovered by all previous Reap() calls. +func (trs *TaskResultSet) FirstError() error { + for _, result := range trs.results { + if result.Error != nil { + return result.Error + } + } + return nil +} + +//---------------------------------------- +// Parallel + +// Run tasks in parallel, with ability to abort early. +// Returns ok=false iff any of the tasks returned abort=true. +// NOTE: Do not implement quit features here. Instead, provide convenient +// concurrent quit-like primitives, passed implicitly via Task closures. (e.g. +// it's not Parallel's concern how you quit/abort your tasks). +func Parallel(tasks ...Task) (trs *TaskResultSet, ok bool) { + var taskResultChz = make([]TaskResultCh, len(tasks)) // To return. + var taskDoneCh = make(chan bool, len(tasks)) // A "wait group" channel, early abort if any true received. + var numPanics = new(int32) // Keep track of panics to set ok=false later. + ok = true // We will set it to false iff any tasks panic'd or returned abort. + + // Start all tasks in parallel in separate goroutines. + // When the task is complete, it will appear in the + // respective taskResultCh (associated by task index). + for i, task := range tasks { + var taskResultCh = make(chan TaskResult, 1) // Capacity for 1 result. + taskResultChz[i] = taskResultCh + go func(i int, task Task, taskResultCh chan TaskResult) { + // Recovery + defer func() { + if pnk := recover(); pnk != nil { + atomic.AddInt32(numPanics, 1) + // Send panic to taskResultCh. + taskResultCh <- TaskResult{nil, ErrorWrap(pnk, "Panic in task")} + // Closing taskResultCh lets trs.Wait() work. + close(taskResultCh) + // Decrement waitgroup. + taskDoneCh <- false + } + }() + // Run the task. + var val, err, abort = task(i) + // Send val/err to taskResultCh. + // NOTE: Below this line, nothing must panic/ + taskResultCh <- TaskResult{val, err} + // Closing taskResultCh lets trs.Wait() work. + close(taskResultCh) + // Decrement waitgroup. + taskDoneCh <- abort + }(i, task, taskResultCh) + } + + // Wait until all tasks are done, or until abort. + // DONE_LOOP: + for i := 0; i < len(tasks); i++ { + abort := <-taskDoneCh + if abort { + ok = false + break + } + } + + // Ok is also false if there were any panics. + // We must do this check here (after DONE_LOOP). + ok = ok && (atomic.LoadInt32(numPanics) == 0) + + return newTaskResultSet(taskResultChz).Reap(), ok +} diff --git a/common/async_test.go b/common/async_test.go new file mode 100644 index 00000000..f565b4bd --- /dev/null +++ b/common/async_test.go @@ -0,0 +1,156 @@ +package common + +import ( + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParallel(t *testing.T) { + + // Create tasks. + var counter = new(int32) + var tasks = make([]Task, 100*1000) + for i := 0; i < len(tasks); i++ { + tasks[i] = func(i int) (res interface{}, err error, abort bool) { + atomic.AddInt32(counter, 1) + return -1 * i, nil, false + } + } + + // Run in parallel. + var trs, ok = Parallel(tasks...) + assert.True(t, ok) + + // Verify. + assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") + var failedTasks int + for i := 0; i < len(tasks); i++ { + taskResult, ok := trs.LatestResult(i) + if !ok { + assert.Fail(t, "Task #%v did not complete.", i) + failedTasks++ + } else if taskResult.Error != nil { + assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) + failedTasks++ + } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { + assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) + failedTasks++ + } else { + // Good! + } + } + assert.Equal(t, failedTasks, 0, "No task should have failed") + assert.Nil(t, trs.FirstError(), "There should be no errors") + assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") +} + +func TestParallelAbort(t *testing.T) { + + var flow1 = make(chan struct{}, 1) + var flow2 = make(chan struct{}, 1) + var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. + var flow4 = make(chan struct{}, 1) + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 0) + flow1 <- struct{}{} + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 1) + flow2 <- <-flow1 + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 2) + flow3 <- <-flow2 + return 2, nil, true + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 3) + <-flow4 + return 3, nil, false + }, + } + + // Run in parallel. + var taskResultSet, ok = Parallel(tasks...) + assert.False(t, ok, "ok should be false since we aborted task #2.") + + // Verify task #3. + // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send. + waitTimeout(t, taskResultSet.chz[3], "Task #3") + + // Now let the last task (#3) complete after abort. + flow4 <- <-flow3 + + // Wait until all tasks have returned or panic'd. + taskResultSet.Wait() + + // Verify task #0, #1, #2. + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, 2, nil, nil) + checkResult(t, taskResultSet, 3, 3, nil, nil) +} + +func TestParallelRecover(t *testing.T) { + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + panic(2) + }, + } + + // Run in parallel. + var taskResultSet, ok = Parallel(tasks...) + assert.False(t, ok, "ok should be false since we panic'd in task #2.") + + // Verify task #0, #1, #2. + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, nil, nil, 2) +} + +// Wait for result +func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { + taskResult, ok := taskResultSet.LatestResult(index) + taskName := fmt.Sprintf("Task #%v", index) + assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) + assert.Equal(t, val, taskResult.Value, taskName) + if err != nil { + assert.Equal(t, err, taskResult.Error, taskName) + } else if pnk != nil { + assert.Equal(t, pnk, taskResult.Error.(Error).Data(), taskName) + } else { + assert.Nil(t, taskResult.Error, taskName) + } +} + +// Wait for timeout (no result) +func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { + select { + case _, ok := <-taskResultCh: + if !ok { + assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) + } else { + assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) + } + case <-time.After(1 * time.Second): // TODO use deterministic time? + // Good! + } +} diff --git a/common/bit_array.go b/common/bit_array.go new file mode 100644 index 00000000..0290921a --- /dev/null +++ b/common/bit_array.go @@ -0,0 +1,378 @@ +package common + +import ( + "encoding/binary" + "fmt" + "regexp" + "strings" + "sync" +) + +type BitArray struct { + mtx sync.Mutex + Bits int `json:"bits"` // NOTE: persisted via reflect, must be exported + Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported +} + +// There is no BitArray whose Size is 0. Use nil instead. +func NewBitArray(bits int) *BitArray { + if bits <= 0 { + return nil + } + return &BitArray{ + Bits: bits, + Elems: make([]uint64, (bits+63)/64), + } +} + +func (bA *BitArray) Size() int { + if bA == nil { + return 0 + } + return bA.Bits +} + +// NOTE: behavior is undefined if i >= bA.Bits +func (bA *BitArray) GetIndex(i int) bool { + if bA == nil { + return false + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + return bA.getIndex(i) +} + +func (bA *BitArray) getIndex(i int) bool { + if i >= bA.Bits { + return false + } + return bA.Elems[i/64]&(uint64(1)< 0 +} + +// NOTE: behavior is undefined if i >= bA.Bits +func (bA *BitArray) SetIndex(i int, v bool) bool { + if bA == nil { + return false + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + return bA.setIndex(i, v) +} + +func (bA *BitArray) setIndex(i int, v bool) bool { + if i >= bA.Bits { + return false + } + if v { + bA.Elems[i/64] |= (uint64(1) << uint(i%64)) + } else { + bA.Elems[i/64] &= ^(uint64(1) << uint(i%64)) + } + return true +} + +func (bA *BitArray) Copy() *BitArray { + if bA == nil { + return nil + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + return bA.copy() +} + +func (bA *BitArray) copy() *BitArray { + c := make([]uint64, len(bA.Elems)) + copy(c, bA.Elems) + return &BitArray{ + Bits: bA.Bits, + Elems: c, + } +} + +func (bA *BitArray) copyBits(bits int) *BitArray { + c := make([]uint64, (bits+63)/64) + copy(c, bA.Elems) + return &BitArray{ + Bits: bits, + Elems: c, + } +} + +// Returns a BitArray of larger bits size. +func (bA *BitArray) Or(o *BitArray) *BitArray { + if bA == nil && o == nil { + return nil + } + if bA == nil && o != nil { + return o.Copy() + } + if o == nil { + return bA.Copy() + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + c := bA.copyBits(MaxInt(bA.Bits, o.Bits)) + for i := 0; i < len(c.Elems); i++ { + c.Elems[i] |= o.Elems[i] + } + return c +} + +// Returns a BitArray of smaller bit size. +func (bA *BitArray) And(o *BitArray) *BitArray { + if bA == nil || o == nil { + return nil + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + return bA.and(o) +} + +func (bA *BitArray) and(o *BitArray) *BitArray { + c := bA.copyBits(MinInt(bA.Bits, o.Bits)) + for i := 0; i < len(c.Elems); i++ { + c.Elems[i] &= o.Elems[i] + } + return c +} + +func (bA *BitArray) Not() *BitArray { + if bA == nil { + return nil // Degenerate + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + c := bA.copy() + for i := 0; i < len(c.Elems); i++ { + c.Elems[i] = ^c.Elems[i] + } + return c +} + +func (bA *BitArray) Sub(o *BitArray) *BitArray { + if bA == nil || o == nil { + // TODO: Decide if we should do 1's complement here? + return nil + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + if bA.Bits > o.Bits { + c := bA.copy() + for i := 0; i < len(o.Elems)-1; i++ { + c.Elems[i] &= ^c.Elems[i] + } + i := len(o.Elems) - 1 + if i >= 0 { + for idx := i * 64; idx < o.Bits; idx++ { + // NOTE: each individual GetIndex() call to o is safe. + c.setIndex(idx, c.getIndex(idx) && !o.GetIndex(idx)) + } + } + return c + } + return bA.and(o.Not()) // Note degenerate case where o == nil +} + +func (bA *BitArray) IsEmpty() bool { + if bA == nil { + return true // should this be opposite? + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + for _, e := range bA.Elems { + if e > 0 { + return false + } + } + return true +} + +func (bA *BitArray) IsFull() bool { + if bA == nil { + return true + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + + // Check all elements except the last + for _, elem := range bA.Elems[:len(bA.Elems)-1] { + if (^elem) != 0 { + return false + } + } + + // Check that the last element has (lastElemBits) 1's + lastElemBits := (bA.Bits+63)%64 + 1 + lastElem := bA.Elems[len(bA.Elems)-1] + return (lastElem+1)&((uint64(1)< 0 { + randBitStart := RandIntn(64) + for j := 0; j < 64; j++ { + bitIdx := ((j + randBitStart) % 64) + if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { + return 64*elemIdx + bitIdx, true + } + } + PanicSanity("should not happen") + } + } else { + // Special case for last elem, to ignore straggler bits + elemBits := bA.Bits % 64 + if elemBits == 0 { + elemBits = 64 + } + randBitStart := RandIntn(elemBits) + for j := 0; j < elemBits; j++ { + bitIdx := ((j + randBitStart) % elemBits) + if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { + return 64*elemIdx + bitIdx, true + } + } + } + } + return 0, false +} + +// String returns a string representation of BitArray: BA{}, +// where is a sequence of 'x' (1) and '_' (0). +// The includes spaces and newlines to help people. +// For a simple sequence of 'x' and '_' characters with no spaces or newlines, +// see the MarshalJSON() method. +// Example: "BA{_x_}" or "nil-BitArray" for nil. +func (bA *BitArray) String() string { + return bA.StringIndented("") +} + +func (bA *BitArray) StringIndented(indent string) string { + if bA == nil { + return "nil-BitArray" + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + return bA.stringIndented(indent) +} + +func (bA *BitArray) stringIndented(indent string) string { + lines := []string{} + bits := "" + for i := 0; i < bA.Bits; i++ { + if bA.getIndex(i) { + bits += "x" + } else { + bits += "_" + } + if i%100 == 99 { + lines = append(lines, bits) + bits = "" + } + if i%10 == 9 { + bits += indent + } + if i%50 == 49 { + bits += indent + } + } + if len(bits) > 0 { + lines = append(lines, bits) + } + return fmt.Sprintf("BA{%v:%v}", bA.Bits, strings.Join(lines, indent)) +} + +func (bA *BitArray) Bytes() []byte { + bA.mtx.Lock() + defer bA.mtx.Unlock() + + numBytes := (bA.Bits + 7) / 8 + bytes := make([]byte, numBytes) + for i := 0; i < len(bA.Elems); i++ { + elemBytes := [8]byte{} + binary.LittleEndian.PutUint64(elemBytes[:], bA.Elems[i]) + copy(bytes[i*8:], elemBytes[:]) + } + return bytes +} + +// NOTE: other bitarray o is not locked when reading, +// so if necessary, caller must copy or lock o prior to calling Update. +// If bA is nil, does nothing. +func (bA *BitArray) Update(o *BitArray) { + if bA == nil || o == nil { + return + } + bA.mtx.Lock() + defer bA.mtx.Unlock() + + copy(bA.Elems, o.Elems) +} + +// MarshalJSON implements json.Marshaler interface by marshaling bit array +// using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit. +func (bA *BitArray) MarshalJSON() ([]byte, error) { + if bA == nil { + return []byte("null"), nil + } + + bA.mtx.Lock() + defer bA.mtx.Unlock() + + bits := `"` + for i := 0; i < bA.Bits; i++ { + if bA.getIndex(i) { + bits += `x` + } else { + bits += `_` + } + } + bits += `"` + return []byte(bits), nil +} + +var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`) + +// UnmarshalJSON implements json.Unmarshaler interface by unmarshaling a custom +// JSON description. +func (bA *BitArray) UnmarshalJSON(bz []byte) error { + b := string(bz) + if b == "null" { + // This is required e.g. for encoding/json when decoding + // into a pointer with pre-allocated BitArray. + bA.Bits = 0 + bA.Elems = nil + return nil + } + + // Validate 'b'. + match := bitArrayJSONRegexp.FindStringSubmatch(b) + if match == nil { + return fmt.Errorf("BitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b) + } + bits := match[1] + + // Construct new BitArray and copy over. + numBits := len(bits) + bA2 := NewBitArray(numBits) + for i := 0; i < numBits; i++ { + if bits[i] == 'x' { + bA2.SetIndex(i, true) + } + } + *bA = *bA2 + return nil +} diff --git a/common/bit_array_test.go b/common/bit_array_test.go new file mode 100644 index 00000000..c697ba5d --- /dev/null +++ b/common/bit_array_test.go @@ -0,0 +1,267 @@ +package common + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func randBitArray(bits int) (*BitArray, []byte) { + src := RandBytes((bits + 7) / 8) + bA := NewBitArray(bits) + for i := 0; i < len(src); i++ { + for j := 0; j < 8; j++ { + if i*8+j >= bits { + return bA, src + } + setBit := src[i]&(1< 0 + bA.SetIndex(i*8+j, setBit) + } + } + return bA, src +} + +func TestAnd(t *testing.T) { + + bA1, _ := randBitArray(51) + bA2, _ := randBitArray(31) + bA3 := bA1.And(bA2) + + var bNil *BitArray + require.Equal(t, bNil.And(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.And(nil), (*BitArray)(nil)) + require.Equal(t, bNil.And(nil), (*BitArray)(nil)) + + if bA3.Bits != 31 { + t.Error("Expected min bits", bA3.Bits) + } + if len(bA3.Elems) != len(bA2.Elems) { + t.Error("Expected min elems length") + } + for i := 0; i < bA3.Bits; i++ { + expected := bA1.GetIndex(i) && bA2.GetIndex(i) + if bA3.GetIndex(i) != expected { + t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) + } + } +} + +func TestOr(t *testing.T) { + + bA1, _ := randBitArray(51) + bA2, _ := randBitArray(31) + bA3 := bA1.Or(bA2) + + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Or(bA1), bA1) + require.Equal(t, bA1.Or(nil), bA1) + require.Equal(t, bNil.Or(nil), (*BitArray)(nil)) + + if bA3.Bits != 51 { + t.Error("Expected max bits") + } + if len(bA3.Elems) != len(bA1.Elems) { + t.Error("Expected max elems length") + } + for i := 0; i < bA3.Bits; i++ { + expected := bA1.GetIndex(i) || bA2.GetIndex(i) + if bA3.GetIndex(i) != expected { + t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) + } + } +} + +func TestSub1(t *testing.T) { + + bA1, _ := randBitArray(31) + bA2, _ := randBitArray(51) + bA3 := bA1.Sub(bA2) + + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + + if bA3.Bits != bA1.Bits { + t.Error("Expected bA1 bits") + } + if len(bA3.Elems) != len(bA1.Elems) { + t.Error("Expected bA1 elems length") + } + for i := 0; i < bA3.Bits; i++ { + expected := bA1.GetIndex(i) + if bA2.GetIndex(i) { + expected = false + } + if bA3.GetIndex(i) != expected { + t.Error("Wrong bit from bA3", i, bA1.GetIndex(i), bA2.GetIndex(i), bA3.GetIndex(i)) + } + } +} + +func TestSub2(t *testing.T) { + + bA1, _ := randBitArray(51) + bA2, _ := randBitArray(31) + bA3 := bA1.Sub(bA2) + + bNil := (*BitArray)(nil) + require.Equal(t, bNil.Sub(bA1), (*BitArray)(nil)) + require.Equal(t, bA1.Sub(nil), (*BitArray)(nil)) + require.Equal(t, bNil.Sub(nil), (*BitArray)(nil)) + + if bA3.Bits != bA1.Bits { + t.Error("Expected bA1 bits") + } + if len(bA3.Elems) != len(bA1.Elems) { + t.Error("Expected bA1 elems length") + } + for i := 0; i < bA3.Bits; i++ { + expected := bA1.GetIndex(i) + if i < bA2.Bits && bA2.GetIndex(i) { + expected = false + } + if bA3.GetIndex(i) != expected { + t.Error("Wrong bit from bA3") + } + } +} + +func TestPickRandom(t *testing.T) { + for idx := 0; idx < 123; idx++ { + bA1 := NewBitArray(123) + bA1.SetIndex(idx, true) + index, ok := bA1.PickRandom() + if !ok { + t.Fatal("Expected to pick element but got none") + } + if index != idx { + t.Fatalf("Expected to pick element at %v but got wrong index", idx) + } + } +} + +func TestBytes(t *testing.T) { + bA := NewBitArray(4) + bA.SetIndex(0, true) + check := func(bA *BitArray, bz []byte) { + if !bytes.Equal(bA.Bytes(), bz) { + panic(Fmt("Expected %X but got %X", bz, bA.Bytes())) + } + } + check(bA, []byte{0x01}) + bA.SetIndex(3, true) + check(bA, []byte{0x09}) + + bA = NewBitArray(9) + check(bA, []byte{0x00, 0x00}) + bA.SetIndex(7, true) + check(bA, []byte{0x80, 0x00}) + bA.SetIndex(8, true) + check(bA, []byte{0x80, 0x01}) + + bA = NewBitArray(16) + check(bA, []byte{0x00, 0x00}) + bA.SetIndex(7, true) + check(bA, []byte{0x80, 0x00}) + bA.SetIndex(8, true) + check(bA, []byte{0x80, 0x01}) + bA.SetIndex(9, true) + check(bA, []byte{0x80, 0x03}) +} + +func TestEmptyFull(t *testing.T) { + ns := []int{47, 123} + for _, n := range ns { + bA := NewBitArray(n) + if !bA.IsEmpty() { + t.Fatal("Expected bit array to be empty") + } + for i := 0; i < n; i++ { + bA.SetIndex(i, true) + } + if !bA.IsFull() { + t.Fatal("Expected bit array to be full") + } + } +} + +func TestUpdateNeverPanics(t *testing.T) { + newRandBitArray := func(n int) *BitArray { + ba, _ := randBitArray(n) + return ba + } + pairs := []struct { + a, b *BitArray + }{ + {nil, nil}, + {newRandBitArray(10), newRandBitArray(12)}, + {newRandBitArray(23), newRandBitArray(23)}, + {newRandBitArray(37), nil}, + {nil, NewBitArray(10)}, + } + + for _, pair := range pairs { + a, b := pair.a, pair.b + a.Update(b) + b.Update(a) + } +} + +func TestNewBitArrayNeverCrashesOnNegatives(t *testing.T) { + bitList := []int{-127, -128, -1 << 31} + for _, bits := range bitList { + _ = NewBitArray(bits) + } +} + +func TestJSONMarshalUnmarshal(t *testing.T) { + + bA1 := NewBitArray(0) + + bA2 := NewBitArray(1) + + bA3 := NewBitArray(1) + bA3.SetIndex(0, true) + + bA4 := NewBitArray(5) + bA4.SetIndex(0, true) + bA4.SetIndex(1, true) + + testCases := []struct { + bA *BitArray + marshalledBA string + }{ + {nil, `null`}, + {bA1, `null`}, + {bA2, `"_"`}, + {bA3, `"x"`}, + {bA4, `"xx___"`}, + } + + for _, tc := range testCases { + t.Run(tc.bA.String(), func(t *testing.T) { + bz, err := json.Marshal(tc.bA) + require.NoError(t, err) + + assert.Equal(t, tc.marshalledBA, string(bz)) + + var unmarshalledBA *BitArray + err = json.Unmarshal(bz, &unmarshalledBA) + require.NoError(t, err) + + if tc.bA == nil { + require.Nil(t, unmarshalledBA) + } else { + require.NotNil(t, unmarshalledBA) + assert.EqualValues(t, tc.bA.Bits, unmarshalledBA.Bits) + if assert.EqualValues(t, tc.bA.String(), unmarshalledBA.String()) { + assert.EqualValues(t, tc.bA.Elems, unmarshalledBA.Elems) + } + } + }) + } +} diff --git a/common/bytes.go b/common/bytes.go new file mode 100644 index 00000000..711720aa --- /dev/null +++ b/common/bytes.go @@ -0,0 +1,62 @@ +package common + +import ( + "encoding/hex" + "fmt" + "strings" +) + +// The main purpose of HexBytes is to enable HEX-encoding for json/encoding. +type HexBytes []byte + +// Marshal needed for protobuf compatibility +func (bz HexBytes) Marshal() ([]byte, error) { + return bz, nil +} + +// Unmarshal needed for protobuf compatibility +func (bz *HexBytes) Unmarshal(data []byte) error { + *bz = data + return nil +} + +// This is the point of Bytes. +func (bz HexBytes) MarshalJSON() ([]byte, error) { + s := strings.ToUpper(hex.EncodeToString(bz)) + jbz := make([]byte, len(s)+2) + jbz[0] = '"' + copy(jbz[1:], []byte(s)) + jbz[len(jbz)-1] = '"' + return jbz, nil +} + +// This is the point of Bytes. +func (bz *HexBytes) UnmarshalJSON(data []byte) error { + if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { + return fmt.Errorf("Invalid hex string: %s", data) + } + bz2, err := hex.DecodeString(string(data[1 : len(data)-1])) + if err != nil { + return err + } + *bz = bz2 + return nil +} + +// Allow it to fulfill various interfaces in light-client, etc... +func (bz HexBytes) Bytes() []byte { + return bz +} + +func (bz HexBytes) String() string { + return strings.ToUpper(hex.EncodeToString(bz)) +} + +func (bz HexBytes) Format(s fmt.State, verb rune) { + switch verb { + case 'p': + s.Write([]byte(fmt.Sprintf("%p", bz))) + default: + s.Write([]byte(fmt.Sprintf("%X", []byte(bz)))) + } +} diff --git a/common/bytes_test.go b/common/bytes_test.go new file mode 100644 index 00000000..9e11988f --- /dev/null +++ b/common/bytes_test.go @@ -0,0 +1,65 @@ +package common + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This is a trivial test for protobuf compatibility. +func TestMarshal(t *testing.T) { + bz := []byte("hello world") + dataB := HexBytes(bz) + bz2, err := dataB.Marshal() + assert.Nil(t, err) + assert.Equal(t, bz, bz2) + + var dataB2 HexBytes + err = (&dataB2).Unmarshal(bz) + assert.Nil(t, err) + assert.Equal(t, dataB, dataB2) +} + +// Test that the hex encoding works. +func TestJSONMarshal(t *testing.T) { + + type TestStruct struct { + B1 []byte + B2 HexBytes + } + + cases := []struct { + input []byte + expected string + }{ + {[]byte(``), `{"B1":"","B2":""}`}, + {[]byte(`a`), `{"B1":"YQ==","B2":"61"}`}, + {[]byte(`abc`), `{"B1":"YWJj","B2":"616263"}`}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { + ts := TestStruct{B1: tc.input, B2: tc.input} + + // Test that it marshals correctly to JSON. + jsonBytes, err := json.Marshal(ts) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, string(jsonBytes), tc.expected) + + // TODO do fuzz testing to ensure that unmarshal fails + + // Test that unmarshaling works correctly. + ts2 := TestStruct{} + err = json.Unmarshal(jsonBytes, &ts2) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, ts2.B1, tc.input) + assert.Equal(t, ts2.B2, HexBytes(tc.input)) + }) + } +} diff --git a/common/byteslice.go b/common/byteslice.go new file mode 100644 index 00000000..57b3a8a2 --- /dev/null +++ b/common/byteslice.go @@ -0,0 +1,73 @@ +package common + +import ( + "bytes" +) + +// Fingerprint returns the first 6 bytes of a byte slice. +// If the slice is less than 6 bytes, the fingerprint +// contains trailing zeroes. +func Fingerprint(slice []byte) []byte { + fingerprint := make([]byte, 6) + copy(fingerprint, slice) + return fingerprint +} + +func IsZeros(slice []byte) bool { + for _, byt := range slice { + if byt != byte(0) { + return false + } + } + return true +} + +func RightPadBytes(slice []byte, l int) []byte { + if l < len(slice) { + return slice + } + padded := make([]byte, l) + copy(padded[0:len(slice)], slice) + return padded +} + +func LeftPadBytes(slice []byte, l int) []byte { + if l < len(slice) { + return slice + } + padded := make([]byte, l) + copy(padded[l-len(slice):], slice) + return padded +} + +func TrimmedString(b []byte) string { + trimSet := string([]byte{0}) + return string(bytes.TrimLeft(b, trimSet)) + +} + +// PrefixEndBytes returns the end byteslice for a noninclusive range +// that would include all byte slices for which the input is the prefix +func PrefixEndBytes(prefix []byte) []byte { + if prefix == nil { + return nil + } + + end := make([]byte, len(prefix)) + copy(end, prefix) + finished := false + + for !finished { + if end[len(end)-1] != byte(255) { + end[len(end)-1]++ + finished = true + } else { + end = end[:len(end)-1] + if len(end) == 0 { + end = nil + finished = true + } + } + } + return end +} diff --git a/common/byteslice_test.go b/common/byteslice_test.go new file mode 100644 index 00000000..98085d12 --- /dev/null +++ b/common/byteslice_test.go @@ -0,0 +1,28 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPrefixEndBytes(t *testing.T) { + assert := assert.New(t) + + var testCases = []struct { + prefix []byte + expected []byte + }{ + {[]byte{byte(55), byte(255), byte(255), byte(0)}, []byte{byte(55), byte(255), byte(255), byte(1)}}, + {[]byte{byte(55), byte(255), byte(255), byte(15)}, []byte{byte(55), byte(255), byte(255), byte(16)}}, + {[]byte{byte(55), byte(200), byte(255)}, []byte{byte(55), byte(201)}}, + {[]byte{byte(55), byte(255), byte(255)}, []byte{byte(56)}}, + {[]byte{byte(255), byte(255), byte(255)}, nil}, + {nil, nil}, + } + + for _, test := range testCases { + end := PrefixEndBytes(test.prefix) + assert.Equal(test.expected, end) + } +} diff --git a/common/cmap.go b/common/cmap.go new file mode 100644 index 00000000..c65c27d4 --- /dev/null +++ b/common/cmap.go @@ -0,0 +1,73 @@ +package common + +import "sync" + +// CMap is a goroutine-safe map +type CMap struct { + m map[string]interface{} + l sync.Mutex +} + +func NewCMap() *CMap { + return &CMap{ + m: make(map[string]interface{}), + } +} + +func (cm *CMap) Set(key string, value interface{}) { + cm.l.Lock() + defer cm.l.Unlock() + cm.m[key] = value +} + +func (cm *CMap) Get(key string) interface{} { + cm.l.Lock() + defer cm.l.Unlock() + return cm.m[key] +} + +func (cm *CMap) Has(key string) bool { + cm.l.Lock() + defer cm.l.Unlock() + _, ok := cm.m[key] + return ok +} + +func (cm *CMap) Delete(key string) { + cm.l.Lock() + defer cm.l.Unlock() + delete(cm.m, key) +} + +func (cm *CMap) Size() int { + cm.l.Lock() + defer cm.l.Unlock() + return len(cm.m) +} + +func (cm *CMap) Clear() { + cm.l.Lock() + defer cm.l.Unlock() + cm.m = make(map[string]interface{}) +} + +func (cm *CMap) Keys() []string { + cm.l.Lock() + defer cm.l.Unlock() + + keys := []string{} + for k := range cm.m { + keys = append(keys, k) + } + return keys +} + +func (cm *CMap) Values() []interface{} { + cm.l.Lock() + defer cm.l.Unlock() + items := []interface{}{} + for _, v := range cm.m { + items = append(items, v) + } + return items +} diff --git a/common/cmap_test.go b/common/cmap_test.go new file mode 100644 index 00000000..c665a7f3 --- /dev/null +++ b/common/cmap_test.go @@ -0,0 +1,53 @@ +package common + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIterateKeysWithValues(t *testing.T) { + cmap := NewCMap() + + for i := 1; i <= 10; i++ { + cmap.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i)) + } + + // Testing size + assert.Equal(t, 10, cmap.Size()) + assert.Equal(t, 10, len(cmap.Keys())) + assert.Equal(t, 10, len(cmap.Values())) + + // Iterating Keys, checking for matching Value + for _, key := range cmap.Keys() { + val := strings.Replace(key, "key", "value", -1) + assert.Equal(t, val, cmap.Get(key)) + } + + // Test if all keys are within []Keys() + keys := cmap.Keys() + for i := 1; i <= 10; i++ { + assert.Contains(t, keys, fmt.Sprintf("key%d", i), "cmap.Keys() should contain key") + } + + // Delete 1 Key + cmap.Delete("key1") + + assert.NotEqual(t, len(keys), len(cmap.Keys()), "[]keys and []Keys() should not be equal, they are copies, one item was removed") +} + +func TestContains(t *testing.T) { + cmap := NewCMap() + + cmap.Set("key1", "value1") + + // Test for known values + assert.True(t, cmap.Has("key1")) + assert.Equal(t, "value1", cmap.Get("key1")) + + // Test for unknown values + assert.False(t, cmap.Has("key2")) + assert.Nil(t, cmap.Get("key2")) +} diff --git a/common/colors.go b/common/colors.go new file mode 100644 index 00000000..049ce7a5 --- /dev/null +++ b/common/colors.go @@ -0,0 +1,95 @@ +package common + +import ( + "fmt" + "strings" +) + +const ( + ANSIReset = "\x1b[0m" + ANSIBright = "\x1b[1m" + ANSIDim = "\x1b[2m" + ANSIUnderscore = "\x1b[4m" + ANSIBlink = "\x1b[5m" + ANSIReverse = "\x1b[7m" + ANSIHidden = "\x1b[8m" + + ANSIFgBlack = "\x1b[30m" + ANSIFgRed = "\x1b[31m" + ANSIFgGreen = "\x1b[32m" + ANSIFgYellow = "\x1b[33m" + ANSIFgBlue = "\x1b[34m" + ANSIFgMagenta = "\x1b[35m" + ANSIFgCyan = "\x1b[36m" + ANSIFgWhite = "\x1b[37m" + + ANSIBgBlack = "\x1b[40m" + ANSIBgRed = "\x1b[41m" + ANSIBgGreen = "\x1b[42m" + ANSIBgYellow = "\x1b[43m" + ANSIBgBlue = "\x1b[44m" + ANSIBgMagenta = "\x1b[45m" + ANSIBgCyan = "\x1b[46m" + ANSIBgWhite = "\x1b[47m" +) + +// color the string s with color 'color' +// unless s is already colored +func treat(s string, color string) string { + if len(s) > 2 && s[:2] == "\x1b[" { + return s + } + return color + s + ANSIReset +} + +func treatAll(color string, args ...interface{}) string { + var parts []string + for _, arg := range args { + parts = append(parts, treat(fmt.Sprintf("%v", arg), color)) + } + return strings.Join(parts, "") +} + +func Black(args ...interface{}) string { + return treatAll(ANSIFgBlack, args...) +} + +func Red(args ...interface{}) string { + return treatAll(ANSIFgRed, args...) +} + +func Green(args ...interface{}) string { + return treatAll(ANSIFgGreen, args...) +} + +func Yellow(args ...interface{}) string { + return treatAll(ANSIFgYellow, args...) +} + +func Blue(args ...interface{}) string { + return treatAll(ANSIFgBlue, args...) +} + +func Magenta(args ...interface{}) string { + return treatAll(ANSIFgMagenta, args...) +} + +func Cyan(args ...interface{}) string { + return treatAll(ANSIFgCyan, args...) +} + +func White(args ...interface{}) string { + return treatAll(ANSIFgWhite, args...) +} + +func ColoredBytes(data []byte, textColor, bytesColor func(...interface{}) string) string { + s := "" + for _, b := range data { + if 0x21 <= b && b < 0x7F { + s += textColor(string(b)) + } else { + s += bytesColor(Fmt("%02X", b)) + } + } + return s +} diff --git a/common/date.go b/common/date.go new file mode 100644 index 00000000..e017a4b4 --- /dev/null +++ b/common/date.go @@ -0,0 +1,43 @@ +package common + +import ( + "strings" + "time" + + "github.com/pkg/errors" +) + +// TimeLayout helps to parse a date string of the format YYYY-MM-DD +// Intended to be used with the following function: +// time.Parse(TimeLayout, date) +var TimeLayout = "2006-01-02" //this represents YYYY-MM-DD + +// ParseDateRange parses a date range string of the format start:end +// where the start and end date are of the format YYYY-MM-DD. +// The parsed dates are time.Time and will return the zero time for +// unbounded dates, ex: +// unbounded start: :2000-12-31 +// unbounded end: 2000-12-31: +func ParseDateRange(dateRange string) (startDate, endDate time.Time, err error) { + dates := strings.Split(dateRange, ":") + if len(dates) != 2 { + err = errors.New("bad date range, must be in format date:date") + return + } + parseDate := func(date string) (out time.Time, err error) { + if len(date) == 0 { + return + } + out, err = time.Parse(TimeLayout, date) + return + } + startDate, err = parseDate(dates[0]) + if err != nil { + return + } + endDate, err = parseDate(dates[1]) + if err != nil { + return + } + return +} diff --git a/common/date_test.go b/common/date_test.go new file mode 100644 index 00000000..2c063247 --- /dev/null +++ b/common/date_test.go @@ -0,0 +1,46 @@ +package common + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + date = time.Date(2015, time.Month(12), 31, 0, 0, 0, 0, time.UTC) + date2 = time.Date(2016, time.Month(12), 31, 0, 0, 0, 0, time.UTC) + zero time.Time +) + +func TestParseDateRange(t *testing.T) { + assert := assert.New(t) + + var testDates = []struct { + dateStr string + start time.Time + end time.Time + errNil bool + }{ + {"2015-12-31:2016-12-31", date, date2, true}, + {"2015-12-31:", date, zero, true}, + {":2016-12-31", zero, date2, true}, + {"2016-12-31", zero, zero, false}, + {"2016-31-12:", zero, zero, false}, + {":2016-31-12", zero, zero, false}, + } + + for _, test := range testDates { + start, end, err := ParseDateRange(test.dateStr) + if test.errNil { + assert.Nil(err) + testPtr := func(want, have time.Time) { + assert.True(have.Equal(want)) + } + testPtr(test.start, start) + testPtr(test.end, end) + } else { + assert.NotNil(err) + } + } +} diff --git a/common/errors.go b/common/errors.go new file mode 100644 index 00000000..5c31b896 --- /dev/null +++ b/common/errors.go @@ -0,0 +1,246 @@ +package common + +import ( + "fmt" + "runtime" +) + +//---------------------------------------- +// Convenience method. + +func ErrorWrap(cause interface{}, format string, args ...interface{}) Error { + if causeCmnError, ok := cause.(*cmnError); ok { + msg := Fmt(format, args...) + return causeCmnError.Stacktrace().Trace(1, msg) + } else if cause == nil { + return newCmnError(FmtError{format, args}).Stacktrace() + } else { + // NOTE: causeCmnError is a typed nil here. + msg := Fmt(format, args...) + return newCmnError(cause).Stacktrace().Trace(1, msg) + } +} + +//---------------------------------------- +// Error & cmnError + +/* + +Usage with arbitrary error data: + +```go + // Error construction + type MyError struct{} + var err1 error = NewErrorWithData(MyError{}, "my message") + ... + // Wrapping + var err2 error = ErrorWrap(err1, "another message") + if (err1 != err2) { panic("should be the same") + ... + // Error handling + switch err2.Data().(type){ + case MyError: ... + default: ... + } +``` +*/ +type Error interface { + Error() string + Stacktrace() Error + Trace(offset int, format string, args ...interface{}) Error + Data() interface{} +} + +// New Error with formatted message. +// The Error's Data will be a FmtError type. +func NewError(format string, args ...interface{}) Error { + err := FmtError{format, args} + return newCmnError(err) +} + +// New Error with specified data. +func NewErrorWithData(data interface{}) Error { + return newCmnError(data) +} + +type cmnError struct { + data interface{} // associated data + msgtraces []msgtraceItem // all messages traced + stacktrace []uintptr // first stack trace +} + +var _ Error = &cmnError{} + +// NOTE: do not expose. +func newCmnError(data interface{}) *cmnError { + return &cmnError{ + data: data, + msgtraces: nil, + stacktrace: nil, + } +} + +// Implements error. +func (err *cmnError) Error() string { + return fmt.Sprintf("%v", err) +} + +// Captures a stacktrace if one was not already captured. +func (err *cmnError) Stacktrace() Error { + if err.stacktrace == nil { + var offset = 3 + var depth = 32 + err.stacktrace = captureStacktrace(offset, depth) + } + return err +} + +// Add tracing information with msg. +// Set n=0 unless wrapped with some function, then n > 0. +func (err *cmnError) Trace(offset int, format string, args ...interface{}) Error { + msg := Fmt(format, args...) + return err.doTrace(msg, offset) +} + +// Return the "data" of this error. +// Data could be used for error handling/switching, +// or for holding general error/debug information. +func (err *cmnError) Data() interface{} { + return err.data +} + +func (err *cmnError) doTrace(msg string, n int) Error { + pc, _, _, _ := runtime.Caller(n + 2) // +1 for doTrace(). +1 for the caller. + // Include file & line number & msg. + // Do not include the whole stack trace. + err.msgtraces = append(err.msgtraces, msgtraceItem{ + pc: pc, + msg: msg, + }) + return err +} + +func (err *cmnError) Format(s fmt.State, verb rune) { + switch verb { + case 'p': + s.Write([]byte(fmt.Sprintf("%p", &err))) + default: + if s.Flag('#') { + s.Write([]byte("--= Error =--\n")) + // Write data. + s.Write([]byte(fmt.Sprintf("Data: %#v\n", err.data))) + // Write msg trace items. + s.Write([]byte(fmt.Sprintf("Msg Traces:\n"))) + for i, msgtrace := range err.msgtraces { + s.Write([]byte(fmt.Sprintf(" %4d %s\n", i, msgtrace.String()))) + } + // Write stack trace. + if err.stacktrace != nil { + s.Write([]byte(fmt.Sprintf("Stack Trace:\n"))) + for i, pc := range err.stacktrace { + fnc := runtime.FuncForPC(pc) + file, line := fnc.FileLine(pc) + s.Write([]byte(fmt.Sprintf(" %4d %s:%d\n", i, file, line))) + } + } + s.Write([]byte("--= /Error =--\n")) + } else { + // Write msg. + s.Write([]byte(fmt.Sprintf("Error{%v}", err.data))) // TODO tick-esc? + } + } +} + +//---------------------------------------- +// stacktrace & msgtraceItem + +func captureStacktrace(offset int, depth int) []uintptr { + var pcs = make([]uintptr, depth) + n := runtime.Callers(offset, pcs) + return pcs[0:n] +} + +type msgtraceItem struct { + pc uintptr + msg string +} + +func (mti msgtraceItem) String() string { + fnc := runtime.FuncForPC(mti.pc) + file, line := fnc.FileLine(mti.pc) + return fmt.Sprintf("%s:%d - %s", + file, line, + mti.msg, + ) +} + +//---------------------------------------- +// fmt error + +/* + +FmtError is the data type for NewError() (e.g. NewError().Data().(FmtError)) +Theoretically it could be used to switch on the format string. + +```go + // Error construction + var err1 error = NewError("invalid username %v", "BOB") + var err2 error = NewError("another kind of error") + ... + // Error handling + switch err1.Data().(cmn.FmtError).Format() { + case "invalid username %v": ... + case "another kind of error": ... + default: ... + } +``` +*/ +type FmtError struct { + format string + args []interface{} +} + +func (fe FmtError) Error() string { + return fmt.Sprintf(fe.format, fe.args...) +} + +func (fe FmtError) String() string { + return fmt.Sprintf("FmtError{format:%v,args:%v}", + fe.format, fe.args) +} + +func (fe FmtError) Format() string { + return fe.format +} + +//---------------------------------------- +// Panic wrappers +// XXX DEPRECATED + +// A panic resulting from a sanity check means there is a programmer error +// and some guarantee is not satisfied. +// XXX DEPRECATED +func PanicSanity(v interface{}) { + panic(Fmt("Panicked on a Sanity Check: %v", v)) +} + +// A panic here means something has gone horribly wrong, in the form of data corruption or +// failure of the operating system. In a correct/healthy system, these should never fire. +// If they do, it's indicative of a much more serious problem. +// XXX DEPRECATED +func PanicCrisis(v interface{}) { + panic(Fmt("Panicked on a Crisis: %v", v)) +} + +// Indicates a failure of consensus. Someone was malicious or something has +// gone horribly wrong. These should really boot us into an "emergency-recover" mode +// XXX DEPRECATED +func PanicConsensus(v interface{}) { + panic(Fmt("Panicked on a Consensus Failure: %v", v)) +} + +// For those times when we're not sure if we should panic +// XXX DEPRECATED +func PanicQ(v interface{}) { + panic(Fmt("Panicked questionably: %v", v)) +} diff --git a/common/errors_test.go b/common/errors_test.go new file mode 100644 index 00000000..16aede22 --- /dev/null +++ b/common/errors_test.go @@ -0,0 +1,103 @@ +package common + +import ( + fmt "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorPanic(t *testing.T) { + type pnk struct { + msg string + } + + capturePanic := func() (err Error) { + defer func() { + if r := recover(); r != nil { + err = ErrorWrap(r, "This is the message in ErrorWrap(r, message).") + } + return + }() + panic(pnk{"something"}) + return nil + } + + var err = capturePanic() + + assert.Equal(t, pnk{"something"}, err.Data()) + assert.Equal(t, "Error{{something}}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "This is the message in ErrorWrap(r, message).") + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorWrapSomething(t *testing.T) { + + var err = ErrorWrap("something", "formatter%v%v", 0, 1) + + assert.Equal(t, "something", err.Data()) + assert.Equal(t, "Error{something}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `formatter01\n`, fmt.Sprintf("%#v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorWrapNothing(t *testing.T) { + + var err = ErrorWrap(nil, "formatter%v%v", 0, 1) + + assert.Equal(t, + FmtError{"formatter%v%v", []interface{}{0, 1}}, + err.Data()) + assert.Equal(t, "Error{formatter01}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), `Data: common.FmtError{format:"formatter%v%v", args:[]interface {}{0, 1}}`) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorNewError(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1) + + assert.Equal(t, + FmtError{"formatter%v%v", []interface{}{0, 1}}, + err.Data()) + assert.Equal(t, "Error{formatter01}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), `Data: common.FmtError{format:"formatter%v%v", args:[]interface {}{0, 1}}`) + assert.NotContains(t, fmt.Sprintf("%#v", err), "Stack Trace") +} + +func TestErrorNewErrorWithStacktrace(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1).Stacktrace() + + assert.Equal(t, + FmtError{"formatter%v%v", []interface{}{0, 1}}, + err.Data()) + assert.Equal(t, "Error{formatter01}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), `Data: common.FmtError{format:"formatter%v%v", args:[]interface {}{0, 1}}`) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorNewErrorWithTrace(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1) + err.Trace(0, "trace %v", 1) + err.Trace(0, "trace %v", 2) + err.Trace(0, "trace %v", 3) + + assert.Equal(t, + FmtError{"formatter%v%v", []interface{}{0, 1}}, + err.Data()) + assert.Equal(t, "Error{formatter01}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), `Data: common.FmtError{format:"formatter%v%v", args:[]interface {}{0, 1}}`) + dump := fmt.Sprintf("%#v", err) + assert.NotContains(t, dump, "Stack Trace") + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 1`, dump) + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 2`, dump) + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 3`, dump) +} + +func TestErrorWrapError(t *testing.T) { + var err1 error = NewError("my message") + var err2 error = ErrorWrap(err1, "another message") + assert.Equal(t, err1, err2) +} diff --git a/common/heap.go b/common/heap.go new file mode 100644 index 00000000..b3bcb9db --- /dev/null +++ b/common/heap.go @@ -0,0 +1,125 @@ +package common + +import ( + "bytes" + "container/heap" +) + +/* + Example usage: + + ``` + h := NewHeap() + + h.Push("msg1", 1) + h.Push("msg3", 3) + h.Push("msg2", 2) + + fmt.Println(h.Pop()) // msg1 + fmt.Println(h.Pop()) // msg2 + fmt.Println(h.Pop()) // msg3 + ``` +*/ +type Heap struct { + pq priorityQueue +} + +func NewHeap() *Heap { + return &Heap{pq: make([]*pqItem, 0)} +} + +func (h *Heap) Len() int64 { + return int64(len(h.pq)) +} + +func (h *Heap) Push(value interface{}, priority int) { + heap.Push(&h.pq, &pqItem{value: value, priority: cmpInt(priority)}) +} + +func (h *Heap) PushBytes(value interface{}, priority []byte) { + heap.Push(&h.pq, &pqItem{value: value, priority: cmpBytes(priority)}) +} + +func (h *Heap) PushComparable(value interface{}, priority Comparable) { + heap.Push(&h.pq, &pqItem{value: value, priority: priority}) +} + +func (h *Heap) Peek() interface{} { + if len(h.pq) == 0 { + return nil + } + return h.pq[0].value +} + +func (h *Heap) Update(value interface{}, priority Comparable) { + h.pq.Update(h.pq[0], value, priority) +} + +func (h *Heap) Pop() interface{} { + item := heap.Pop(&h.pq).(*pqItem) + return item.value +} + +//----------------------------------------------------------------------------- +// From: http://golang.org/pkg/container/heap/#example__priorityQueue + +type pqItem struct { + value interface{} + priority Comparable + index int +} + +type priorityQueue []*pqItem + +func (pq priorityQueue) Len() int { return len(pq) } + +func (pq priorityQueue) Less(i, j int) bool { + return pq[i].priority.Less(pq[j].priority) +} + +func (pq priorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (pq *priorityQueue) Push(x interface{}) { + n := len(*pq) + item := x.(*pqItem) + item.index = n + *pq = append(*pq, item) +} + +func (pq *priorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + item.index = -1 // for safety + *pq = old[0 : n-1] + return item +} + +func (pq *priorityQueue) Update(item *pqItem, value interface{}, priority Comparable) { + item.value = value + item.priority = priority + heap.Fix(pq, item.index) +} + +//-------------------------------------------------------------------------------- +// Comparable + +type Comparable interface { + Less(o interface{}) bool +} + +type cmpInt int + +func (i cmpInt) Less(o interface{}) bool { + return int(i) < int(o.(cmpInt)) +} + +type cmpBytes []byte + +func (bz cmpBytes) Less(o interface{}) bool { + return bytes.Compare([]byte(bz), []byte(o.(cmpBytes))) < 0 +} diff --git a/common/int.go b/common/int.go new file mode 100644 index 00000000..a8a5f1e0 --- /dev/null +++ b/common/int.go @@ -0,0 +1,65 @@ +package common + +import ( + "encoding/binary" + "sort" +) + +// Sort for []uint64 + +type Uint64Slice []uint64 + +func (p Uint64Slice) Len() int { return len(p) } +func (p Uint64Slice) Less(i, j int) bool { return p[i] < p[j] } +func (p Uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p Uint64Slice) Sort() { sort.Sort(p) } + +func SearchUint64s(a []uint64, x uint64) int { + return sort.Search(len(a), func(i int) bool { return a[i] >= x }) +} + +func (p Uint64Slice) Search(x uint64) int { return SearchUint64s(p, x) } + +//-------------------------------------------------------------------------------- + +func PutUint64LE(dest []byte, i uint64) { + binary.LittleEndian.PutUint64(dest, i) +} + +func GetUint64LE(src []byte) uint64 { + return binary.LittleEndian.Uint64(src) +} + +func PutUint64BE(dest []byte, i uint64) { + binary.BigEndian.PutUint64(dest, i) +} + +func GetUint64BE(src []byte) uint64 { + return binary.BigEndian.Uint64(src) +} + +func PutInt64LE(dest []byte, i int64) { + binary.LittleEndian.PutUint64(dest, uint64(i)) +} + +func GetInt64LE(src []byte) int64 { + return int64(binary.LittleEndian.Uint64(src)) +} + +func PutInt64BE(dest []byte, i int64) { + binary.BigEndian.PutUint64(dest, uint64(i)) +} + +func GetInt64BE(src []byte) int64 { + return int64(binary.BigEndian.Uint64(src)) +} + +// IntInSlice returns true if a is found in the list. +func IntInSlice(a int, list []int) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} diff --git a/common/int_test.go b/common/int_test.go new file mode 100644 index 00000000..1ecc7844 --- /dev/null +++ b/common/int_test.go @@ -0,0 +1,14 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIntInSlice(t *testing.T) { + assert.True(t, IntInSlice(1, []int{1, 2, 3})) + assert.False(t, IntInSlice(4, []int{1, 2, 3})) + assert.True(t, IntInSlice(0, []int{0})) + assert.False(t, IntInSlice(0, []int{})) +} diff --git a/common/io.go b/common/io.go new file mode 100644 index 00000000..fa0443e0 --- /dev/null +++ b/common/io.go @@ -0,0 +1,74 @@ +package common + +import ( + "bytes" + "errors" + "io" +) + +type PrefixedReader struct { + Prefix []byte + reader io.Reader +} + +func NewPrefixedReader(prefix []byte, reader io.Reader) *PrefixedReader { + return &PrefixedReader{prefix, reader} +} + +func (pr *PrefixedReader) Read(p []byte) (n int, err error) { + if len(pr.Prefix) > 0 { + read := copy(p, pr.Prefix) + pr.Prefix = pr.Prefix[read:] + return read, nil + } + return pr.reader.Read(p) +} + +// NOTE: Not goroutine safe +type BufferCloser struct { + bytes.Buffer + Closed bool +} + +func NewBufferCloser(buf []byte) *BufferCloser { + return &BufferCloser{ + *bytes.NewBuffer(buf), + false, + } +} + +func (bc *BufferCloser) Close() error { + if bc.Closed { + return errors.New("BufferCloser already closed") + } + bc.Closed = true + return nil +} + +func (bc *BufferCloser) Write(p []byte) (n int, err error) { + if bc.Closed { + return 0, errors.New("Cannot write to closed BufferCloser") + } + return bc.Buffer.Write(p) +} + +func (bc *BufferCloser) WriteByte(c byte) error { + if bc.Closed { + return errors.New("Cannot write to closed BufferCloser") + } + return bc.Buffer.WriteByte(c) +} + +func (bc *BufferCloser) WriteRune(r rune) (n int, err error) { + if bc.Closed { + return 0, errors.New("Cannot write to closed BufferCloser") + } + return bc.Buffer.WriteRune(r) +} + +func (bc *BufferCloser) WriteString(s string) (n int, err error) { + if bc.Closed { + return 0, errors.New("Cannot write to closed BufferCloser") + } + return bc.Buffer.WriteString(s) +} diff --git a/common/kvpair.go b/common/kvpair.go new file mode 100644 index 00000000..54c3a58c --- /dev/null +++ b/common/kvpair.go @@ -0,0 +1,67 @@ +package common + +import ( + "bytes" + "sort" +) + +//---------------------------------------- +// KVPair + +/* +Defined in types.proto + +type KVPair struct { + Key []byte + Value []byte +} +*/ + +type KVPairs []KVPair + +// Sorting +func (kvs KVPairs) Len() int { return len(kvs) } +func (kvs KVPairs) Less(i, j int) bool { + switch bytes.Compare(kvs[i].Key, kvs[j].Key) { + case -1: + return true + case 0: + return bytes.Compare(kvs[i].Value, kvs[j].Value) < 0 + case 1: + return false + default: + panic("invalid comparison result") + } +} +func (kvs KVPairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } +func (kvs KVPairs) Sort() { sort.Sort(kvs) } + +//---------------------------------------- +// KI64Pair + +/* +Defined in types.proto +type KI64Pair struct { + Key []byte + Value int64 +} +*/ + +type KI64Pairs []KI64Pair + +// Sorting +func (kvs KI64Pairs) Len() int { return len(kvs) } +func (kvs KI64Pairs) Less(i, j int) bool { + switch bytes.Compare(kvs[i].Key, kvs[j].Key) { + case -1: + return true + case 0: + return kvs[i].Value < kvs[j].Value + case 1: + return false + default: + panic("invalid comparison result") + } +} +func (kvs KI64Pairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } +func (kvs KI64Pairs) Sort() { sort.Sort(kvs) } diff --git a/common/math.go b/common/math.go new file mode 100644 index 00000000..b037d1a7 --- /dev/null +++ b/common/math.go @@ -0,0 +1,157 @@ +package common + +func MaxInt8(a, b int8) int8 { + if a > b { + return a + } + return b +} + +func MaxUint8(a, b uint8) uint8 { + if a > b { + return a + } + return b +} + +func MaxInt16(a, b int16) int16 { + if a > b { + return a + } + return b +} + +func MaxUint16(a, b uint16) uint16 { + if a > b { + return a + } + return b +} + +func MaxInt32(a, b int32) int32 { + if a > b { + return a + } + return b +} + +func MaxUint32(a, b uint32) uint32 { + if a > b { + return a + } + return b +} + +func MaxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func MaxUint64(a, b uint64) uint64 { + if a > b { + return a + } + return b +} + +func MaxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func MaxUint(a, b uint) uint { + if a > b { + return a + } + return b +} + +//----------------------------------------------------------------------------- + +func MinInt8(a, b int8) int8 { + if a < b { + return a + } + return b +} + +func MinUint8(a, b uint8) uint8 { + if a < b { + return a + } + return b +} + +func MinInt16(a, b int16) int16 { + if a < b { + return a + } + return b +} + +func MinUint16(a, b uint16) uint16 { + if a < b { + return a + } + return b +} + +func MinInt32(a, b int32) int32 { + if a < b { + return a + } + return b +} + +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +func MinInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +func MinInt(a, b int) int { + if a < b { + return a + } + return b +} + +func MinUint(a, b uint) uint { + if a < b { + return a + } + return b +} + +//----------------------------------------------------------------------------- + +func ExpUint64(a, b uint64) uint64 { + accum := uint64(1) + for b > 0 { + if b&1 == 1 { + accum *= a + } + a *= a + b >>= 1 + } + return accum +} diff --git a/common/net.go b/common/net.go new file mode 100644 index 00000000..bdbe38f7 --- /dev/null +++ b/common/net.go @@ -0,0 +1,26 @@ +package common + +import ( + "net" + "strings" +) + +// Connect dials the given address and returns a net.Conn. The protoAddr argument should be prefixed with the protocol, +// eg. "tcp://127.0.0.1:8080" or "unix:///tmp/test.sock" +func Connect(protoAddr string) (net.Conn, error) { + proto, address := ProtocolAndAddress(protoAddr) + conn, err := net.Dial(proto, address) + return conn, err +} + +// ProtocolAndAddress splits an address into the protocol and address components. +// For instance, "tcp://127.0.0.1:8080" will be split into "tcp" and "127.0.0.1:8080". +// If the address has no protocol prefix, the default is "tcp". +func ProtocolAndAddress(listenAddr string) (string, string) { + protocol, address := "tcp", listenAddr + parts := strings.SplitN(address, "://", 2) + if len(parts) == 2 { + protocol, address = parts[0], parts[1] + } + return protocol, address +} diff --git a/common/net_test.go b/common/net_test.go new file mode 100644 index 00000000..38d2ae82 --- /dev/null +++ b/common/net_test.go @@ -0,0 +1,38 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProtocolAndAddress(t *testing.T) { + + cases := []struct { + fullAddr string + proto string + addr string + }{ + { + "tcp://mydomain:80", + "tcp", + "mydomain:80", + }, + { + "mydomain:80", + "tcp", + "mydomain:80", + }, + { + "unix://mydomain:80", + "unix", + "mydomain:80", + }, + } + + for _, c := range cases { + proto, addr := ProtocolAndAddress(c.fullAddr) + assert.Equal(t, proto, c.proto) + assert.Equal(t, addr, c.addr) + } +} diff --git a/common/nil.go b/common/nil.go new file mode 100644 index 00000000..31f75f00 --- /dev/null +++ b/common/nil.go @@ -0,0 +1,29 @@ +package common + +import "reflect" + +// Go lacks a simple and safe way to see if something is a typed nil. +// See: +// - https://dave.cheney.net/2017/08/09/typed-nils-in-go-2 +// - https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I/discussion +// - https://github.com/golang/go/issues/21538 +func IsTypedNil(o interface{}) bool { + rv := reflect.ValueOf(o) + switch rv.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice: + return rv.IsNil() + default: + return false + } +} + +// Returns true if it has zero length. +func IsEmpty(o interface{}) bool { + rv := reflect.ValueOf(o) + switch rv.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return rv.Len() == 0 + default: + return false + } +} diff --git a/common/os.go b/common/os.go new file mode 100644 index 00000000..00f4da57 --- /dev/null +++ b/common/os.go @@ -0,0 +1,195 @@ +package common + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "syscall" +) + +var gopath string + +// GoPath returns GOPATH env variable value. If it is not set, this function +// will try to call `go env GOPATH` subcommand. +func GoPath() string { + if gopath != "" { + return gopath + } + + path := os.Getenv("GOPATH") + if len(path) == 0 { + goCmd := exec.Command("go", "env", "GOPATH") + out, err := goCmd.Output() + if err != nil { + panic(fmt.Sprintf("failed to determine gopath: %v", err)) + } + path = string(out) + } + gopath = path + return path +} + +// TrapSignal catches the SIGTERM and executes cb function. After that it exits +// with code 1. +func TrapSignal(cb func()) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + for sig := range c { + fmt.Printf("captured %v, exiting...\n", sig) + if cb != nil { + cb() + } + os.Exit(1) + } + }() + select {} +} + +// Kill the running process by sending itself SIGTERM. +func Kill() error { + p, err := os.FindProcess(os.Getpid()) + if err != nil { + return err + } + return p.Signal(syscall.SIGTERM) +} + +func Exit(s string) { + fmt.Printf(s + "\n") + os.Exit(1) +} + +func EnsureDir(dir string, mode os.FileMode) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, mode) + if err != nil { + return fmt.Errorf("Could not create directory %v. %v", dir, err) + } + } + return nil +} + +func IsDirEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + if os.IsNotExist(err) { + return true, err + } + // Otherwise perhaps a permission + // error or some other error. + return false, err + } + defer f.Close() + + _, err = f.Readdirnames(1) // Or f.Readdir(1) + if err == io.EOF { + return true, nil + } + return false, err // Either not empty or error, suits both cases +} + +func FileExists(filePath string) bool { + _, err := os.Stat(filePath) + return !os.IsNotExist(err) +} + +func ReadFile(filePath string) ([]byte, error) { + return ioutil.ReadFile(filePath) +} + +func MustReadFile(filePath string) []byte { + fileBytes, err := ioutil.ReadFile(filePath) + if err != nil { + Exit(Fmt("MustReadFile failed: %v", err)) + return nil + } + return fileBytes +} + +func WriteFile(filePath string, contents []byte, mode os.FileMode) error { + return ioutil.WriteFile(filePath, contents, mode) +} + +func MustWriteFile(filePath string, contents []byte, mode os.FileMode) { + err := WriteFile(filePath, contents, mode) + if err != nil { + Exit(Fmt("MustWriteFile failed: %v", err)) + } +} + +// WriteFileAtomic creates a temporary file with data and the perm given and +// swaps it atomically with filename if successful. +func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { + var ( + dir = filepath.Dir(filename) + tempFile = filepath.Join(dir, "write-file-atomic-"+RandStr(32)) + // Override in case it does exist, create in case it doesn't and force kernel + // flush, which still leaves the potential of lingering disk cache. + flag = os.O_WRONLY | os.O_CREATE | os.O_SYNC | os.O_TRUNC + ) + + f, err := os.OpenFile(tempFile, flag, perm) + if err != nil { + return err + } + // Clean up in any case. Defer stacking order is last-in-first-out. + defer os.Remove(f.Name()) + defer f.Close() + + if n, err := f.Write(data); err != nil { + return err + } else if n < len(data) { + return io.ErrShortWrite + } + // Close the file before renaming it, otherwise it will cause "The process + // cannot access the file because it is being used by another process." on windows. + f.Close() + + return os.Rename(f.Name(), filename) +} + +//-------------------------------------------------------------------------------- + +func Tempfile(prefix string) (*os.File, string) { + file, err := ioutil.TempFile("", prefix) + if err != nil { + PanicCrisis(err) + } + return file, file.Name() +} + +func Tempdir(prefix string) (*os.File, string) { + tempDir := os.TempDir() + "/" + prefix + RandStr(12) + err := EnsureDir(tempDir, 0700) + if err != nil { + panic(Fmt("Error creating temp dir: %v", err)) + } + dir, err := os.Open(tempDir) + if err != nil { + panic(Fmt("Error opening temp dir: %v", err)) + } + return dir, tempDir +} + +//-------------------------------------------------------------------------------- + +func Prompt(prompt string, defaultValue string) (string, error) { + fmt.Print(prompt) + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil { + return defaultValue, err + } + line = strings.TrimSpace(line) + if line == "" { + return defaultValue, nil + } + return line, nil +} diff --git a/common/os_test.go b/common/os_test.go new file mode 100644 index 00000000..973d6890 --- /dev/null +++ b/common/os_test.go @@ -0,0 +1,91 @@ +package common + +import ( + "bytes" + "io/ioutil" + "math/rand" + "os" + "testing" + "time" +) + +func TestWriteFileAtomic(t *testing.T) { + var ( + seed = rand.New(rand.NewSource(time.Now().UnixNano())) + data = []byte(RandStr(seed.Intn(2048))) + old = RandBytes(seed.Intn(2048)) + perm os.FileMode = 0600 + ) + + f, err := ioutil.TempFile("/tmp", "write-atomic-test-") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + if err = ioutil.WriteFile(f.Name(), old, 0664); err != nil { + t.Fatal(err) + } + + if err = WriteFileAtomic(f.Name(), data, perm); err != nil { + t.Fatal(err) + } + + rData, err := ioutil.ReadFile(f.Name()) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, rData) { + t.Fatalf("data mismatch: %v != %v", data, rData) + } + + stat, err := os.Stat(f.Name()) + if err != nil { + t.Fatal(err) + } + + if have, want := stat.Mode().Perm(), perm; have != want { + t.Errorf("have %v, want %v", have, want) + } +} + +func TestGoPath(t *testing.T) { + // restore original gopath upon exit + path := os.Getenv("GOPATH") + defer func() { + _ = os.Setenv("GOPATH", path) + }() + + err := os.Setenv("GOPATH", "~/testgopath") + if err != nil { + t.Fatal(err) + } + path = GoPath() + if path != "~/testgopath" { + t.Fatalf("should get GOPATH env var value, got %v", path) + } + os.Unsetenv("GOPATH") + + path = GoPath() + if path != "~/testgopath" { + t.Fatalf("subsequent calls should return the same value, got %v", path) + } +} + +func TestGoPathWithoutEnvVar(t *testing.T) { + // restore original gopath upon exit + path := os.Getenv("GOPATH") + defer func() { + _ = os.Setenv("GOPATH", path) + }() + + os.Unsetenv("GOPATH") + // reset cache + gopath = "" + + path = GoPath() + if path == "" || path == "~/testgopath" { + t.Fatalf("should get nonempty result of calling go env GOPATH, got %v", path) + } +} diff --git a/common/random.go b/common/random.go new file mode 100644 index 00000000..389a32fc --- /dev/null +++ b/common/random.go @@ -0,0 +1,357 @@ +package common + +import ( + crand "crypto/rand" + mrand "math/rand" + "sync" + "time" +) + +const ( + strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters +) + +// pseudo random number generator. +// seeded with OS randomness (crand) + +type Rand struct { + sync.Mutex + rand *mrand.Rand +} + +var grand *Rand + +func init() { + grand = NewRand() + grand.init() +} + +func NewRand() *Rand { + rand := &Rand{} + rand.init() + return rand +} + +func (r *Rand) init() { + bz := cRandBytes(8) + var seed uint64 + for i := 0; i < 8; i++ { + seed |= uint64(bz[i]) + seed <<= 8 + } + r.reset(int64(seed)) +} + +func (r *Rand) reset(seed int64) { + r.rand = mrand.New(mrand.NewSource(seed)) +} + +//---------------------------------------- +// Global functions + +func Seed(seed int64) { + grand.Seed(seed) +} + +func RandStr(length int) string { + return grand.Str(length) +} + +func RandUint16() uint16 { + return grand.Uint16() +} + +func RandUint32() uint32 { + return grand.Uint32() +} + +func RandUint64() uint64 { + return grand.Uint64() +} + +func RandUint() uint { + return grand.Uint() +} + +func RandInt16() int16 { + return grand.Int16() +} + +func RandInt32() int32 { + return grand.Int32() +} + +func RandInt64() int64 { + return grand.Int64() +} + +func RandInt() int { + return grand.Int() +} + +func RandInt31() int32 { + return grand.Int31() +} + +func RandInt31n(n int32) int32 { + return grand.Int31n(n) +} + +func RandInt63() int64 { + return grand.Int63() +} + +func RandInt63n(n int64) int64 { + return grand.Int63n(n) +} + +func RandUint16Exp() uint16 { + return grand.Uint16Exp() +} + +func RandUint32Exp() uint32 { + return grand.Uint32Exp() +} + +func RandUint64Exp() uint64 { + return grand.Uint64Exp() +} + +func RandFloat32() float32 { + return grand.Float32() +} + +func RandFloat64() float64 { + return grand.Float64() +} + +func RandTime() time.Time { + return grand.Time() +} + +func RandBytes(n int) []byte { + return grand.Bytes(n) +} + +func RandIntn(n int) int { + return grand.Intn(n) +} + +func RandPerm(n int) []int { + return grand.Perm(n) +} + +//---------------------------------------- +// Rand methods + +func (r *Rand) Seed(seed int64) { + r.Lock() + r.reset(seed) + r.Unlock() +} + +// Constructs an alphanumeric string of given length. +// It is not safe for cryptographic usage. +func (r *Rand) Str(length int) string { + chars := []byte{} +MAIN_LOOP: + for { + val := r.Int63() + for i := 0; i < 10; i++ { + v := int(val & 0x3f) // rightmost 6 bits + if v >= 62 { // only 62 characters in strChars + val >>= 6 + continue + } else { + chars = append(chars, strChars[v]) + if len(chars) == length { + break MAIN_LOOP + } + val >>= 6 + } + } + } + + return string(chars) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Uint16() uint16 { + return uint16(r.Uint32() & (1<<16 - 1)) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Uint32() uint32 { + r.Lock() + u32 := r.rand.Uint32() + r.Unlock() + return u32 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Uint64() uint64 { + return uint64(r.Uint32())<<32 + uint64(r.Uint32()) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Uint() uint { + r.Lock() + i := r.rand.Int() + r.Unlock() + return uint(i) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int16() int16 { + return int16(r.Uint32() & (1<<16 - 1)) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int32() int32 { + return int32(r.Uint32()) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int64() int64 { + return int64(r.Uint64()) +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int() int { + r.Lock() + i := r.rand.Int() + r.Unlock() + return i +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int31() int32 { + r.Lock() + i31 := r.rand.Int31() + r.Unlock() + return i31 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int31n(n int32) int32 { + r.Lock() + i31n := r.rand.Int31n(n) + r.Unlock() + return i31n +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int63() int64 { + r.Lock() + i63 := r.rand.Int63() + r.Unlock() + return i63 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int63n(n int64) int64 { + r.Lock() + i63n := r.rand.Int63n(n) + r.Unlock() + return i63n +} + +// Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. +func (r *Rand) Uint16Exp() uint16 { + bits := r.Uint32() % 16 + if bits == 0 { + return 0 + } + n := uint16(1 << (bits - 1)) + n += uint16(r.Int31()) & ((1 << (bits - 1)) - 1) + return n +} + +// Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. +func (r *Rand) Uint32Exp() uint32 { + bits := r.Uint32() % 32 + if bits == 0 { + return 0 + } + n := uint32(1 << (bits - 1)) + n += uint32(r.Int31()) & ((1 << (bits - 1)) - 1) + return n +} + +// Distributed pseudo-exponentially to test for various cases +// It is not safe for cryptographic usage. +func (r *Rand) Uint64Exp() uint64 { + bits := r.Uint32() % 64 + if bits == 0 { + return 0 + } + n := uint64(1 << (bits - 1)) + n += uint64(r.Int63()) & ((1 << (bits - 1)) - 1) + return n +} + +// It is not safe for cryptographic usage. +func (r *Rand) Float32() float32 { + r.Lock() + f32 := r.rand.Float32() + r.Unlock() + return f32 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Float64() float64 { + r.Lock() + f64 := r.rand.Float64() + r.Unlock() + return f64 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Time() time.Time { + return time.Unix(int64(r.Uint64Exp()), 0) +} + +// RandBytes returns n random bytes from the OS's source of entropy ie. via crypto/rand. +// It is not safe for cryptographic usage. +func (r *Rand) Bytes(n int) []byte { + // cRandBytes isn't guaranteed to be fast so instead + // use random bytes generated from the internal PRNG + bs := make([]byte, n) + for i := 0; i < len(bs); i++ { + bs[i] = byte(r.Int() & 0xFF) + } + return bs +} + +// RandIntn returns, as an int, a non-negative pseudo-random number in [0, n). +// It panics if n <= 0. +// It is not safe for cryptographic usage. +func (r *Rand) Intn(n int) int { + r.Lock() + i := r.rand.Intn(n) + r.Unlock() + return i +} + +// RandPerm returns a pseudo-random permutation of n integers in [0, n). +// It is not safe for cryptographic usage. +func (r *Rand) Perm(n int) []int { + r.Lock() + perm := r.rand.Perm(n) + r.Unlock() + return perm +} + +// NOTE: This relies on the os's random number generator. +// For real security, we should salt that with some seed. +// See github.com/tendermint/go-crypto for a more secure reader. +func cRandBytes(numBytes int) []byte { + b := make([]byte, numBytes) + _, err := crand.Read(b) + if err != nil { + PanicCrisis(err) + } + return b +} diff --git a/common/random_test.go b/common/random_test.go new file mode 100644 index 00000000..b58b4a13 --- /dev/null +++ b/common/random_test.go @@ -0,0 +1,121 @@ +package common + +import ( + "bytes" + "encoding/json" + "fmt" + mrand "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRandStr(t *testing.T) { + l := 243 + s := RandStr(l) + assert.Equal(t, l, len(s)) +} + +func TestRandBytes(t *testing.T) { + l := 243 + b := RandBytes(l) + assert.Equal(t, l, len(b)) +} + +func TestRandIntn(t *testing.T) { + n := 243 + for i := 0; i < 100; i++ { + x := RandIntn(n) + assert.True(t, x < n) + } +} + +// Test to make sure that we never call math.rand(). +// We do this by ensuring that outputs are deterministic. +func TestDeterminism(t *testing.T) { + var firstOutput string + + // Set math/rand's seed for the sake of debugging this test. + // (It isn't strictly necessary). + mrand.Seed(1) + + for i := 0; i < 100; i++ { + output := testThemAll() + if i == 0 { + firstOutput = output + } else { + if firstOutput != output { + t.Errorf("Run #%d's output was different from first run.\nfirst: %v\nlast: %v", + i, firstOutput, output) + } + } + } +} + +func testThemAll() string { + + // Such determinism. + grand.reset(1) + + // Use it. + out := new(bytes.Buffer) + perm := RandPerm(10) + blob, _ := json.Marshal(perm) + fmt.Fprintf(out, "perm: %s\n", blob) + fmt.Fprintf(out, "randInt: %d\n", RandInt()) + fmt.Fprintf(out, "randUint: %d\n", RandUint()) + fmt.Fprintf(out, "randIntn: %d\n", RandIntn(97)) + fmt.Fprintf(out, "randInt31: %d\n", RandInt31()) + fmt.Fprintf(out, "randInt32: %d\n", RandInt32()) + fmt.Fprintf(out, "randInt63: %d\n", RandInt63()) + fmt.Fprintf(out, "randInt64: %d\n", RandInt64()) + fmt.Fprintf(out, "randUint32: %d\n", RandUint32()) + fmt.Fprintf(out, "randUint64: %d\n", RandUint64()) + fmt.Fprintf(out, "randUint16Exp: %d\n", RandUint16Exp()) + fmt.Fprintf(out, "randUint32Exp: %d\n", RandUint32Exp()) + fmt.Fprintf(out, "randUint64Exp: %d\n", RandUint64Exp()) + return out.String() +} + +func TestRngConcurrencySafety(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + _ = RandUint64() + <-time.After(time.Millisecond * time.Duration(RandIntn(100))) + _ = RandPerm(3) + }() + } + wg.Wait() +} + +func BenchmarkRandBytes10B(b *testing.B) { + benchmarkRandBytes(b, 10) +} +func BenchmarkRandBytes100B(b *testing.B) { + benchmarkRandBytes(b, 100) +} +func BenchmarkRandBytes1KiB(b *testing.B) { + benchmarkRandBytes(b, 1024) +} +func BenchmarkRandBytes10KiB(b *testing.B) { + benchmarkRandBytes(b, 10*1024) +} +func BenchmarkRandBytes100KiB(b *testing.B) { + benchmarkRandBytes(b, 100*1024) +} +func BenchmarkRandBytes1MiB(b *testing.B) { + benchmarkRandBytes(b, 1024*1024) +} + +func benchmarkRandBytes(b *testing.B, n int) { + for i := 0; i < b.N; i++ { + _ = RandBytes(n) + } + b.ReportAllocs() +} diff --git a/common/repeat_timer.go b/common/repeat_timer.go new file mode 100644 index 00000000..5d049738 --- /dev/null +++ b/common/repeat_timer.go @@ -0,0 +1,232 @@ +package common + +import ( + "sync" + "time" +) + +// Used by RepeatTimer the first time, +// and every time it's Reset() after Stop(). +type TickerMaker func(dur time.Duration) Ticker + +// Ticker is a basic ticker interface. +type Ticker interface { + + // Never changes, never closes. + Chan() <-chan time.Time + + // Stopping a stopped Ticker will panic. + Stop() +} + +//---------------------------------------- +// defaultTicker + +var _ Ticker = (*defaultTicker)(nil) + +type defaultTicker time.Ticker + +func defaultTickerMaker(dur time.Duration) Ticker { + ticker := time.NewTicker(dur) + return (*defaultTicker)(ticker) +} + +// Implements Ticker +func (t *defaultTicker) Chan() <-chan time.Time { + return t.C +} + +// Implements Ticker +func (t *defaultTicker) Stop() { + ((*time.Ticker)(t)).Stop() +} + +//---------------------------------------- +// LogicalTickerMaker + +// Construct a TickerMaker that always uses `source`. +// It's useful for simulating a deterministic clock. +func NewLogicalTickerMaker(source chan time.Time) TickerMaker { + return func(dur time.Duration) Ticker { + return newLogicalTicker(source, dur) + } +} + +type logicalTicker struct { + source <-chan time.Time + ch chan time.Time + quit chan struct{} +} + +func newLogicalTicker(source <-chan time.Time, interval time.Duration) Ticker { + lt := &logicalTicker{ + source: source, + ch: make(chan time.Time), + quit: make(chan struct{}), + } + go lt.fireRoutine(interval) + return lt +} + +// We need a goroutine to read times from t.source +// and fire on t.Chan() when `interval` has passed. +func (t *logicalTicker) fireRoutine(interval time.Duration) { + source := t.source + + // Init `lasttime` + lasttime := time.Time{} + select { + case lasttime = <-source: + case <-t.quit: + return + } + // Init `lasttime` end + + for { + select { + case newtime := <-source: + elapsed := newtime.Sub(lasttime) + if interval <= elapsed { + // Block for determinism until the ticker is stopped. + select { + case t.ch <- newtime: + case <-t.quit: + return + } + // Reset timeleft. + // Don't try to "catch up" by sending more. + // "Ticker adjusts the intervals or drops ticks to make up for + // slow receivers" - https://golang.org/pkg/time/#Ticker + lasttime = newtime + } + case <-t.quit: + return // done + } + } +} + +// Implements Ticker +func (t *logicalTicker) Chan() <-chan time.Time { + return t.ch // immutable +} + +// Implements Ticker +func (t *logicalTicker) Stop() { + close(t.quit) // it *should* panic when stopped twice. +} + +//--------------------------------------------------------------------- + +/* + RepeatTimer repeatedly sends a struct{}{} to `.Chan()` after each `dur` + period. (It's good for keeping connections alive.) + A RepeatTimer must be stopped, or it will keep a goroutine alive. +*/ +type RepeatTimer struct { + name string + ch chan time.Time + tm TickerMaker + + mtx sync.Mutex + dur time.Duration + ticker Ticker + quit chan struct{} +} + +// NewRepeatTimer returns a RepeatTimer with a defaultTicker. +func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer { + return NewRepeatTimerWithTickerMaker(name, dur, defaultTickerMaker) +} + +// NewRepeatTimerWithTicker returns a RepeatTimer with the given ticker +// maker. +func NewRepeatTimerWithTickerMaker(name string, dur time.Duration, tm TickerMaker) *RepeatTimer { + var t = &RepeatTimer{ + name: name, + ch: make(chan time.Time), + tm: tm, + dur: dur, + ticker: nil, + quit: nil, + } + t.reset() + return t +} + +// receive ticks on ch, send out on t.ch +func (t *RepeatTimer) fireRoutine(ch <-chan time.Time, quit <-chan struct{}) { + for { + select { + case tick := <-ch: + select { + case t.ch <- tick: + case <-quit: + return + } + case <-quit: // NOTE: `t.quit` races. + return + } + } +} + +func (t *RepeatTimer) Chan() <-chan time.Time { + return t.ch +} + +func (t *RepeatTimer) Stop() { + t.mtx.Lock() + defer t.mtx.Unlock() + + t.stop() +} + +// Wait the duration again before firing. +func (t *RepeatTimer) Reset() { + t.mtx.Lock() + defer t.mtx.Unlock() + + t.reset() +} + +//---------------------------------------- +// Misc. + +// CONTRACT: (non-constructor) caller should hold t.mtx. +func (t *RepeatTimer) reset() { + if t.ticker != nil { + t.stop() + } + t.ticker = t.tm(t.dur) + t.quit = make(chan struct{}) + go t.fireRoutine(t.ticker.Chan(), t.quit) +} + +// CONTRACT: caller should hold t.mtx. +func (t *RepeatTimer) stop() { + if t.ticker == nil { + /* + Similar to the case of closing channels twice: + https://groups.google.com/forum/#!topic/golang-nuts/rhxMiNmRAPk + Stopping a RepeatTimer twice implies that you do + not know whether you are done or not. + If you're calling stop on a stopped RepeatTimer, + you probably have race conditions. + */ + panic("Tried to stop a stopped RepeatTimer") + } + t.ticker.Stop() + t.ticker = nil + /* + From https://golang.org/pkg/time/#Ticker: + "Stop the ticker to release associated resources" + "After Stop, no more ticks will be sent" + So we shouldn't have to do the below. + + select { + case <-t.ch: + // read off channel if there's anything there + default: + } + */ + close(t.quit) +} diff --git a/common/repeat_timer_test.go b/common/repeat_timer_test.go new file mode 100644 index 00000000..160f4394 --- /dev/null +++ b/common/repeat_timer_test.go @@ -0,0 +1,137 @@ +package common + +import ( + "math/rand" + "sync" + "testing" + "time" + + "github.com/fortytw2/leaktest" + "github.com/stretchr/testify/assert" +) + +func TestDefaultTicker(t *testing.T) { + ticker := defaultTickerMaker(time.Millisecond * 10) + <-ticker.Chan() + ticker.Stop() +} + +func TestRepeatTimer(t *testing.T) { + + ch := make(chan time.Time, 100) + mtx := new(sync.Mutex) + + // tick() fires from start to end + // (exclusive) in milliseconds with incr. + // It locks on mtx, so subsequent calls + // run in series. + tick := func(startMs, endMs, incrMs time.Duration) { + mtx.Lock() + go func() { + for tMs := startMs; tMs < endMs; tMs += incrMs { + lt := time.Time{} + lt = lt.Add(tMs * time.Millisecond) + ch <- lt + } + mtx.Unlock() + }() + } + + // tock consumes Ticker.Chan() events and checks them against the ms in "timesMs". + tock := func(t *testing.T, rt *RepeatTimer, timesMs []int64) { + + // Check against timesMs. + for _, timeMs := range timesMs { + tyme := <-rt.Chan() + sinceMs := tyme.Sub(time.Time{}) / time.Millisecond + assert.Equal(t, timeMs, int64(sinceMs)) + } + + // TODO detect number of running + // goroutines to ensure that + // no other times will fire. + // See https://github.com/tendermint/tmlibs/issues/120. + time.Sleep(time.Millisecond * 100) + done := true + select { + case <-rt.Chan(): + done = false + default: + } + assert.True(t, done) + } + + tm := NewLogicalTickerMaker(ch) + rt := NewRepeatTimerWithTickerMaker("bar", time.Second, tm) + + /* NOTE: Useful for debugging deadlocks... + go func() { + time.Sleep(time.Second * 3) + trace := make([]byte, 102400) + count := runtime.Stack(trace, true) + fmt.Printf("Stack of %d bytes: %s\n", count, trace) + }() + */ + + tick(0, 1000, 10) + tock(t, rt, []int64{}) + tick(1000, 2000, 10) + tock(t, rt, []int64{1000}) + tick(2005, 5000, 10) + tock(t, rt, []int64{2005, 3005, 4005}) + tick(5001, 5999, 1) + // Read 5005 instead of 5001 because + // it's 1 second greater than 4005. + tock(t, rt, []int64{5005}) + tick(6000, 7005, 1) + tock(t, rt, []int64{6005}) + tick(7033, 8032, 1) + tock(t, rt, []int64{7033}) + + // After a reset, nothing happens + // until two ticks are received. + rt.Reset() + tock(t, rt, []int64{}) + tick(8040, 8041, 1) + tock(t, rt, []int64{}) + tick(9555, 9556, 1) + tock(t, rt, []int64{9555}) + + // After a stop, nothing more is sent. + rt.Stop() + tock(t, rt, []int64{}) + + // Another stop panics. + assert.Panics(t, func() { rt.Stop() }) +} + +func TestRepeatTimerReset(t *testing.T) { + // check that we are not leaking any go-routines + defer leaktest.Check(t)() + + timer := NewRepeatTimer("test", 20*time.Millisecond) + defer timer.Stop() + + // test we don't receive tick before duration ms. + select { + case <-timer.Chan(): + t.Fatal("did not expect to receive tick") + default: + } + + timer.Reset() + + // test we receive tick after Reset is called + select { + case <-timer.Chan(): + // all good + case <-time.After(40 * time.Millisecond): + t.Fatal("expected to receive tick after reset") + } + + // just random calls + for i := 0; i < 100; i++ { + time.Sleep(time.Duration(rand.Intn(40)) * time.Millisecond) + timer.Reset() + } +} diff --git a/common/service.go b/common/service.go new file mode 100644 index 00000000..2f90fa4f --- /dev/null +++ b/common/service.go @@ -0,0 +1,205 @@ +package common + +import ( + "errors" + "fmt" + "sync/atomic" + + "github.com/tendermint/tmlibs/log" +) + +var ( + ErrAlreadyStarted = errors.New("already started") + ErrAlreadyStopped = errors.New("already stopped") +) + +// Service defines a service that can be started, stopped, and reset. +type Service interface { + // Start the service. + // If it's already started or stopped, will return an error. + // If OnStart() returns an error, it's returned by Start() + Start() error + OnStart() error + + // Stop the service. + // If it's already stopped, will return an error. + // OnStop must never error. + Stop() error + OnStop() + + // Reset the service. + // Panics by default - must be overwritten to enable reset. + Reset() error + OnReset() error + + // Return true if the service is running + IsRunning() bool + + // Quit returns a channel, which is closed once service is stopped. + Quit() <-chan struct{} + + // String representation of the service + String() string + + // SetLogger sets a logger. + SetLogger(log.Logger) +} + +/* +Classical-inheritance-style service declarations. Services can be started, then +stopped, then optionally restarted. + +Users can override the OnStart/OnStop methods. In the absence of errors, these +methods are guaranteed to be called at most once. If OnStart returns an error, +service won't be marked as started, so the user can call Start again. + +A call to Reset will panic, unless OnReset is overwritten, allowing +OnStart/OnStop to be called again. + +The caller must ensure that Start and Stop are not called concurrently. + +It is ok to call Stop without calling Start first. + +Typical usage: + + type FooService struct { + BaseService + // private fields + } + + func NewFooService() *FooService { + fs := &FooService{ + // init + } + fs.BaseService = *NewBaseService(log, "FooService", fs) + return fs + } + + func (fs *FooService) OnStart() error { + fs.BaseService.OnStart() // Always call the overridden method. + // initialize private fields + // start subroutines, etc. + } + + func (fs *FooService) OnStop() error { + fs.BaseService.OnStop() // Always call the overridden method. + // close/destroy private fields + // stop subroutines, etc. + } +*/ +type BaseService struct { + Logger log.Logger + name string + started uint32 // atomic + stopped uint32 // atomic + quit chan struct{} + + // The "subclass" of BaseService + impl Service +} + +// NewBaseService creates a new BaseService. +func NewBaseService(logger log.Logger, name string, impl Service) *BaseService { + if logger == nil { + logger = log.NewNopLogger() + } + + return &BaseService{ + Logger: logger, + name: name, + quit: make(chan struct{}), + impl: impl, + } +} + +// SetLogger implements Service by setting a logger. +func (bs *BaseService) SetLogger(l log.Logger) { + bs.Logger = l +} + +// Start implements Service by calling OnStart (if defined). An error will be +// returned if the service is already running or stopped. Not to start the +// stopped service, you need to call Reset. +func (bs *BaseService) Start() error { + if atomic.CompareAndSwapUint32(&bs.started, 0, 1) { + if atomic.LoadUint32(&bs.stopped) == 1 { + bs.Logger.Error(Fmt("Not starting %v -- already stopped", bs.name), "impl", bs.impl) + return ErrAlreadyStopped + } + bs.Logger.Info(Fmt("Starting %v", bs.name), "impl", bs.impl) + err := bs.impl.OnStart() + if err != nil { + // revert flag + atomic.StoreUint32(&bs.started, 0) + return err + } + return nil + } + bs.Logger.Debug(Fmt("Not starting %v -- already started", bs.name), "impl", bs.impl) + return ErrAlreadyStarted +} + +// OnStart implements Service by doing nothing. +// NOTE: Do not put anything in here, +// that way users don't need to call BaseService.OnStart() +func (bs *BaseService) OnStart() error { return nil } + +// Stop implements Service by calling OnStop (if defined) and closing quit +// channel. An error will be returned if the service is already stopped. +func (bs *BaseService) Stop() error { + if atomic.CompareAndSwapUint32(&bs.stopped, 0, 1) { + bs.Logger.Info(Fmt("Stopping %v", bs.name), "impl", bs.impl) + bs.impl.OnStop() + close(bs.quit) + return nil + } + bs.Logger.Debug(Fmt("Stopping %v (ignoring: already stopped)", bs.name), "impl", bs.impl) + return ErrAlreadyStopped +} + +// OnStop implements Service by doing nothing. +// NOTE: Do not put anything in here, +// that way users don't need to call BaseService.OnStop() +func (bs *BaseService) OnStop() {} + +// Reset implements Service by calling OnReset callback (if defined). An error +// will be returned if the service is running. +func (bs *BaseService) Reset() error { + if !atomic.CompareAndSwapUint32(&bs.stopped, 1, 0) { + bs.Logger.Debug(Fmt("Can't reset %v. Not stopped", bs.name), "impl", bs.impl) + return fmt.Errorf("can't reset running %s", bs.name) + } + + // whether or not we've started, we can reset + atomic.CompareAndSwapUint32(&bs.started, 1, 0) + + bs.quit = make(chan struct{}) + return bs.impl.OnReset() +} + +// OnReset implements Service by panicking. +func (bs *BaseService) OnReset() error { + PanicSanity("The service cannot be reset") + return nil +} + +// IsRunning implements Service by returning true or false depending on the +// service's state. +func (bs *BaseService) IsRunning() bool { + return atomic.LoadUint32(&bs.started) == 1 && atomic.LoadUint32(&bs.stopped) == 0 +} + +// Wait blocks until the service is stopped. +func (bs *BaseService) Wait() { + <-bs.quit +} + +// String implements Servce by returning a string representation of the service. +func (bs *BaseService) String() string { + return bs.name +} + +// Quit Implements Service by returning a quit channel. +func (bs *BaseService) Quit() <-chan struct{} { + return bs.quit +} diff --git a/common/service_test.go b/common/service_test.go new file mode 100644 index 00000000..ef360a64 --- /dev/null +++ b/common/service_test.go @@ -0,0 +1,54 @@ +package common + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type testService struct { + BaseService +} + +func (testService) OnReset() error { + return nil +} + +func TestBaseServiceWait(t *testing.T) { + ts := &testService{} + ts.BaseService = *NewBaseService(nil, "TestService", ts) + ts.Start() + + waitFinished := make(chan struct{}) + go func() { + ts.Wait() + waitFinished <- struct{}{} + }() + + go ts.Stop() + + select { + case <-waitFinished: + // all good + case <-time.After(100 * time.Millisecond): + t.Fatal("expected Wait() to finish within 100 ms.") + } +} + +func TestBaseServiceReset(t *testing.T) { + ts := &testService{} + ts.BaseService = *NewBaseService(nil, "TestService", ts) + ts.Start() + + err := ts.Reset() + require.Error(t, err, "expected cant reset service error") + + ts.Stop() + + err = ts.Reset() + require.NoError(t, err) + + err = ts.Start() + require.NoError(t, err) +} diff --git a/common/string.go b/common/string.go new file mode 100644 index 00000000..fac1be6c --- /dev/null +++ b/common/string.go @@ -0,0 +1,89 @@ +package common + +import ( + "encoding/hex" + "fmt" + "strings" +) + +// Like fmt.Sprintf, but skips formatting if args are empty. +var Fmt = func(format string, a ...interface{}) string { + if len(a) == 0 { + return format + } + return fmt.Sprintf(format, a...) +} + +// IsHex returns true for non-empty hex-string prefixed with "0x" +func IsHex(s string) bool { + if len(s) > 2 && strings.EqualFold(s[:2], "0x") { + _, err := hex.DecodeString(s[2:]) + return err == nil + } + return false +} + +// StripHex returns hex string without leading "0x" +func StripHex(s string) string { + if IsHex(s) { + return s[2:] + } + return s +} + +// StringInSlice returns true if a is found the list. +func StringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// SplitAndTrim slices s into all subslices separated by sep and returns a +// slice of the string s with all leading and trailing Unicode code points +// contained in cutset removed. If sep is empty, SplitAndTrim splits after each +// UTF-8 sequence. First part is equivalent to strings.SplitN with a count of +// -1. +func SplitAndTrim(s, sep, cutset string) []string { + if s == "" { + return []string{} + } + + spl := strings.Split(s, sep) + for i := 0; i < len(spl); i++ { + spl[i] = strings.Trim(spl[i], cutset) + } + return spl +} + +// Returns true if s is a non-empty printable non-tab ascii character. +func IsASCIIText(s string) bool { + if len(s) == 0 { + return false + } + for _, b := range []byte(s) { + if 32 <= b && b <= 126 { + // good + } else { + return false + } + } + return true +} + +// NOTE: Assumes that s is ASCII as per IsASCIIText(), otherwise panics. +func ASCIITrim(s string) string { + r := make([]byte, 0, len(s)) + for _, b := range []byte(s) { + if b == 32 { + continue // skip space + } else if 32 < b && b <= 126 { + r = append(r, b) + } else { + panic(fmt.Sprintf("non-ASCII (non-tab) char 0x%X", b)) + } + } + return string(r) +} diff --git a/common/string_test.go b/common/string_test.go new file mode 100644 index 00000000..5d1b68fe --- /dev/null +++ b/common/string_test.go @@ -0,0 +1,74 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStringInSlice(t *testing.T) { + assert.True(t, StringInSlice("a", []string{"a", "b", "c"})) + assert.False(t, StringInSlice("d", []string{"a", "b", "c"})) + assert.True(t, StringInSlice("", []string{""})) + assert.False(t, StringInSlice("", []string{})) +} + +func TestIsHex(t *testing.T) { + notHex := []string{ + "", " ", "a", "x", "0", "0x", "0X", "0x ", "0X ", "0X a", + "0xf ", "0x f", "0xp", "0x-", + "0xf", "0XBED", "0xF", "0xbed", // Odd lengths + } + for _, v := range notHex { + assert.False(t, IsHex(v), "%q is not hex", v) + } + hex := []string{ + "0x00", "0x0a", "0x0F", "0xFFFFFF", "0Xdeadbeef", "0x0BED", + "0X12", "0X0A", + } + for _, v := range hex { + assert.True(t, IsHex(v), "%q is hex", v) + } +} + +func TestSplitAndTrim(t *testing.T) { + testCases := []struct { + s string + sep string + cutset string + expected []string + }{ + {"a,b,c", ",", " ", []string{"a", "b", "c"}}, + {" a , b , c ", ",", " ", []string{"a", "b", "c"}}, + {" a, b, c ", ",", " ", []string{"a", "b", "c"}}, + {" , ", ",", " ", []string{"", ""}}, + {" ", ",", " ", []string{""}}, + } + + for _, tc := range testCases { + assert.Equal(t, tc.expected, SplitAndTrim(tc.s, tc.sep, tc.cutset), "%s", tc.s) + } +} + +func TestIsASCIIText(t *testing.T) { + notASCIIText := []string{ + "", "\xC2", "\xC2\xA2", "\xFF", "\x80", "\xF0", "\n", "\t", + } + for _, v := range notASCIIText { + assert.False(t, IsHex(v), "%q is not ascii-text", v) + } + asciiText := []string{ + " ", ".", "x", "$", "_", "abcdefg;", "-", "0x00", "0", "123", + } + for _, v := range asciiText { + assert.True(t, IsASCIIText(v), "%q is ascii-text", v) + } +} + +func TestASCIITrim(t *testing.T) { + assert.Equal(t, ASCIITrim(" "), "") + assert.Equal(t, ASCIITrim(" a"), "a") + assert.Equal(t, ASCIITrim("a "), "a") + assert.Equal(t, ASCIITrim(" a "), "a") + assert.Panics(t, func() { ASCIITrim("\xC2\xA2") }) +} diff --git a/common/throttle_timer.go b/common/throttle_timer.go new file mode 100644 index 00000000..38ef4e9a --- /dev/null +++ b/common/throttle_timer.go @@ -0,0 +1,75 @@ +package common + +import ( + "sync" + "time" +) + +/* +ThrottleTimer fires an event at most "dur" after each .Set() call. +If a short burst of .Set() calls happens, ThrottleTimer fires once. +If a long continuous burst of .Set() calls happens, ThrottleTimer fires +at most once every "dur". +*/ +type ThrottleTimer struct { + Name string + Ch chan struct{} + quit chan struct{} + dur time.Duration + + mtx sync.Mutex + timer *time.Timer + isSet bool +} + +func NewThrottleTimer(name string, dur time.Duration) *ThrottleTimer { + var ch = make(chan struct{}) + var quit = make(chan struct{}) + var t = &ThrottleTimer{Name: name, Ch: ch, dur: dur, quit: quit} + t.mtx.Lock() + t.timer = time.AfterFunc(dur, t.fireRoutine) + t.mtx.Unlock() + t.timer.Stop() + return t +} + +func (t *ThrottleTimer) fireRoutine() { + t.mtx.Lock() + defer t.mtx.Unlock() + select { + case t.Ch <- struct{}{}: + t.isSet = false + case <-t.quit: + // do nothing + default: + t.timer.Reset(t.dur) + } +} + +func (t *ThrottleTimer) Set() { + t.mtx.Lock() + defer t.mtx.Unlock() + if !t.isSet { + t.isSet = true + t.timer.Reset(t.dur) + } +} + +func (t *ThrottleTimer) Unset() { + t.mtx.Lock() + defer t.mtx.Unlock() + t.isSet = false + t.timer.Stop() +} + +// For ease of .Stop()'ing services before .Start()'ing them, +// we ignore .Stop()'s on nil ThrottleTimers +func (t *ThrottleTimer) Stop() bool { + if t == nil { + return false + } + close(t.quit) + t.mtx.Lock() + defer t.mtx.Unlock() + return t.timer.Stop() +} diff --git a/common/throttle_timer_test.go b/common/throttle_timer_test.go new file mode 100644 index 00000000..00f5abde --- /dev/null +++ b/common/throttle_timer_test.go @@ -0,0 +1,78 @@ +package common + +import ( + "sync" + "testing" + "time" + + // make govet noshadow happy... + asrt "github.com/stretchr/testify/assert" +) + +type thCounter struct { + input chan struct{} + mtx sync.Mutex + count int +} + +func (c *thCounter) Increment() { + c.mtx.Lock() + c.count++ + c.mtx.Unlock() +} + +func (c *thCounter) Count() int { + c.mtx.Lock() + val := c.count + c.mtx.Unlock() + return val +} + +// Read should run in a go-routine and +// updates count by one every time a packet comes in +func (c *thCounter) Read() { + for range c.input { + c.Increment() + } +} + +func TestThrottle(test *testing.T) { + assert := asrt.New(test) + + ms := 50 + delay := time.Duration(ms) * time.Millisecond + longwait := time.Duration(2) * delay + t := NewThrottleTimer("foo", delay) + + // start at 0 + c := &thCounter{input: t.Ch} + assert.Equal(0, c.Count()) + go c.Read() + + // waiting does nothing + time.Sleep(longwait) + assert.Equal(0, c.Count()) + + // send one event adds one + t.Set() + time.Sleep(longwait) + assert.Equal(1, c.Count()) + + // send a burst adds one + for i := 0; i < 5; i++ { + t.Set() + } + time.Sleep(longwait) + assert.Equal(2, c.Count()) + + // send 12, over 2 delay sections, adds 3 + short := time.Duration(ms/5) * time.Millisecond + for i := 0; i < 13; i++ { + t.Set() + time.Sleep(short) + } + time.Sleep(longwait) + assert.Equal(5, c.Count()) + + close(t.Ch) +} diff --git a/common/types.pb.go b/common/types.pb.go new file mode 100644 index 00000000..f6645602 --- /dev/null +++ b/common/types.pb.go @@ -0,0 +1,98 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: common/types.proto + +/* +Package common is a generated protocol buffer package. + +It is generated from these files: + common/types.proto + +It has these top-level messages: + KVPair + KI64Pair +*/ +//nolint: gas +package common + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// Define these here for compatibility but use tmlibs/common.KVPair. +type KVPair struct { + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` +} + +func (m *KVPair) Reset() { *m = KVPair{} } +func (m *KVPair) String() string { return proto.CompactTextString(m) } +func (*KVPair) ProtoMessage() {} +func (*KVPair) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *KVPair) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *KVPair) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +// Define these here for compatibility but use tmlibs/common.KI64Pair. +type KI64Pair struct { + Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value int64 `protobuf:"varint,2,opt,name=value" json:"value,omitempty"` +} + +func (m *KI64Pair) Reset() { *m = KI64Pair{} } +func (m *KI64Pair) String() string { return proto.CompactTextString(m) } +func (*KI64Pair) ProtoMessage() {} +func (*KI64Pair) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *KI64Pair) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *KI64Pair) GetValue() int64 { + if m != nil { + return m.Value + } + return 0 +} + +func init() { + proto.RegisterType((*KVPair)(nil), "common.KVPair") + proto.RegisterType((*KI64Pair)(nil), "common.KI64Pair") +} + +func init() { proto.RegisterFile("common/types.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 107 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4a, 0xce, 0xcf, 0xcd, + 0xcd, 0xcf, 0xd3, 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, + 0x83, 0x88, 0x29, 0x19, 0x70, 0xb1, 0x79, 0x87, 0x05, 0x24, 0x66, 0x16, 0x09, 0x09, 0x70, 0x31, + 0x67, 0xa7, 0x56, 0x4a, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0x81, 0x98, 0x42, 0x22, 0x5c, 0xac, + 0x65, 0x89, 0x39, 0xa5, 0xa9, 0x12, 0x4c, 0x60, 0x31, 0x08, 0x47, 0xc9, 0x88, 0x8b, 0xc3, 0xdb, + 0xd3, 0xcc, 0x84, 0x18, 0x3d, 0xcc, 0x50, 0x3d, 0x49, 0x6c, 0x60, 0x4b, 0x8d, 0x01, 0x01, 0x00, + 0x00, 0xff, 0xff, 0xd8, 0xf1, 0xc3, 0x8c, 0x8a, 0x00, 0x00, 0x00, +} diff --git a/common/types.proto b/common/types.proto new file mode 100644 index 00000000..8406fcfd --- /dev/null +++ b/common/types.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; +package common; + +//---------------------------------------- +// Abstract types + +// Define these here for compatibility but use tmlibs/common.KVPair. +message KVPair { + bytes key = 1; + bytes value = 2; +} + +// Define these here for compatibility but use tmlibs/common.KI64Pair. +message KI64Pair { + bytes key = 1; + int64 value = 2; +} diff --git a/common/word.go b/common/word.go new file mode 100644 index 00000000..a5b841f5 --- /dev/null +++ b/common/word.go @@ -0,0 +1,90 @@ +package common + +import ( + "bytes" + "sort" +) + +var ( + Zero256 = Word256{0} + One256 = Word256{1} +) + +type Word256 [32]byte + +func (w Word256) String() string { return string(w[:]) } +func (w Word256) TrimmedString() string { return TrimmedString(w.Bytes()) } +func (w Word256) Copy() Word256 { return w } +func (w Word256) Bytes() []byte { return w[:] } // copied. +func (w Word256) Prefix(n int) []byte { return w[:n] } +func (w Word256) Postfix(n int) []byte { return w[32-n:] } +func (w Word256) IsZero() bool { + accum := byte(0) + for _, byt := range w { + accum |= byt + } + return accum == 0 +} +func (w Word256) Compare(other Word256) int { + return bytes.Compare(w[:], other[:]) +} + +func Uint64ToWord256(i uint64) Word256 { + buf := [8]byte{} + PutUint64BE(buf[:], i) + return LeftPadWord256(buf[:]) +} + +func Int64ToWord256(i int64) Word256 { + buf := [8]byte{} + PutInt64BE(buf[:], i) + return LeftPadWord256(buf[:]) +} + +func RightPadWord256(bz []byte) (word Word256) { + copy(word[:], bz) + return +} + +func LeftPadWord256(bz []byte) (word Word256) { + copy(word[32-len(bz):], bz) + return +} + +func Uint64FromWord256(word Word256) uint64 { + buf := word.Postfix(8) + return GetUint64BE(buf) +} + +func Int64FromWord256(word Word256) int64 { + buf := word.Postfix(8) + return GetInt64BE(buf) +} + +//------------------------------------- + +type Tuple256 struct { + First Word256 + Second Word256 +} + +func (tuple Tuple256) Compare(other Tuple256) int { + firstCompare := tuple.First.Compare(other.First) + if firstCompare == 0 { + return tuple.Second.Compare(other.Second) + } + return firstCompare +} + +func Tuple256Split(t Tuple256) (Word256, Word256) { + return t.First, t.Second +} + +type Tuple256Slice []Tuple256 + +func (p Tuple256Slice) Len() int { return len(p) } +func (p Tuple256Slice) Less(i, j int) bool { + return p[i].Compare(p[j]) < 0 +} +func (p Tuple256Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p Tuple256Slice) Sort() { sort.Sort(p) } diff --git a/db/LICENSE.md b/db/LICENSE.md new file mode 100644 index 00000000..ab8da59d --- /dev/null +++ b/db/LICENSE.md @@ -0,0 +1,3 @@ +Tendermint Go-DB Copyright (C) 2015 All in Bits, Inc + +Released under the Apache2.0 license diff --git a/db/README.md b/db/README.md new file mode 100644 index 00000000..ca5ab33f --- /dev/null +++ b/db/README.md @@ -0,0 +1 @@ +TODO: syndtr/goleveldb should be replaced with actual LevelDB instance diff --git a/db/backend_test.go b/db/backend_test.go new file mode 100644 index 00000000..d451b7c5 --- /dev/null +++ b/db/backend_test.go @@ -0,0 +1,215 @@ +package db + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" +) + +func cleanupDBDir(dir, name string) { + os.RemoveAll(filepath.Join(dir, name) + ".db") +} + +func testBackendGetSetDelete(t *testing.T, backend DBBackendType) { + // Default + dir, dirname := cmn.Tempdir(fmt.Sprintf("test_backend_%s_", backend)) + defer dir.Close() + db := NewDB("testdb", backend, dirname) + + // A nonexistent key should return nil, even if the key is empty + require.Nil(t, db.Get([]byte(""))) + + // A nonexistent key should return nil, even if the key is nil + require.Nil(t, db.Get(nil)) + + // A nonexistent key should return nil. + key := []byte("abc") + require.Nil(t, db.Get(key)) + + // Set empty value. + db.Set(key, []byte("")) + require.NotNil(t, db.Get(key)) + require.Empty(t, db.Get(key)) + + // Set nil value. + db.Set(key, nil) + require.NotNil(t, db.Get(key)) + require.Empty(t, db.Get(key)) + + // Delete. + db.Delete(key) + require.Nil(t, db.Get(key)) +} + +func TestBackendsGetSetDelete(t *testing.T) { + for dbType := range backends { + testBackendGetSetDelete(t, dbType) + } +} + +func withDB(t *testing.T, creator dbCreator, fn func(DB)) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db, err := creator(name, "") + defer cleanupDBDir("", name) + assert.Nil(t, err) + fn(db) + db.Close() +} + +func TestBackendsNilKeys(t *testing.T) { + + // Test all backends. + for dbType, creator := range backends { + withDB(t, creator, func(db DB) { + t.Run(fmt.Sprintf("Testing %s", dbType), func(t *testing.T) { + + // Nil keys are treated as the empty key for most operations. + expect := func(key, value []byte) { + if len(key) == 0 { // nil or empty + assert.Equal(t, db.Get(nil), db.Get([]byte(""))) + assert.Equal(t, db.Has(nil), db.Has([]byte(""))) + } + assert.Equal(t, db.Get(key), value) + assert.Equal(t, db.Has(key), value != nil) + } + + // Not set + expect(nil, nil) + + // Set nil value + db.Set(nil, nil) + expect(nil, []byte("")) + + // Set empty value + db.Set(nil, []byte("")) + expect(nil, []byte("")) + + // Set nil, Delete nil + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set nil, Delete empty + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // Set empty, Delete nil + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set empty, Delete empty + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // SetSync nil, DeleteSync nil + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync nil, DeleteSync empty + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + + // SetSync empty, DeleteSync nil + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync empty, DeleteSync empty + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + }) + }) + } +} + +func TestGoLevelDBBackend(t *testing.T) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, GoLevelDBBackend, "") + defer cleanupDBDir("", name) + + _, ok := db.(*GoLevelDB) + assert.True(t, ok) +} + +func TestDBIterator(t *testing.T) { + for dbType := range backends { + t.Run(fmt.Sprintf("%v", dbType), func(t *testing.T) { + testDBIterator(t, dbType) + }) + } +} + +func testDBIterator(t *testing.T, backend DBBackendType) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, backend, "") + defer cleanupDBDir("", name) + + for i := 0; i < 10; i++ { + if i != 6 { // but skip 6. + db.Set(int642Bytes(int64(i)), nil) + } + } + + verifyIterator(t, db.Iterator(nil, nil), []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator") + verifyIterator(t, db.ReverseIterator(nil, nil), []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator") + + verifyIterator(t, db.Iterator(nil, int642Bytes(0)), []int64(nil), "forward iterator to 0") + verifyIterator(t, db.ReverseIterator(nil, int642Bytes(10)), []int64(nil), "reverse iterator 10") + + verifyIterator(t, db.Iterator(int642Bytes(0), nil), []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 0") + verifyIterator(t, db.Iterator(int642Bytes(1), nil), []int64{1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 1") + verifyIterator(t, db.ReverseIterator(int642Bytes(10), nil), []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 10") + verifyIterator(t, db.ReverseIterator(int642Bytes(9), nil), []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 9") + verifyIterator(t, db.ReverseIterator(int642Bytes(8), nil), []int64{8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 8") + + verifyIterator(t, db.Iterator(int642Bytes(5), int642Bytes(6)), []int64{5}, "forward iterator from 5 to 6") + verifyIterator(t, db.Iterator(int642Bytes(5), int642Bytes(7)), []int64{5}, "forward iterator from 5 to 7") + verifyIterator(t, db.Iterator(int642Bytes(5), int642Bytes(8)), []int64{5, 7}, "forward iterator from 5 to 8") + verifyIterator(t, db.Iterator(int642Bytes(6), int642Bytes(7)), []int64(nil), "forward iterator from 6 to 7") + verifyIterator(t, db.Iterator(int642Bytes(6), int642Bytes(8)), []int64{7}, "forward iterator from 6 to 8") + verifyIterator(t, db.Iterator(int642Bytes(7), int642Bytes(8)), []int64{7}, "forward iterator from 7 to 8") + + verifyIterator(t, db.ReverseIterator(int642Bytes(5), int642Bytes(4)), []int64{5}, "reverse iterator from 5 to 4") + verifyIterator(t, db.ReverseIterator(int642Bytes(6), int642Bytes(4)), []int64{5}, "reverse iterator from 6 to 4") + verifyIterator(t, db.ReverseIterator(int642Bytes(7), int642Bytes(4)), []int64{7, 5}, "reverse iterator from 7 to 4") + verifyIterator(t, db.ReverseIterator(int642Bytes(6), int642Bytes(5)), []int64(nil), "reverse iterator from 6 to 5") + verifyIterator(t, db.ReverseIterator(int642Bytes(7), int642Bytes(5)), []int64{7}, "reverse iterator from 7 to 5") + verifyIterator(t, db.ReverseIterator(int642Bytes(7), int642Bytes(6)), []int64{7}, "reverse iterator from 7 to 6") + + verifyIterator(t, db.Iterator(int642Bytes(0), int642Bytes(1)), []int64{0}, "forward iterator from 0 to 1") + verifyIterator(t, db.ReverseIterator(int642Bytes(9), int642Bytes(8)), []int64{9}, "reverse iterator from 9 to 8") + + verifyIterator(t, db.Iterator(int642Bytes(2), int642Bytes(4)), []int64{2, 3}, "forward iterator from 2 to 4") + verifyIterator(t, db.Iterator(int642Bytes(4), int642Bytes(2)), []int64(nil), "forward iterator from 4 to 2") + verifyIterator(t, db.ReverseIterator(int642Bytes(4), int642Bytes(2)), []int64{4, 3}, "reverse iterator from 4 to 2") + verifyIterator(t, db.ReverseIterator(int642Bytes(2), int642Bytes(4)), []int64(nil), "reverse iterator from 2 to 4") + +} + +func verifyIterator(t *testing.T, itr Iterator, expected []int64, msg string) { + var list []int64 + for itr.Valid() { + list = append(list, bytes2Int64(itr.Key())) + itr.Next() + } + assert.Equal(t, expected, list, msg) +} diff --git a/db/c_level_db.go b/db/c_level_db.go new file mode 100644 index 00000000..30746126 --- /dev/null +++ b/db/c_level_db.go @@ -0,0 +1,312 @@ +// +build gcc + +package db + +import ( + "bytes" + "fmt" + "path/filepath" + + "github.com/jmhodges/levigo" +) + +func init() { + dbCreator := func(name string, dir string) (DB, error) { + return NewCLevelDB(name, dir) + } + registerDBCreator(LevelDBBackend, dbCreator, true) + registerDBCreator(CLevelDBBackend, dbCreator, false) +} + +var _ DB = (*CLevelDB)(nil) + +type CLevelDB struct { + db *levigo.DB + ro *levigo.ReadOptions + wo *levigo.WriteOptions + woSync *levigo.WriteOptions +} + +func NewCLevelDB(name string, dir string) (*CLevelDB, error) { + dbPath := filepath.Join(dir, name+".db") + + opts := levigo.NewOptions() + opts.SetCache(levigo.NewLRUCache(1 << 30)) + opts.SetCreateIfMissing(true) + db, err := levigo.Open(dbPath, opts) + if err != nil { + return nil, err + } + ro := levigo.NewReadOptions() + wo := levigo.NewWriteOptions() + woSync := levigo.NewWriteOptions() + woSync.SetSync(true) + database := &CLevelDB{ + db: db, + ro: ro, + wo: wo, + woSync: woSync, + } + return database, nil +} + +// Implements DB. +func (db *CLevelDB) Get(key []byte) []byte { + key = nonNilBytes(key) + res, err := db.db.Get(db.ro, key) + if err != nil { + panic(err) + } + return res +} + +// Implements DB. +func (db *CLevelDB) Has(key []byte) bool { + return db.Get(key) != nil +} + +// Implements DB. +func (db *CLevelDB) Set(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + err := db.db.Put(db.wo, key, value) + if err != nil { + panic(err) + } +} + +// Implements DB. +func (db *CLevelDB) SetSync(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + err := db.db.Put(db.woSync, key, value) + if err != nil { + panic(err) + } +} + +// Implements DB. +func (db *CLevelDB) Delete(key []byte) { + key = nonNilBytes(key) + err := db.db.Delete(db.wo, key) + if err != nil { + panic(err) + } +} + +// Implements DB. +func (db *CLevelDB) DeleteSync(key []byte) { + key = nonNilBytes(key) + err := db.db.Delete(db.woSync, key) + if err != nil { + panic(err) + } +} + +func (db *CLevelDB) DB() *levigo.DB { + return db.db +} + +// Implements DB. +func (db *CLevelDB) Close() { + db.db.Close() + db.ro.Close() + db.wo.Close() + db.woSync.Close() +} + +// Implements DB. +func (db *CLevelDB) Print() { + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + key := itr.Key() + value := itr.Value() + fmt.Printf("[%X]:\t[%X]\n", key, value) + } +} + +// Implements DB. +func (db *CLevelDB) Stats() map[string]string { + // TODO: Find the available properties for the C LevelDB implementation + keys := []string{} + + stats := make(map[string]string) + for _, key := range keys { + str := db.db.PropertyValue(key) + stats[key] = str + } + return stats +} + +//---------------------------------------- +// Batch + +// Implements DB. +func (db *CLevelDB) NewBatch() Batch { + batch := levigo.NewWriteBatch() + return &cLevelDBBatch{db, batch} +} + +type cLevelDBBatch struct { + db *CLevelDB + batch *levigo.WriteBatch +} + +// Implements Batch. +func (mBatch *cLevelDBBatch) Set(key, value []byte) { + mBatch.batch.Put(key, value) +} + +// Implements Batch. +func (mBatch *cLevelDBBatch) Delete(key []byte) { + mBatch.batch.Delete(key) +} + +// Implements Batch. +func (mBatch *cLevelDBBatch) Write() { + err := mBatch.db.db.Write(mBatch.db.wo, mBatch.batch) + if err != nil { + panic(err) + } +} + +// Implements Batch. +func (mBatch *cLevelDBBatch) WriteSync() { + err := mBatch.db.db.Write(mBatch.db.woSync, mBatch.batch) + if err != nil { + panic(err) + } +} + +//---------------------------------------- +// Iterator +// NOTE This is almost identical to db/go_level_db.Iterator +// Before creating a third version, refactor. + +func (db *CLevelDB) Iterator(start, end []byte) Iterator { + itr := db.db.NewIterator(db.ro) + return newCLevelDBIterator(itr, start, end, false) +} + +func (db *CLevelDB) ReverseIterator(start, end []byte) Iterator { + itr := db.db.NewIterator(db.ro) + return newCLevelDBIterator(itr, start, end, true) +} + +var _ Iterator = (*cLevelDBIterator)(nil) + +type cLevelDBIterator struct { + source *levigo.Iterator + start, end []byte + isReverse bool + isInvalid bool +} + +func newCLevelDBIterator(source *levigo.Iterator, start, end []byte, isReverse bool) *cLevelDBIterator { + if isReverse { + if start == nil { + source.SeekToLast() + } else { + source.Seek(start) + if source.Valid() { + soakey := source.Key() // start or after key + if bytes.Compare(start, soakey) < 0 { + source.Prev() + } + } else { + source.SeekToLast() + } + } + } else { + if start == nil { + source.SeekToFirst() + } else { + source.Seek(start) + } + } + return &cLevelDBIterator{ + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, + } +} + +func (itr cLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +func (itr cLevelDBIterator) Valid() bool { + + // Once invalid, forever invalid. + if itr.isInvalid { + return false + } + + // Panic on DB error. No way to recover. + itr.assertNoError() + + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false + } + + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + if itr.isReverse { + if end != nil && bytes.Compare(key, end) <= 0 { + itr.isInvalid = true + return false + } + } else { + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true + return false + } + } + + // It's valid. + return true +} + +func (itr cLevelDBIterator) Key() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Key() +} + +func (itr cLevelDBIterator) Value() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Value() +} + +func (itr cLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + if itr.isReverse { + itr.source.Prev() + } else { + itr.source.Next() + } +} + +func (itr cLevelDBIterator) Close() { + itr.source.Close() +} + +func (itr cLevelDBIterator) assertNoError() { + if err := itr.source.GetError(); err != nil { + panic(err) + } +} + +func (itr cLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("cLevelDBIterator is invalid") + } +} diff --git a/db/c_level_db_test.go b/db/c_level_db_test.go new file mode 100644 index 00000000..34bb7227 --- /dev/null +++ b/db/c_level_db_test.go @@ -0,0 +1,96 @@ +// +build gcc + +package db + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + cmn "github.com/tendermint/tmlibs/common" +) + +func BenchmarkRandomReadsWrites2(b *testing.B) { + b.StopTimer() + + numItems := int64(1000000) + internal := map[int64]int64{} + for i := 0; i < int(numItems); i++ { + internal[int64(i)] = int64(0) + } + db, err := NewCLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") + if err != nil { + b.Fatal(err.Error()) + return + } + + fmt.Println("ok, starting") + b.StartTimer() + + for i := 0; i < b.N; i++ { + // Write something + { + idx := (int64(cmn.RandInt()) % numItems) + internal[idx] += 1 + val := internal[idx] + idxBytes := int642Bytes(int64(idx)) + valBytes := int642Bytes(int64(val)) + //fmt.Printf("Set %X -> %X\n", idxBytes, valBytes) + db.Set( + idxBytes, + valBytes, + ) + } + // Read something + { + idx := (int64(cmn.RandInt()) % numItems) + val := internal[idx] + idxBytes := int642Bytes(int64(idx)) + valBytes := db.Get(idxBytes) + //fmt.Printf("Get %X -> %X\n", idxBytes, valBytes) + if val == 0 { + if !bytes.Equal(valBytes, nil) { + b.Errorf("Expected %v for %v, got %X", + nil, idx, valBytes) + break + } + } else { + if len(valBytes) != 8 { + b.Errorf("Expected length 8 for %v, got %X", + idx, valBytes) + break + } + valGot := bytes2Int64(valBytes) + if val != valGot { + b.Errorf("Expected %v for %v, got %v", + val, idx, valGot) + break + } + } + } + } + + db.Close() +} + +/* +func int642Bytes(i int64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +func bytes2Int64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} +*/ + +func TestCLevelDBBackend(t *testing.T) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, LevelDBBackend, "") + defer cleanupDBDir("", name) + + _, ok := db.(*CLevelDB) + assert.True(t, ok) +} diff --git a/db/common_test.go b/db/common_test.go new file mode 100644 index 00000000..6af6e15e --- /dev/null +++ b/db/common_test.go @@ -0,0 +1,191 @@ +package db + +import ( + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" +) + +//---------------------------------------- +// Helper functions. + +func checkValue(t *testing.T, db DB, key []byte, valueWanted []byte) { + valueGot := db.Get(key) + assert.Equal(t, valueWanted, valueGot) +} + +func checkValid(t *testing.T, itr Iterator, expected bool) { + valid := itr.Valid() + require.Equal(t, expected, valid) +} + +func checkNext(t *testing.T, itr Iterator, expected bool) { + itr.Next() + valid := itr.Valid() + require.Equal(t, expected, valid) +} + +func checkNextPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Next() }, "checkNextPanics expected panic but didn't") +} + +func checkDomain(t *testing.T, itr Iterator, start, end []byte) { + ds, de := itr.Domain() + assert.Equal(t, start, ds, "checkDomain domain start incorrect") + assert.Equal(t, end, de, "checkDomain domain end incorrect") +} + +func checkItem(t *testing.T, itr Iterator, key []byte, value []byte) { + k, v := itr.Key(), itr.Value() + assert.Exactly(t, key, k) + assert.Exactly(t, value, v) +} + +func checkInvalid(t *testing.T, itr Iterator) { + checkValid(t, itr, false) + checkKeyPanics(t, itr) + checkValuePanics(t, itr) + checkNextPanics(t, itr) +} + +func checkKeyPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkKeyPanics expected panic but didn't") +} + +func checkValuePanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkValuePanics expected panic but didn't") +} + +func newTempDB(t *testing.T, backend DBBackendType) (db DB) { + dir, dirname := cmn.Tempdir("db_common_test") + db = NewDB("testdb", backend, dirname) + dir.Close() + return db +} + +//---------------------------------------- +// mockDB + +// NOTE: not actually goroutine safe. +// If you want something goroutine safe, maybe you just want a MemDB. +type mockDB struct { + mtx sync.Mutex + calls map[string]int +} + +func newMockDB() *mockDB { + return &mockDB{ + calls: make(map[string]int), + } +} + +func (mdb *mockDB) Mutex() *sync.Mutex { + return &(mdb.mtx) +} + +func (mdb *mockDB) Get([]byte) []byte { + mdb.calls["Get"]++ + return nil +} + +func (mdb *mockDB) Has([]byte) bool { + mdb.calls["Has"]++ + return false +} + +func (mdb *mockDB) Set([]byte, []byte) { + mdb.calls["Set"]++ +} + +func (mdb *mockDB) SetSync([]byte, []byte) { + mdb.calls["SetSync"]++ +} + +func (mdb *mockDB) SetNoLock([]byte, []byte) { + mdb.calls["SetNoLock"]++ +} + +func (mdb *mockDB) SetNoLockSync([]byte, []byte) { + mdb.calls["SetNoLockSync"]++ +} + +func (mdb *mockDB) Delete([]byte) { + mdb.calls["Delete"]++ +} + +func (mdb *mockDB) DeleteSync([]byte) { + mdb.calls["DeleteSync"]++ +} + +func (mdb *mockDB) DeleteNoLock([]byte) { + mdb.calls["DeleteNoLock"]++ +} + +func (mdb *mockDB) DeleteNoLockSync([]byte) { + mdb.calls["DeleteNoLockSync"]++ +} + +func (mdb *mockDB) Iterator(start, end []byte) Iterator { + mdb.calls["Iterator"]++ + return &mockIterator{} +} + +func (mdb *mockDB) ReverseIterator(start, end []byte) Iterator { + mdb.calls["ReverseIterator"]++ + return &mockIterator{} +} + +func (mdb *mockDB) Close() { + mdb.calls["Close"]++ +} + +func (mdb *mockDB) NewBatch() Batch { + mdb.calls["NewBatch"]++ + return &memBatch{db: mdb} +} + +func (mdb *mockDB) Print() { + mdb.calls["Print"]++ + fmt.Printf("mockDB{%v}", mdb.Stats()) +} + +func (mdb *mockDB) Stats() map[string]string { + mdb.calls["Stats"]++ + + res := make(map[string]string) + for key, count := range mdb.calls { + res[key] = fmt.Sprintf("%d", count) + } + return res +} + +//---------------------------------------- +// mockIterator + +type mockIterator struct{} + +func (mockIterator) Domain() (start []byte, end []byte) { + return nil, nil +} + +func (mockIterator) Valid() bool { + return false +} + +func (mockIterator) Next() { +} + +func (mockIterator) Key() []byte { + return nil +} + +func (mockIterator) Value() []byte { + return nil +} + +func (mockIterator) Close() { +} diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..86993766 --- /dev/null +++ b/db/db.go @@ -0,0 +1,36 @@ +package db + +import "fmt" + +//---------------------------------------- +// Main entry + +type DBBackendType string + +const ( + LevelDBBackend DBBackendType = "leveldb" // legacy, defaults to goleveldb unless +gcc + CLevelDBBackend DBBackendType = "cleveldb" + GoLevelDBBackend DBBackendType = "goleveldb" + MemDBBackend DBBackendType = "memdb" + FSDBBackend DBBackendType = "fsdb" // using the filesystem naively +) + +type dbCreator func(name string, dir string) (DB, error) + +var backends = map[DBBackendType]dbCreator{} + +func registerDBCreator(backend DBBackendType, creator dbCreator, force bool) { + _, ok := backends[backend] + if !force && ok { + return + } + backends[backend] = creator +} + +func NewDB(name string, backend DBBackendType, dir string) DB { + db, err := backends[backend](name, dir) + if err != nil { + panic(fmt.Sprintf("Error initializing DB: %v", err)) + } + return db +} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 00000000..a5690101 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,194 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDBIteratorSingleKey(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(nil, nil) + + checkValid(t, itr, true) + checkNext(t, itr, false) + checkValid(t, itr, false) + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorTwoKeys(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + db.SetSync(bz("2"), bz("value_1")) + + { // Fail by calling Next too much + itr := db.Iterator(nil, nil) + checkValid(t, itr, true) + + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkNext(t, itr, false) + checkValid(t, itr, false) + + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + } + }) + } +} + +func TestDBIteratorMany(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + keys := make([][]byte, 100) + for i := 0; i < 100; i++ { + keys[i] = []byte{byte(i)} + } + + value := []byte{5} + for _, k := range keys { + db.Set(k, value) + } + + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + assert.Equal(t, db.Get(itr.Key()), itr.Value()) + } + }) + } +} + +func TestDBIteratorEmpty(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(nil, nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorEmptyBeginAfter(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(bz("1"), nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorNonemptyBeginAfter(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(bz("2"), nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBBatchWrite1(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Delete(bz("3")) + batch.Set(bz("4"), bz("4")) + batch.Write() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWrite2(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Set(bz("4"), bz("4")) + batch.Delete(bz("3")) + batch.Write() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWriteSync1(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Delete(bz("3")) + batch.Set(bz("4"), bz("4")) + batch.WriteSync() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 2, mdb.calls["SetNoLock"]) + assert.Equal(t, 1, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWriteSync2(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Set(bz("4"), bz("4")) + batch.Delete(bz("3")) + batch.WriteSync() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLockSync"]) +} diff --git a/db/debug_db.go b/db/debug_db.go new file mode 100644 index 00000000..a3e785c2 --- /dev/null +++ b/db/debug_db.go @@ -0,0 +1,256 @@ +package db + +import ( + "fmt" + "sync" + + cmn "github.com/tendermint/tmlibs/common" +) + +func _fmt(f string, az ...interface{}) string { + return fmt.Sprintf(f, az...) +} + +//---------------------------------------- +// debugDB + +type debugDB struct { + label string + db DB +} + +// For printing all operationgs to the console for debugging. +func NewDebugDB(label string, db DB) debugDB { + return debugDB{ + label: label, + db: db, + } +} + +// Implements atomicSetDeleter. +func (ddb debugDB) Mutex() *sync.Mutex { return nil } + +// Implements DB. +func (ddb debugDB) Get(key []byte) (value []byte) { + defer func() { + fmt.Printf("%v.Get(%v) %v\n", ddb.label, + cmn.ColoredBytes(key, cmn.Cyan, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + }() + value = ddb.db.Get(key) + return +} + +// Implements DB. +func (ddb debugDB) Has(key []byte) (has bool) { + defer func() { + fmt.Printf("%v.Has(%v) %v\n", ddb.label, + cmn.ColoredBytes(key, cmn.Cyan, cmn.Blue), has) + }() + return ddb.db.Has(key) +} + +// Implements DB. +func (ddb debugDB) Set(key []byte, value []byte) { + fmt.Printf("%v.Set(%v, %v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Yellow, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + ddb.db.Set(key, value) +} + +// Implements DB. +func (ddb debugDB) SetSync(key []byte, value []byte) { + fmt.Printf("%v.SetSync(%v, %v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Yellow, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + ddb.db.SetSync(key, value) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) SetNoLock(key []byte, value []byte) { + fmt.Printf("%v.SetNoLock(%v, %v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Yellow, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + ddb.db.(atomicSetDeleter).SetNoLock(key, value) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) SetNoLockSync(key []byte, value []byte) { + fmt.Printf("%v.SetNoLockSync(%v, %v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Yellow, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + ddb.db.(atomicSetDeleter).SetNoLockSync(key, value) +} + +// Implements DB. +func (ddb debugDB) Delete(key []byte) { + fmt.Printf("%v.Delete(%v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Red, cmn.Yellow)) + ddb.db.Delete(key) +} + +// Implements DB. +func (ddb debugDB) DeleteSync(key []byte) { + fmt.Printf("%v.DeleteSync(%v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Red, cmn.Yellow)) + ddb.db.DeleteSync(key) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) DeleteNoLock(key []byte) { + fmt.Printf("%v.DeleteNoLock(%v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Red, cmn.Yellow)) + ddb.db.(atomicSetDeleter).DeleteNoLock(key) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) DeleteNoLockSync(key []byte) { + fmt.Printf("%v.DeleteNoLockSync(%v)\n", ddb.label, + cmn.ColoredBytes(key, cmn.Red, cmn.Yellow)) + ddb.db.(atomicSetDeleter).DeleteNoLockSync(key) +} + +// Implements DB. +func (ddb debugDB) Iterator(start, end []byte) Iterator { + fmt.Printf("%v.Iterator(%v, %v)\n", ddb.label, + cmn.ColoredBytes(start, cmn.Cyan, cmn.Blue), + cmn.ColoredBytes(end, cmn.Cyan, cmn.Blue)) + return NewDebugIterator(ddb.label, ddb.db.Iterator(start, end)) +} + +// Implements DB. +func (ddb debugDB) ReverseIterator(start, end []byte) Iterator { + fmt.Printf("%v.ReverseIterator(%v, %v)\n", ddb.label, + cmn.ColoredBytes(start, cmn.Cyan, cmn.Blue), + cmn.ColoredBytes(end, cmn.Cyan, cmn.Blue)) + return NewDebugIterator(ddb.label, ddb.db.ReverseIterator(start, end)) +} + +// Implements DB. +// Panics if the underlying db is not an +// atomicSetDeleter. +func (ddb debugDB) NewBatch() Batch { + fmt.Printf("%v.NewBatch()\n", ddb.label) + return NewDebugBatch(ddb.label, ddb.db.NewBatch()) +} + +// Implements DB. +func (ddb debugDB) Close() { + fmt.Printf("%v.Close()\n", ddb.label) + ddb.db.Close() +} + +// Implements DB. +func (ddb debugDB) Print() { + ddb.db.Print() +} + +// Implements DB. +func (ddb debugDB) Stats() map[string]string { + return ddb.db.Stats() +} + +//---------------------------------------- +// debugIterator + +type debugIterator struct { + label string + itr Iterator +} + +// For printing all operationgs to the console for debugging. +func NewDebugIterator(label string, itr Iterator) debugIterator { + return debugIterator{ + label: label, + itr: itr, + } +} + +// Implements Iterator. +func (ditr debugIterator) Domain() (start []byte, end []byte) { + defer func() { + fmt.Printf("%v.itr.Domain() (%X,%X)\n", ditr.label, start, end) + }() + start, end = ditr.itr.Domain() + return +} + +// Implements Iterator. +func (ditr debugIterator) Valid() (ok bool) { + defer func() { + fmt.Printf("%v.itr.Valid() %v\n", ditr.label, ok) + }() + ok = ditr.itr.Valid() + return +} + +// Implements Iterator. +func (ditr debugIterator) Next() { + fmt.Printf("%v.itr.Next()\n", ditr.label) + ditr.itr.Next() +} + +// Implements Iterator. +func (ditr debugIterator) Key() (key []byte) { + key = ditr.itr.Key() + fmt.Printf("%v.itr.Key() %v\n", ditr.label, + cmn.ColoredBytes(key, cmn.Cyan, cmn.Blue)) + return +} + +// Implements Iterator. +func (ditr debugIterator) Value() (value []byte) { + value = ditr.itr.Value() + fmt.Printf("%v.itr.Value() %v\n", ditr.label, + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + return +} + +// Implements Iterator. +func (ditr debugIterator) Close() { + fmt.Printf("%v.itr.Close()\n", ditr.label) + ditr.itr.Close() +} + +//---------------------------------------- +// debugBatch + +type debugBatch struct { + label string + bch Batch +} + +// For printing all operationgs to the console for debugging. +func NewDebugBatch(label string, bch Batch) debugBatch { + return debugBatch{ + label: label, + bch: bch, + } +} + +// Implements Batch. +func (dbch debugBatch) Set(key, value []byte) { + fmt.Printf("%v.batch.Set(%v, %v)\n", dbch.label, + cmn.ColoredBytes(key, cmn.Yellow, cmn.Blue), + cmn.ColoredBytes(value, cmn.Green, cmn.Blue)) + dbch.bch.Set(key, value) +} + +// Implements Batch. +func (dbch debugBatch) Delete(key []byte) { + fmt.Printf("%v.batch.Delete(%v)\n", dbch.label, + cmn.ColoredBytes(key, cmn.Red, cmn.Yellow)) + dbch.bch.Delete(key) +} + +// Implements Batch. +func (dbch debugBatch) Write() { + fmt.Printf("%v.batch.Write()\n", dbch.label) + dbch.bch.Write() +} + +// Implements Batch. +func (dbch debugBatch) WriteSync() { + fmt.Printf("%v.batch.WriteSync()\n", dbch.label) + dbch.bch.WriteSync() +} diff --git a/db/fsdb.go b/db/fsdb.go new file mode 100644 index 00000000..b5711ba3 --- /dev/null +++ b/db/fsdb.go @@ -0,0 +1,262 @@ +package db + +import ( + "fmt" + "io/ioutil" + "net/url" + "os" + "path/filepath" + "sort" + "sync" + + "github.com/pkg/errors" + cmn "github.com/tendermint/tmlibs/common" +) + +const ( + keyPerm = os.FileMode(0600) + dirPerm = os.FileMode(0700) +) + +func init() { + registerDBCreator(FSDBBackend, func(name string, dir string) (DB, error) { + dbPath := filepath.Join(dir, name+".db") + return NewFSDB(dbPath), nil + }, false) +} + +var _ DB = (*FSDB)(nil) + +// It's slow. +type FSDB struct { + mtx sync.Mutex + dir string +} + +func NewFSDB(dir string) *FSDB { + err := os.MkdirAll(dir, dirPerm) + if err != nil { + panic(errors.Wrap(err, "Creating FSDB dir "+dir)) + } + database := &FSDB{ + dir: dir, + } + return database +} + +func (db *FSDB) Get(key []byte) []byte { + db.mtx.Lock() + defer db.mtx.Unlock() + key = escapeKey(key) + + path := db.nameToPath(key) + value, err := read(path) + if os.IsNotExist(err) { + return nil + } else if err != nil { + panic(errors.Wrapf(err, "Getting key %s (0x%X)", string(key), key)) + } + return value +} + +func (db *FSDB) Has(key []byte) bool { + db.mtx.Lock() + defer db.mtx.Unlock() + key = escapeKey(key) + + path := db.nameToPath(key) + return cmn.FileExists(path) +} + +func (db *FSDB) Set(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +func (db *FSDB) SetSync(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) SetNoLock(key []byte, value []byte) { + key = escapeKey(key) + value = nonNilBytes(value) + path := db.nameToPath(key) + err := write(path, value) + if err != nil { + panic(errors.Wrapf(err, "Setting key %s (0x%X)", string(key), key)) + } +} + +func (db *FSDB) Delete(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +func (db *FSDB) DeleteSync(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) DeleteNoLock(key []byte) { + key = escapeKey(key) + path := db.nameToPath(key) + err := remove(path) + if os.IsNotExist(err) { + return + } else if err != nil { + panic(errors.Wrapf(err, "Removing key %s (0x%X)", string(key), key)) + } +} + +func (db *FSDB) Close() { + // Nothing to do. +} + +func (db *FSDB) Print() { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Print not yet implemented") +} + +func (db *FSDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Stats not yet implemented") +} + +func (db *FSDB) NewBatch() Batch { + db.mtx.Lock() + defer db.mtx.Unlock() + + // Not sure we would ever want to try... + // It doesn't seem easy for general filesystems. + panic("FSDB.NewBatch not yet implemented") +} + +func (db *FSDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +func (db *FSDB) Iterator(start, end []byte) Iterator { + return db.MakeIterator(start, end, false) +} + +func (db *FSDB) MakeIterator(start, end []byte, isReversed bool) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + keys, err := list(db.dir, start, end, isReversed) + if err != nil { + panic(errors.Wrapf(err, "Listing keys in %s", db.dir)) + } + if isReversed { + sort.Sort(sort.Reverse(sort.StringSlice(keys))) + } else { + sort.Strings(keys) + } + return newMemDBIterator(db, keys, start, end) +} + +func (db *FSDB) ReverseIterator(start, end []byte) Iterator { + return db.MakeIterator(start, end, true) +} + +func (db *FSDB) nameToPath(name []byte) string { + n := url.PathEscape(string(name)) + return filepath.Join(db.dir, n) +} + +// Read some bytes to a file. +// CONTRACT: returns os errors directly without wrapping. +func read(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + d, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + return d, nil +} + +// Write some bytes from a file. +// CONTRACT: returns os errors directly without wrapping. +func write(path string, d []byte) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, keyPerm) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(d) + if err != nil { + return err + } + err = f.Sync() + return err +} + +// Remove a file. +// CONTRACT: returns os errors directly without wrapping. +func remove(path string) error { + return os.Remove(path) +} + +// List keys in a directory, stripping of escape sequences and dir portions. +// CONTRACT: returns os errors directly without wrapping. +func list(dirPath string, start, end []byte, isReversed bool) ([]string, error) { + dir, err := os.Open(dirPath) + if err != nil { + return nil, err + } + defer dir.Close() + + names, err := dir.Readdirnames(0) + if err != nil { + return nil, err + } + var keys []string + for _, name := range names { + n, err := url.PathUnescape(name) + if err != nil { + return nil, fmt.Errorf("Failed to unescape %s while listing", name) + } + key := unescapeKey([]byte(n)) + if IsKeyInDomain(key, start, end, isReversed) { + keys = append(keys, string(key)) + } + } + return keys, nil +} + +// To support empty or nil keys, while the file system doesn't allow empty +// filenames. +func escapeKey(key []byte) []byte { + return []byte("k_" + string(key)) +} +func unescapeKey(escKey []byte) []byte { + if len(escKey) < 2 { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + if string(escKey[:2]) != "k_" { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + return escKey[2:] +} diff --git a/db/go_level_db.go b/db/go_level_db.go new file mode 100644 index 00000000..eca8a07f --- /dev/null +++ b/db/go_level_db.go @@ -0,0 +1,327 @@ +package db + +import ( + "bytes" + "fmt" + "path/filepath" + + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/iterator" + "github.com/syndtr/goleveldb/leveldb/opt" + + cmn "github.com/tendermint/tmlibs/common" +) + +func init() { + dbCreator := func(name string, dir string) (DB, error) { + return NewGoLevelDB(name, dir) + } + registerDBCreator(LevelDBBackend, dbCreator, false) + registerDBCreator(GoLevelDBBackend, dbCreator, false) +} + +var _ DB = (*GoLevelDB)(nil) + +type GoLevelDB struct { + db *leveldb.DB +} + +func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { + dbPath := filepath.Join(dir, name+".db") + db, err := leveldb.OpenFile(dbPath, nil) + if err != nil { + return nil, err + } + database := &GoLevelDB{ + db: db, + } + return database, nil +} + +// Implements DB. +func (db *GoLevelDB) Get(key []byte) []byte { + key = nonNilBytes(key) + res, err := db.db.Get(key, nil) + if err != nil { + if err == errors.ErrNotFound { + return nil + } + panic(err) + } + return res +} + +// Implements DB. +func (db *GoLevelDB) Has(key []byte) bool { + return db.Get(key) != nil +} + +// Implements DB. +func (db *GoLevelDB) Set(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + err := db.db.Put(key, value, nil) + if err != nil { + cmn.PanicCrisis(err) + } +} + +// Implements DB. +func (db *GoLevelDB) SetSync(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + err := db.db.Put(key, value, &opt.WriteOptions{Sync: true}) + if err != nil { + cmn.PanicCrisis(err) + } +} + +// Implements DB. +func (db *GoLevelDB) Delete(key []byte) { + key = nonNilBytes(key) + err := db.db.Delete(key, nil) + if err != nil { + cmn.PanicCrisis(err) + } +} + +// Implements DB. +func (db *GoLevelDB) DeleteSync(key []byte) { + key = nonNilBytes(key) + err := db.db.Delete(key, &opt.WriteOptions{Sync: true}) + if err != nil { + cmn.PanicCrisis(err) + } +} + +func (db *GoLevelDB) DB() *leveldb.DB { + return db.db +} + +// Implements DB. +func (db *GoLevelDB) Close() { + db.db.Close() +} + +// Implements DB. +func (db *GoLevelDB) Print() { + str, _ := db.db.GetProperty("leveldb.stats") + fmt.Printf("%v\n", str) + + itr := db.db.NewIterator(nil, nil) + for itr.Next() { + key := itr.Key() + value := itr.Value() + fmt.Printf("[%X]:\t[%X]\n", key, value) + } +} + +// Implements DB. +func (db *GoLevelDB) Stats() map[string]string { + keys := []string{ + "leveldb.num-files-at-level{n}", + "leveldb.stats", + "leveldb.sstables", + "leveldb.blockpool", + "leveldb.cachedblock", + "leveldb.openedtables", + "leveldb.alivesnaps", + "leveldb.aliveiters", + } + + stats := make(map[string]string) + for _, key := range keys { + str, err := db.db.GetProperty(key) + if err == nil { + stats[key] = str + } + } + return stats +} + +//---------------------------------------- +// Batch + +// Implements DB. +func (db *GoLevelDB) NewBatch() Batch { + batch := new(leveldb.Batch) + return &goLevelDBBatch{db, batch} +} + +type goLevelDBBatch struct { + db *GoLevelDB + batch *leveldb.Batch +} + +// Implements Batch. +func (mBatch *goLevelDBBatch) Set(key, value []byte) { + mBatch.batch.Put(key, value) +} + +// Implements Batch. +func (mBatch *goLevelDBBatch) Delete(key []byte) { + mBatch.batch.Delete(key) +} + +// Implements Batch. +func (mBatch *goLevelDBBatch) Write() { + err := mBatch.db.db.Write(mBatch.batch, &opt.WriteOptions{Sync: false}) + if err != nil { + panic(err) + } +} + +// Implements Batch. +func (mBatch *goLevelDBBatch) WriteSync() { + err := mBatch.db.db.Write(mBatch.batch, &opt.WriteOptions{Sync: true}) + if err != nil { + panic(err) + } +} + +//---------------------------------------- +// Iterator +// NOTE This is almost identical to db/c_level_db.Iterator +// Before creating a third version, refactor. + +// Implements DB. +func (db *GoLevelDB) Iterator(start, end []byte) Iterator { + itr := db.db.NewIterator(nil, nil) + return newGoLevelDBIterator(itr, start, end, false) +} + +// Implements DB. +func (db *GoLevelDB) ReverseIterator(start, end []byte) Iterator { + itr := db.db.NewIterator(nil, nil) + return newGoLevelDBIterator(itr, start, end, true) +} + +type goLevelDBIterator struct { + source iterator.Iterator + start []byte + end []byte + isReverse bool + isInvalid bool +} + +var _ Iterator = (*goLevelDBIterator)(nil) + +func newGoLevelDBIterator(source iterator.Iterator, start, end []byte, isReverse bool) *goLevelDBIterator { + if isReverse { + if start == nil { + source.Last() + } else { + valid := source.Seek(start) + if valid { + soakey := source.Key() // start or after key + if bytes.Compare(start, soakey) < 0 { + source.Prev() + } + } else { + source.Last() + } + } + } else { + if start == nil { + source.First() + } else { + source.Seek(start) + } + } + return &goLevelDBIterator{ + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, + } +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Valid() bool { + + // Once invalid, forever invalid. + if itr.isInvalid { + return false + } + + // Panic on DB error. No way to recover. + itr.assertNoError() + + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false + } + + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + + if itr.isReverse { + if end != nil && bytes.Compare(key, end) <= 0 { + itr.isInvalid = true + return false + } + } else { + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true + return false + } + } + + // Valid + return true +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Key() []byte { + // Key returns a copy of the current key. + // See https://github.com/syndtr/goleveldb/blob/52c212e6c196a1404ea59592d3f1c227c9f034b2/leveldb/iterator/iter.go#L88 + itr.assertNoError() + itr.assertIsValid() + return cp(itr.source.Key()) +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Value() []byte { + // Value returns a copy of the current value. + // See https://github.com/syndtr/goleveldb/blob/52c212e6c196a1404ea59592d3f1c227c9f034b2/leveldb/iterator/iter.go#L88 + itr.assertNoError() + itr.assertIsValid() + return cp(itr.source.Value()) +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + if itr.isReverse { + itr.source.Prev() + } else { + itr.source.Next() + } +} + +// Implements Iterator. +func (itr *goLevelDBIterator) Close() { + itr.source.Release() +} + +func (itr *goLevelDBIterator) assertNoError() { + if err := itr.source.Error(); err != nil { + panic(err) + } +} + +func (itr goLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("goLevelDBIterator is invalid") + } +} diff --git a/db/go_level_db_test.go b/db/go_level_db_test.go new file mode 100644 index 00000000..266add8b --- /dev/null +++ b/db/go_level_db_test.go @@ -0,0 +1,83 @@ +package db + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" + + cmn "github.com/tendermint/tmlibs/common" +) + +func BenchmarkRandomReadsWrites(b *testing.B) { + b.StopTimer() + + numItems := int64(1000000) + internal := map[int64]int64{} + for i := 0; i < int(numItems); i++ { + internal[int64(i)] = int64(0) + } + db, err := NewGoLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") + if err != nil { + b.Fatal(err.Error()) + return + } + + fmt.Println("ok, starting") + b.StartTimer() + + for i := 0; i < b.N; i++ { + // Write something + { + idx := (int64(cmn.RandInt()) % numItems) + internal[idx]++ + val := internal[idx] + idxBytes := int642Bytes(int64(idx)) + valBytes := int642Bytes(int64(val)) + //fmt.Printf("Set %X -> %X\n", idxBytes, valBytes) + db.Set( + idxBytes, + valBytes, + ) + } + // Read something + { + idx := (int64(cmn.RandInt()) % numItems) + val := internal[idx] + idxBytes := int642Bytes(int64(idx)) + valBytes := db.Get(idxBytes) + //fmt.Printf("Get %X -> %X\n", idxBytes, valBytes) + if val == 0 { + if !bytes.Equal(valBytes, nil) { + b.Errorf("Expected %v for %v, got %X", + nil, idx, valBytes) + break + } + } else { + if len(valBytes) != 8 { + b.Errorf("Expected length 8 for %v, got %X", + idx, valBytes) + break + } + valGot := bytes2Int64(valBytes) + if val != valGot { + b.Errorf("Expected %v for %v, got %v", + val, idx, valGot) + break + } + } + } + } + + db.Close() +} + +func int642Bytes(i int64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +func bytes2Int64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} diff --git a/db/mem_batch.go b/db/mem_batch.go new file mode 100644 index 00000000..5c5d0c13 --- /dev/null +++ b/db/mem_batch.go @@ -0,0 +1,72 @@ +package db + +import ( + "sync" +) + +type atomicSetDeleter interface { + Mutex() *sync.Mutex + SetNoLock(key, value []byte) + SetNoLockSync(key, value []byte) + DeleteNoLock(key []byte) + DeleteNoLockSync(key []byte) +} + +type memBatch struct { + db atomicSetDeleter + ops []operation +} + +type opType int + +const ( + opTypeSet opType = 1 + opTypeDelete opType = 2 +) + +type operation struct { + opType + key []byte + value []byte +} + +func (mBatch *memBatch) Set(key, value []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) +} + +func (mBatch *memBatch) Delete(key []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) +} + +func (mBatch *memBatch) Write() { + mBatch.write(false) +} + +func (mBatch *memBatch) WriteSync() { + mBatch.write(true) +} + +func (mBatch *memBatch) write(doSync bool) { + if mtx := mBatch.db.Mutex(); mtx != nil { + mtx.Lock() + defer mtx.Unlock() + } + + for i, op := range mBatch.ops { + if doSync && i == (len(mBatch.ops)-1) { + switch op.opType { + case opTypeSet: + mBatch.db.SetNoLockSync(op.key, op.value) + case opTypeDelete: + mBatch.db.DeleteNoLockSync(op.key) + } + break // we're done. + } + switch op.opType { + case opTypeSet: + mBatch.db.SetNoLock(op.key, op.value) + case opTypeDelete: + mBatch.db.DeleteNoLock(op.key) + } + } +} diff --git a/db/mem_db.go b/db/mem_db.go new file mode 100644 index 00000000..1521f87a --- /dev/null +++ b/db/mem_db.go @@ -0,0 +1,255 @@ +package db + +import ( + "fmt" + "sort" + "sync" +) + +func init() { + registerDBCreator(MemDBBackend, func(name string, dir string) (DB, error) { + return NewMemDB(), nil + }, false) +} + +var _ DB = (*MemDB)(nil) + +type MemDB struct { + mtx sync.Mutex + db map[string][]byte +} + +func NewMemDB() *MemDB { + database := &MemDB{ + db: make(map[string][]byte), + } + return database +} + +// Implements atomicSetDeleter. +func (db *MemDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +// Implements DB. +func (db *MemDB) Get(key []byte) []byte { + db.mtx.Lock() + defer db.mtx.Unlock() + key = nonNilBytes(key) + + value := db.db[string(key)] + return value +} + +// Implements DB. +func (db *MemDB) Has(key []byte) bool { + db.mtx.Lock() + defer db.mtx.Unlock() + key = nonNilBytes(key) + + _, ok := db.db[string(key)] + return ok +} + +// Implements DB. +func (db *MemDB) Set(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// Implements DB. +func (db *MemDB) SetSync(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// Implements atomicSetDeleter. +func (db *MemDB) SetNoLock(key []byte, value []byte) { + db.SetNoLockSync(key, value) +} + +// Implements atomicSetDeleter. +func (db *MemDB) SetNoLockSync(key []byte, value []byte) { + key = nonNilBytes(key) + value = nonNilBytes(value) + + db.db[string(key)] = value +} + +// Implements DB. +func (db *MemDB) Delete(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// Implements DB. +func (db *MemDB) DeleteSync(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// Implements atomicSetDeleter. +func (db *MemDB) DeleteNoLock(key []byte) { + db.DeleteNoLockSync(key) +} + +// Implements atomicSetDeleter. +func (db *MemDB) DeleteNoLockSync(key []byte) { + key = nonNilBytes(key) + + delete(db.db, string(key)) +} + +// Implements DB. +func (db *MemDB) Close() { + // Close is a noop since for an in-memory + // database, we don't have a destination + // to flush contents to nor do we want + // any data loss on invoking Close() + // See the discussion in https://github.com/tendermint/tmlibs/pull/56 +} + +// Implements DB. +func (db *MemDB) Print() { + db.mtx.Lock() + defer db.mtx.Unlock() + + for key, value := range db.db { + fmt.Printf("[%X]:\t[%X]\n", []byte(key), value) + } +} + +// Implements DB. +func (db *MemDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + + stats := make(map[string]string) + stats["database.type"] = "memDB" + stats["database.size"] = fmt.Sprintf("%d", len(db.db)) + return stats +} + +// Implements DB. +func (db *MemDB) NewBatch() Batch { + db.mtx.Lock() + defer db.mtx.Unlock() + + return &memBatch{db, nil} +} + +//---------------------------------------- +// Iterator + +// Implements DB. +func (db *MemDB) Iterator(start, end []byte) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() + + keys := db.getSortedKeys(start, end, false) + return newMemDBIterator(db, keys, start, end) +} + +// Implements DB. +func (db *MemDB) ReverseIterator(start, end []byte) Iterator { + db.mtx.Lock() + defer db.mtx.Unlock() + + keys := db.getSortedKeys(start, end, true) + return newMemDBIterator(db, keys, start, end) +} + +// We need a copy of all of the keys. +// Not the best, but probably not a bottleneck depending. +type memDBIterator struct { + db DB + cur int + keys []string + start []byte + end []byte +} + +var _ Iterator = (*memDBIterator)(nil) + +// Keys is expected to be in reverse order for reverse iterators. +func newMemDBIterator(db DB, keys []string, start, end []byte) *memDBIterator { + return &memDBIterator{ + db: db, + cur: 0, + keys: keys, + start: start, + end: end, + } +} + +// Implements Iterator. +func (itr *memDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +// Implements Iterator. +func (itr *memDBIterator) Valid() bool { + return 0 <= itr.cur && itr.cur < len(itr.keys) +} + +// Implements Iterator. +func (itr *memDBIterator) Next() { + itr.assertIsValid() + itr.cur++ +} + +// Implements Iterator. +func (itr *memDBIterator) Key() []byte { + itr.assertIsValid() + return []byte(itr.keys[itr.cur]) +} + +// Implements Iterator. +func (itr *memDBIterator) Value() []byte { + itr.assertIsValid() + key := []byte(itr.keys[itr.cur]) + return itr.db.Get(key) +} + +// Implements Iterator. +func (itr *memDBIterator) Close() { + itr.keys = nil + itr.db = nil +} + +func (itr *memDBIterator) assertIsValid() { + if !itr.Valid() { + panic("memDBIterator is invalid") + } +} + +//---------------------------------------- +// Misc. + +func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string { + keys := []string{} + for key := range db.db { + inDomain := IsKeyInDomain([]byte(key), start, end, reverse) + if inDomain { + keys = append(keys, key) + } + } + sort.Strings(keys) + if reverse { + nkeys := len(keys) + for i := 0; i < nkeys/2; i++ { + temp := keys[i] + keys[i] = keys[nkeys-i-1] + keys[nkeys-i-1] = temp + } + } + return keys +} diff --git a/db/prefix_db.go b/db/prefix_db.go new file mode 100644 index 00000000..5bb53ebd --- /dev/null +++ b/db/prefix_db.go @@ -0,0 +1,355 @@ +package db + +import ( + "bytes" + "fmt" + "sync" +) + +// IteratePrefix is a convenience function for iterating over a key domain +// restricted by prefix. +func IteratePrefix(db DB, prefix []byte) Iterator { + var start, end []byte + if len(prefix) == 0 { + start = nil + end = nil + } else { + start = cp(prefix) + end = cpIncr(prefix) + } + return db.Iterator(start, end) +} + +/* +TODO: Make test, maybe rename. +// Like IteratePrefix but the iterator strips the prefix from the keys. +func IteratePrefixStripped(db DB, prefix []byte) Iterator { + start, end := ... + return newPrefixIterator(prefix, start, end, IteratePrefix(db, prefix)) +} +*/ + +//---------------------------------------- +// prefixDB + +type prefixDB struct { + mtx sync.Mutex + prefix []byte + db DB +} + +// NewPrefixDB lets you namespace multiple DBs within a single DB. +func NewPrefixDB(db DB, prefix []byte) *prefixDB { + return &prefixDB{ + prefix: prefix, + db: db, + } +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) Mutex() *sync.Mutex { + return &(pdb.mtx) +} + +// Implements DB. +func (pdb *prefixDB) Get(key []byte) []byte { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pkey := pdb.prefixed(key) + value := pdb.db.Get(pkey) + return value +} + +// Implements DB. +func (pdb *prefixDB) Has(key []byte) bool { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + return pdb.db.Has(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) Set(key []byte, value []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pkey := pdb.prefixed(key) + pdb.db.Set(pkey, value) +} + +// Implements DB. +func (pdb *prefixDB) SetSync(key []byte, value []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.SetSync(pdb.prefixed(key), value) +} + +// Implements DB. +func (pdb *prefixDB) Delete(key []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.Delete(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) DeleteSync(key []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.DeleteSync(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) Iterator(start, end []byte) Iterator { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + var pstart, pend []byte + pstart = append(cp(pdb.prefix), start...) + if end == nil { + pend = cpIncr(pdb.prefix) + } else { + pend = append(cp(pdb.prefix), end...) + } + return newPrefixIterator( + pdb.prefix, + start, + end, + pdb.db.Iterator( + pstart, + pend, + ), + ) +} + +// Implements DB. +func (pdb *prefixDB) ReverseIterator(start, end []byte) Iterator { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + var pstart, pend []byte + if start == nil { + // This may cause the underlying iterator to start with + // an item which doesn't start with prefix. We will skip + // that item later in this function. See 'skipOne'. + pstart = cpIncr(pdb.prefix) + } else { + pstart = append(cp(pdb.prefix), start...) + } + if end == nil { + // This may cause the underlying iterator to end with an + // item which doesn't start with prefix. The + // prefixIterator will terminate iteration + // automatically upon detecting this. + pend = cpDecr(pdb.prefix) + } else { + pend = append(cp(pdb.prefix), end...) + } + ritr := pdb.db.ReverseIterator(pstart, pend) + if start == nil { + skipOne(ritr, cpIncr(pdb.prefix)) + } + return newPrefixIterator( + pdb.prefix, + start, + end, + ritr, + ) +} + +// Implements DB. +// Panics if the underlying DB is not an +// atomicSetDeleter. +func (pdb *prefixDB) NewBatch() Batch { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + return newPrefixBatch(pdb.prefix, pdb.db.NewBatch()) +} + +/* NOTE: Uncomment to use memBatch instead of prefixBatch +// Implements atomicSetDeleter. +func (pdb *prefixDB) SetNoLock(key []byte, value []byte) { + pdb.db.(atomicSetDeleter).SetNoLock(pdb.prefixed(key), value) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) SetNoLockSync(key []byte, value []byte) { + pdb.db.(atomicSetDeleter).SetNoLockSync(pdb.prefixed(key), value) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) DeleteNoLock(key []byte) { + pdb.db.(atomicSetDeleter).DeleteNoLock(pdb.prefixed(key)) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) DeleteNoLockSync(key []byte) { + pdb.db.(atomicSetDeleter).DeleteNoLockSync(pdb.prefixed(key)) +} +*/ + +// Implements DB. +func (pdb *prefixDB) Close() { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.Close() +} + +// Implements DB. +func (pdb *prefixDB) Print() { + fmt.Printf("prefix: %X\n", pdb.prefix) + + itr := pdb.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + key := itr.Key() + value := itr.Value() + fmt.Printf("[%X]:\t[%X]\n", key, value) + } +} + +// Implements DB. +func (pdb *prefixDB) Stats() map[string]string { + stats := make(map[string]string) + stats["prefixdb.prefix.string"] = string(pdb.prefix) + stats["prefixdb.prefix.hex"] = fmt.Sprintf("%X", pdb.prefix) + source := pdb.db.Stats() + for key, value := range source { + stats["prefixdb.source."+key] = value + } + return stats +} + +func (pdb *prefixDB) prefixed(key []byte) []byte { + return append(cp(pdb.prefix), key...) +} + +//---------------------------------------- +// prefixBatch + +type prefixBatch struct { + prefix []byte + source Batch +} + +func newPrefixBatch(prefix []byte, source Batch) prefixBatch { + return prefixBatch{ + prefix: prefix, + source: source, + } +} + +func (pb prefixBatch) Set(key, value []byte) { + pkey := append(cp(pb.prefix), key...) + pb.source.Set(pkey, value) +} + +func (pb prefixBatch) Delete(key []byte) { + pkey := append(cp(pb.prefix), key...) + pb.source.Delete(pkey) +} + +func (pb prefixBatch) Write() { + pb.source.Write() +} + +func (pb prefixBatch) WriteSync() { + pb.source.WriteSync() +} + +//---------------------------------------- +// prefixIterator + +// Strips prefix while iterating from Iterator. +type prefixIterator struct { + prefix []byte + start []byte + end []byte + source Iterator + valid bool +} + +func newPrefixIterator(prefix, start, end []byte, source Iterator) prefixIterator { + if !source.Valid() || !bytes.HasPrefix(source.Key(), prefix) { + return prefixIterator{ + prefix: prefix, + start: start, + end: end, + source: source, + valid: false, + } + } else { + return prefixIterator{ + prefix: prefix, + start: start, + end: end, + source: source, + valid: true, + } + } +} + +func (itr prefixIterator) Domain() (start []byte, end []byte) { + return itr.start, itr.end +} + +func (itr prefixIterator) Valid() bool { + return itr.valid && itr.source.Valid() +} + +func (itr prefixIterator) Next() { + if !itr.valid { + panic("prefixIterator invalid, cannot call Next()") + } + itr.source.Next() + if !itr.source.Valid() || !bytes.HasPrefix(itr.source.Key(), itr.prefix) { + itr.source.Close() + itr.valid = false + return + } +} + +func (itr prefixIterator) Key() (key []byte) { + if !itr.valid { + panic("prefixIterator invalid, cannot call Key()") + } + return stripPrefix(itr.source.Key(), itr.prefix) +} + +func (itr prefixIterator) Value() (value []byte) { + if !itr.valid { + panic("prefixIterator invalid, cannot call Value()") + } + return itr.source.Value() +} + +func (itr prefixIterator) Close() { + itr.source.Close() +} + +//---------------------------------------- + +func stripPrefix(key []byte, prefix []byte) (stripped []byte) { + if len(key) < len(prefix) { + panic("should not happen") + } + if !bytes.Equal(key[:len(prefix)], prefix) { + panic("should not happne") + } + return key[len(prefix):] +} + +// If the first iterator item is skipKey, then +// skip it. +func skipOne(itr Iterator, skipKey []byte) { + if itr.Valid() { + if bytes.Equal(itr.Key(), skipKey) { + itr.Next() + } + } +} diff --git a/db/prefix_db_test.go b/db/prefix_db_test.go new file mode 100644 index 00000000..60809f15 --- /dev/null +++ b/db/prefix_db_test.go @@ -0,0 +1,147 @@ +package db + +import "testing" + +func mockDBWithStuff() DB { + db := NewMemDB() + // Under "key" prefix + db.Set(bz("key"), bz("value")) + db.Set(bz("key1"), bz("value1")) + db.Set(bz("key2"), bz("value2")) + db.Set(bz("key3"), bz("value3")) + db.Set(bz("something"), bz("else")) + db.Set(bz(""), bz("")) + db.Set(bz("k"), bz("val")) + db.Set(bz("ke"), bz("valu")) + db.Set(bz("kee"), bz("valuu")) + return db +} + +func TestPrefixDBSimple(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + checkValue(t, pdb, bz("key"), nil) + checkValue(t, pdb, bz(""), bz("value")) + checkValue(t, pdb, bz("key1"), nil) + checkValue(t, pdb, bz("1"), bz("value1")) + checkValue(t, pdb, bz("key2"), nil) + checkValue(t, pdb, bz("2"), bz("value2")) + checkValue(t, pdb, bz("key3"), nil) + checkValue(t, pdb, bz("3"), bz("value3")) + checkValue(t, pdb, bz("something"), nil) + checkValue(t, pdb, bz("k"), nil) + checkValue(t, pdb, bz("ke"), nil) + checkValue(t, pdb, bz("kee"), nil) +} + +func TestPrefixDBIterator1(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.Iterator(nil, nil) + checkDomain(t, itr, nil, nil) + checkItem(t, itr, bz(""), bz("value")) + checkNext(t, itr, true) + checkItem(t, itr, bz("1"), bz("value1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("2"), bz("value2")) + checkNext(t, itr, true) + checkItem(t, itr, bz("3"), bz("value3")) + checkNext(t, itr, false) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBIterator2(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.Iterator(nil, bz("")) + checkDomain(t, itr, nil, bz("")) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBIterator3(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.Iterator(bz(""), nil) + checkDomain(t, itr, bz(""), nil) + checkItem(t, itr, bz(""), bz("value")) + checkNext(t, itr, true) + checkItem(t, itr, bz("1"), bz("value1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("2"), bz("value2")) + checkNext(t, itr, true) + checkItem(t, itr, bz("3"), bz("value3")) + checkNext(t, itr, false) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBIterator4(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.Iterator(bz(""), bz("")) + checkDomain(t, itr, bz(""), bz("")) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBReverseIterator1(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.ReverseIterator(nil, nil) + checkDomain(t, itr, nil, nil) + checkItem(t, itr, bz("3"), bz("value3")) + checkNext(t, itr, true) + checkItem(t, itr, bz("2"), bz("value2")) + checkNext(t, itr, true) + checkItem(t, itr, bz("1"), bz("value1")) + checkNext(t, itr, true) + checkItem(t, itr, bz(""), bz("value")) + checkNext(t, itr, false) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBReverseIterator2(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.ReverseIterator(nil, bz("")) + checkDomain(t, itr, nil, bz("")) + checkItem(t, itr, bz("3"), bz("value3")) + checkNext(t, itr, true) + checkItem(t, itr, bz("2"), bz("value2")) + checkNext(t, itr, true) + checkItem(t, itr, bz("1"), bz("value1")) + checkNext(t, itr, false) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBReverseIterator3(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.ReverseIterator(bz(""), nil) + checkDomain(t, itr, bz(""), nil) + checkItem(t, itr, bz(""), bz("value")) + checkNext(t, itr, false) + checkInvalid(t, itr) + itr.Close() +} + +func TestPrefixDBReverseIterator4(t *testing.T) { + db := mockDBWithStuff() + pdb := NewPrefixDB(db, bz("key")) + + itr := pdb.ReverseIterator(bz(""), bz("")) + checkInvalid(t, itr) + itr.Close() +} diff --git a/db/remotedb/doc.go b/db/remotedb/doc.go new file mode 100644 index 00000000..07c95a56 --- /dev/null +++ b/db/remotedb/doc.go @@ -0,0 +1,37 @@ +/* +remotedb is a package for connecting to distributed Tendermint db.DB +instances. The purpose is to detach difficult deployments such as +CLevelDB that requires gcc or perhaps for databases that require +custom configurations such as extra disk space. It also eases +the burden and cost of deployment of dependencies for databases +to be used by Tendermint developers. Most importantly it is built +over the high performant gRPC transport. + +remotedb's RemoteDB implements db.DB so can be used normally +like other databases. One just has to explicitly connect to the +remote database with a client setup such as: + + client, err := remotedb.NewInsecure(addr) + // Make sure to invoke InitRemote! + if err := client.InitRemote(&remotedb.Init{Name: "test-remote-db", Type: "leveldb"}); err != nil { + log.Fatalf("Failed to initialize the remote db") + } + + client.Set(key1, value) + gv1 := client.SetSync(k2, v2) + + client.Delete(k1) + gv2 := client.Get(k1) + + for itr := client.Iterator(k1, k9); itr.Valid(); itr.Next() { + ik, iv := itr.Key(), itr.Value() + ds, de := itr.Domain() + } + + stats := client.Stats() + + if !client.Has(dk1) { + client.SetSync(dk1, dv1) + } +*/ +package remotedb diff --git a/db/remotedb/grpcdb/client.go b/db/remotedb/grpcdb/client.go new file mode 100644 index 00000000..86aa12c7 --- /dev/null +++ b/db/remotedb/grpcdb/client.go @@ -0,0 +1,30 @@ +package grpcdb + +import ( + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + protodb "github.com/tendermint/tmlibs/db/remotedb/proto" +) + +// Security defines how the client will talk to the gRPC server. +type Security uint + +const ( + Insecure Security = iota + Secure +) + +// NewClient creates a gRPC client connected to the bound gRPC server at serverAddr. +// Use kind to set the level of security to either Secure or Insecure. +func NewClient(serverAddr, serverCert string) (protodb.DBClient, error) { + creds, err := credentials.NewClientTLSFromFile(serverCert, "") + if err != nil { + return nil, err + } + cc, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(creds)) + if err != nil { + return nil, err + } + return protodb.NewDBClient(cc), nil +} diff --git a/db/remotedb/grpcdb/doc.go b/db/remotedb/grpcdb/doc.go new file mode 100644 index 00000000..0d8e380c --- /dev/null +++ b/db/remotedb/grpcdb/doc.go @@ -0,0 +1,32 @@ +/* +grpcdb is the distribution of Tendermint's db.DB instances using +the gRPC transport to decouple local db.DB usages from applications, +to using them over a network in a highly performant manner. + +grpcdb allows users to initialize a database's server like +they would locally and invoke the respective methods of db.DB. + +Most users shouldn't use this package, but should instead use +remotedb. Only the lower level users and database server deployers +should use it, for functionality such as: + + ln, err := net.Listen("tcp", "0.0.0.0:0") + srv := grpcdb.NewServer() + defer srv.Stop() + go func() { + if err := srv.Serve(ln); err != nil { + t.Fatalf("BindServer: %v", err) + } + }() + +or + addr := ":8998" + cert := "server.crt" + key := "server.key" + go func() { + if err := grpcdb.ListenAndServe(addr, cert, key); err != nil { + log.Fatalf("BindServer: %v", err) + } + }() +*/ +package grpcdb diff --git a/db/remotedb/grpcdb/example_test.go b/db/remotedb/grpcdb/example_test.go new file mode 100644 index 00000000..827a1cf3 --- /dev/null +++ b/db/remotedb/grpcdb/example_test.go @@ -0,0 +1,52 @@ +package grpcdb_test + +import ( + "bytes" + "context" + "log" + + grpcdb "github.com/tendermint/tmlibs/db/remotedb/grpcdb" + protodb "github.com/tendermint/tmlibs/db/remotedb/proto" +) + +func Example() { + addr := ":8998" + cert := "server.crt" + key := "server.key" + go func() { + if err := grpcdb.ListenAndServe(addr, cert, key); err != nil { + log.Fatalf("BindServer: %v", err) + } + }() + + client, err := grpcdb.NewClient(addr, cert) + if err != nil { + log.Fatalf("Failed to create grpcDB client: %v", err) + } + + ctx := context.Background() + // 1. Initialize the DB + in := &protodb.Init{ + Type: "leveldb", + Name: "grpc-uno-test", + Dir: ".", + } + if _, err := client.Init(ctx, in); err != nil { + log.Fatalf("Init error: %v", err) + } + + // 2. Now it can be used! + query1 := &protodb.Entity{Key: []byte("Project"), Value: []byte("Tmlibs-on-gRPC")} + if _, err := client.SetSync(ctx, query1); err != nil { + log.Fatalf("SetSync err: %v", err) + } + + query2 := &protodb.Entity{Key: []byte("Project")} + read, err := client.Get(ctx, query2) + if err != nil { + log.Fatalf("Get err: %v", err) + } + if g, w := read.Value, []byte("Tmlibs-on-gRPC"); !bytes.Equal(g, w) { + log.Fatalf("got= (%q ==> % X)\nwant=(%q ==> % X)", g, g, w, w) + } +} diff --git a/db/remotedb/grpcdb/server.go b/db/remotedb/grpcdb/server.go new file mode 100644 index 00000000..8320c051 --- /dev/null +++ b/db/remotedb/grpcdb/server.go @@ -0,0 +1,197 @@ +package grpcdb + +import ( + "context" + "net" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/tendermint/tmlibs/db" + protodb "github.com/tendermint/tmlibs/db/remotedb/proto" +) + +// ListenAndServe is a blocking function that sets up a gRPC based +// server at the address supplied, with the gRPC options passed in. +// Normally in usage, invoke it in a goroutine like you would for http.ListenAndServe. +func ListenAndServe(addr, cert, key string, opts ...grpc.ServerOption) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + srv, err := NewServer(cert, key, opts...) + if err != nil { + return err + } + return srv.Serve(ln) +} + +func NewServer(cert, key string, opts ...grpc.ServerOption) (*grpc.Server, error) { + creds, err := credentials.NewServerTLSFromFile(cert, key) + if err != nil { + return nil, err + } + opts = append(opts, grpc.Creds(creds)) + srv := grpc.NewServer(opts...) + protodb.RegisterDBServer(srv, new(server)) + return srv, nil +} + +type server struct { + mu sync.Mutex + db db.DB +} + +var _ protodb.DBServer = (*server)(nil) + +// Init initializes the server's database. Only one type of database +// can be initialized per server. +// +// Dir is the directory on the file system in which the DB will be stored(if backed by disk) (TODO: remove) +// +// Name is representative filesystem entry's basepath +// +// Type can be either one of: +// * cleveldb (if built with gcc enabled) +// * fsdb +// * memdB +// * leveldb +// See https://godoc.org/github.com/tendermint/tmlibs/db#DBBackendType +func (s *server) Init(ctx context.Context, in *protodb.Init) (*protodb.Entity, error) { + s.mu.Lock() + defer s.mu.Unlock() + + s.db = db.NewDB(in.Name, db.DBBackendType(in.Type), in.Dir) + return &protodb.Entity{CreatedAt: time.Now().Unix()}, nil +} + +func (s *server) Delete(ctx context.Context, in *protodb.Entity) (*protodb.Nothing, error) { + s.db.Delete(in.Key) + return nothing, nil +} + +var nothing = new(protodb.Nothing) + +func (s *server) DeleteSync(ctx context.Context, in *protodb.Entity) (*protodb.Nothing, error) { + s.db.DeleteSync(in.Key) + return nothing, nil +} + +func (s *server) Get(ctx context.Context, in *protodb.Entity) (*protodb.Entity, error) { + value := s.db.Get(in.Key) + return &protodb.Entity{Value: value}, nil +} + +func (s *server) GetStream(ds protodb.DB_GetStreamServer) error { + // Receive routine + responsesChan := make(chan *protodb.Entity) + go func() { + defer close(responsesChan) + ctx := context.Background() + for { + in, err := ds.Recv() + if err != nil { + responsesChan <- &protodb.Entity{Err: err.Error()} + return + } + out, err := s.Get(ctx, in) + if err != nil { + if out == nil { + out = new(protodb.Entity) + out.Key = in.Key + } + out.Err = err.Error() + responsesChan <- out + return + } + + // Otherwise continue on + responsesChan <- out + } + }() + + // Send routine, block until we return + for out := range responsesChan { + if err := ds.Send(out); err != nil { + return err + } + } + return nil +} + +func (s *server) Has(ctx context.Context, in *protodb.Entity) (*protodb.Entity, error) { + exists := s.db.Has(in.Key) + return &protodb.Entity{Exists: exists}, nil +} + +func (s *server) Set(ctx context.Context, in *protodb.Entity) (*protodb.Nothing, error) { + s.db.Set(in.Key, in.Value) + return nothing, nil +} + +func (s *server) SetSync(ctx context.Context, in *protodb.Entity) (*protodb.Nothing, error) { + s.db.SetSync(in.Key, in.Value) + return nothing, nil +} + +func (s *server) Iterator(query *protodb.Entity, dis protodb.DB_IteratorServer) error { + it := s.db.Iterator(query.Start, query.End) + return s.handleIterator(it, dis.Send) +} + +func (s *server) handleIterator(it db.Iterator, sendFunc func(*protodb.Iterator) error) error { + for it.Valid() { + start, end := it.Domain() + out := &protodb.Iterator{ + Domain: &protodb.Domain{Start: start, End: end}, + Valid: it.Valid(), + Key: it.Key(), + Value: it.Value(), + } + if err := sendFunc(out); err != nil { + return err + } + + // Finally move the iterator forward + it.Next() + } + return nil +} + +func (s *server) ReverseIterator(query *protodb.Entity, dis protodb.DB_ReverseIteratorServer) error { + it := s.db.ReverseIterator(query.Start, query.End) + return s.handleIterator(it, dis.Send) +} + +func (s *server) Stats(context.Context, *protodb.Nothing) (*protodb.Stats, error) { + stats := s.db.Stats() + return &protodb.Stats{Data: stats, TimeAt: time.Now().Unix()}, nil +} + +func (s *server) BatchWrite(c context.Context, b *protodb.Batch) (*protodb.Nothing, error) { + return s.batchWrite(c, b, false) +} + +func (s *server) BatchWriteSync(c context.Context, b *protodb.Batch) (*protodb.Nothing, error) { + return s.batchWrite(c, b, true) +} + +func (s *server) batchWrite(c context.Context, b *protodb.Batch, sync bool) (*protodb.Nothing, error) { + bat := s.db.NewBatch() + for _, op := range b.Ops { + switch op.Type { + case protodb.Operation_SET: + bat.Set(op.Entity.Key, op.Entity.Value) + case protodb.Operation_DELETE: + bat.Delete(op.Entity.Key) + } + } + if sync { + bat.WriteSync() + } else { + bat.Write() + } + return nothing, nil +} diff --git a/db/remotedb/proto/defs.pb.go b/db/remotedb/proto/defs.pb.go new file mode 100644 index 00000000..4d9f0b27 --- /dev/null +++ b/db/remotedb/proto/defs.pb.go @@ -0,0 +1,914 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: defs.proto + +/* +Package protodb is a generated protocol buffer package. + +It is generated from these files: + defs.proto + +It has these top-level messages: + Batch + Operation + Entity + Nothing + Domain + Iterator + Stats + Init +*/ +package protodb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Operation_Type int32 + +const ( + Operation_SET Operation_Type = 0 + Operation_DELETE Operation_Type = 1 +) + +var Operation_Type_name = map[int32]string{ + 0: "SET", + 1: "DELETE", +} +var Operation_Type_value = map[string]int32{ + "SET": 0, + "DELETE": 1, +} + +func (x Operation_Type) String() string { + return proto.EnumName(Operation_Type_name, int32(x)) +} +func (Operation_Type) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{1, 0} } + +type Batch struct { + Ops []*Operation `protobuf:"bytes,1,rep,name=ops" json:"ops,omitempty"` +} + +func (m *Batch) Reset() { *m = Batch{} } +func (m *Batch) String() string { return proto.CompactTextString(m) } +func (*Batch) ProtoMessage() {} +func (*Batch) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *Batch) GetOps() []*Operation { + if m != nil { + return m.Ops + } + return nil +} + +type Operation struct { + Entity *Entity `protobuf:"bytes,1,opt,name=entity" json:"entity,omitempty"` + Type Operation_Type `protobuf:"varint,2,opt,name=type,enum=protodb.Operation_Type" json:"type,omitempty"` +} + +func (m *Operation) Reset() { *m = Operation{} } +func (m *Operation) String() string { return proto.CompactTextString(m) } +func (*Operation) ProtoMessage() {} +func (*Operation) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *Operation) GetEntity() *Entity { + if m != nil { + return m.Entity + } + return nil +} + +func (m *Operation) GetType() Operation_Type { + if m != nil { + return m.Type + } + return Operation_SET +} + +type Entity struct { + Id int32 `protobuf:"varint,1,opt,name=id" json:"id,omitempty"` + Key []byte `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + Value []byte `protobuf:"bytes,3,opt,name=value,proto3" json:"value,omitempty"` + Exists bool `protobuf:"varint,4,opt,name=exists" json:"exists,omitempty"` + Start []byte `protobuf:"bytes,5,opt,name=start,proto3" json:"start,omitempty"` + End []byte `protobuf:"bytes,6,opt,name=end,proto3" json:"end,omitempty"` + Err string `protobuf:"bytes,7,opt,name=err" json:"err,omitempty"` + CreatedAt int64 `protobuf:"varint,8,opt,name=created_at,json=createdAt" json:"created_at,omitempty"` +} + +func (m *Entity) Reset() { *m = Entity{} } +func (m *Entity) String() string { return proto.CompactTextString(m) } +func (*Entity) ProtoMessage() {} +func (*Entity) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +func (m *Entity) GetId() int32 { + if m != nil { + return m.Id + } + return 0 +} + +func (m *Entity) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *Entity) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +func (m *Entity) GetExists() bool { + if m != nil { + return m.Exists + } + return false +} + +func (m *Entity) GetStart() []byte { + if m != nil { + return m.Start + } + return nil +} + +func (m *Entity) GetEnd() []byte { + if m != nil { + return m.End + } + return nil +} + +func (m *Entity) GetErr() string { + if m != nil { + return m.Err + } + return "" +} + +func (m *Entity) GetCreatedAt() int64 { + if m != nil { + return m.CreatedAt + } + return 0 +} + +type Nothing struct { +} + +func (m *Nothing) Reset() { *m = Nothing{} } +func (m *Nothing) String() string { return proto.CompactTextString(m) } +func (*Nothing) ProtoMessage() {} +func (*Nothing) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } + +type Domain struct { + Start []byte `protobuf:"bytes,1,opt,name=start,proto3" json:"start,omitempty"` + End []byte `protobuf:"bytes,2,opt,name=end,proto3" json:"end,omitempty"` +} + +func (m *Domain) Reset() { *m = Domain{} } +func (m *Domain) String() string { return proto.CompactTextString(m) } +func (*Domain) ProtoMessage() {} +func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } + +func (m *Domain) GetStart() []byte { + if m != nil { + return m.Start + } + return nil +} + +func (m *Domain) GetEnd() []byte { + if m != nil { + return m.End + } + return nil +} + +type Iterator struct { + Domain *Domain `protobuf:"bytes,1,opt,name=domain" json:"domain,omitempty"` + Valid bool `protobuf:"varint,2,opt,name=valid" json:"valid,omitempty"` + Key []byte `protobuf:"bytes,3,opt,name=key,proto3" json:"key,omitempty"` + Value []byte `protobuf:"bytes,4,opt,name=value,proto3" json:"value,omitempty"` +} + +func (m *Iterator) Reset() { *m = Iterator{} } +func (m *Iterator) String() string { return proto.CompactTextString(m) } +func (*Iterator) ProtoMessage() {} +func (*Iterator) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +func (m *Iterator) GetDomain() *Domain { + if m != nil { + return m.Domain + } + return nil +} + +func (m *Iterator) GetValid() bool { + if m != nil { + return m.Valid + } + return false +} + +func (m *Iterator) GetKey() []byte { + if m != nil { + return m.Key + } + return nil +} + +func (m *Iterator) GetValue() []byte { + if m != nil { + return m.Value + } + return nil +} + +type Stats struct { + Data map[string]string `protobuf:"bytes,1,rep,name=data" json:"data,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + TimeAt int64 `protobuf:"varint,2,opt,name=time_at,json=timeAt" json:"time_at,omitempty"` +} + +func (m *Stats) Reset() { *m = Stats{} } +func (m *Stats) String() string { return proto.CompactTextString(m) } +func (*Stats) ProtoMessage() {} +func (*Stats) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + +func (m *Stats) GetData() map[string]string { + if m != nil { + return m.Data + } + return nil +} + +func (m *Stats) GetTimeAt() int64 { + if m != nil { + return m.TimeAt + } + return 0 +} + +type Init struct { + Type string `protobuf:"bytes,1,opt,name=Type" json:"Type,omitempty"` + Name string `protobuf:"bytes,2,opt,name=Name" json:"Name,omitempty"` + Dir string `protobuf:"bytes,3,opt,name=Dir" json:"Dir,omitempty"` +} + +func (m *Init) Reset() { *m = Init{} } +func (m *Init) String() string { return proto.CompactTextString(m) } +func (*Init) ProtoMessage() {} +func (*Init) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } + +func (m *Init) GetType() string { + if m != nil { + return m.Type + } + return "" +} + +func (m *Init) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *Init) GetDir() string { + if m != nil { + return m.Dir + } + return "" +} + +func init() { + proto.RegisterType((*Batch)(nil), "protodb.Batch") + proto.RegisterType((*Operation)(nil), "protodb.Operation") + proto.RegisterType((*Entity)(nil), "protodb.Entity") + proto.RegisterType((*Nothing)(nil), "protodb.Nothing") + proto.RegisterType((*Domain)(nil), "protodb.Domain") + proto.RegisterType((*Iterator)(nil), "protodb.Iterator") + proto.RegisterType((*Stats)(nil), "protodb.Stats") + proto.RegisterType((*Init)(nil), "protodb.Init") + proto.RegisterEnum("protodb.Operation_Type", Operation_Type_name, Operation_Type_value) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for DB service + +type DBClient interface { + Init(ctx context.Context, in *Init, opts ...grpc.CallOption) (*Entity, error) + Get(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Entity, error) + GetStream(ctx context.Context, opts ...grpc.CallOption) (DB_GetStreamClient, error) + Has(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Entity, error) + Set(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) + SetSync(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) + Delete(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) + DeleteSync(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) + Iterator(ctx context.Context, in *Entity, opts ...grpc.CallOption) (DB_IteratorClient, error) + ReverseIterator(ctx context.Context, in *Entity, opts ...grpc.CallOption) (DB_ReverseIteratorClient, error) + // rpc print(Nothing) returns (Entity) {} + Stats(ctx context.Context, in *Nothing, opts ...grpc.CallOption) (*Stats, error) + BatchWrite(ctx context.Context, in *Batch, opts ...grpc.CallOption) (*Nothing, error) + BatchWriteSync(ctx context.Context, in *Batch, opts ...grpc.CallOption) (*Nothing, error) +} + +type dBClient struct { + cc *grpc.ClientConn +} + +func NewDBClient(cc *grpc.ClientConn) DBClient { + return &dBClient{cc} +} + +func (c *dBClient) Init(ctx context.Context, in *Init, opts ...grpc.CallOption) (*Entity, error) { + out := new(Entity) + err := grpc.Invoke(ctx, "/protodb.DB/init", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) Get(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Entity, error) { + out := new(Entity) + err := grpc.Invoke(ctx, "/protodb.DB/get", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) GetStream(ctx context.Context, opts ...grpc.CallOption) (DB_GetStreamClient, error) { + stream, err := grpc.NewClientStream(ctx, &_DB_serviceDesc.Streams[0], c.cc, "/protodb.DB/getStream", opts...) + if err != nil { + return nil, err + } + x := &dBGetStreamClient{stream} + return x, nil +} + +type DB_GetStreamClient interface { + Send(*Entity) error + Recv() (*Entity, error) + grpc.ClientStream +} + +type dBGetStreamClient struct { + grpc.ClientStream +} + +func (x *dBGetStreamClient) Send(m *Entity) error { + return x.ClientStream.SendMsg(m) +} + +func (x *dBGetStreamClient) Recv() (*Entity, error) { + m := new(Entity) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *dBClient) Has(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Entity, error) { + out := new(Entity) + err := grpc.Invoke(ctx, "/protodb.DB/has", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) Set(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/set", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) SetSync(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/setSync", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) Delete(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/delete", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) DeleteSync(ctx context.Context, in *Entity, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/deleteSync", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) Iterator(ctx context.Context, in *Entity, opts ...grpc.CallOption) (DB_IteratorClient, error) { + stream, err := grpc.NewClientStream(ctx, &_DB_serviceDesc.Streams[1], c.cc, "/protodb.DB/iterator", opts...) + if err != nil { + return nil, err + } + x := &dBIteratorClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type DB_IteratorClient interface { + Recv() (*Iterator, error) + grpc.ClientStream +} + +type dBIteratorClient struct { + grpc.ClientStream +} + +func (x *dBIteratorClient) Recv() (*Iterator, error) { + m := new(Iterator) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *dBClient) ReverseIterator(ctx context.Context, in *Entity, opts ...grpc.CallOption) (DB_ReverseIteratorClient, error) { + stream, err := grpc.NewClientStream(ctx, &_DB_serviceDesc.Streams[2], c.cc, "/protodb.DB/reverseIterator", opts...) + if err != nil { + return nil, err + } + x := &dBReverseIteratorClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type DB_ReverseIteratorClient interface { + Recv() (*Iterator, error) + grpc.ClientStream +} + +type dBReverseIteratorClient struct { + grpc.ClientStream +} + +func (x *dBReverseIteratorClient) Recv() (*Iterator, error) { + m := new(Iterator) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *dBClient) Stats(ctx context.Context, in *Nothing, opts ...grpc.CallOption) (*Stats, error) { + out := new(Stats) + err := grpc.Invoke(ctx, "/protodb.DB/stats", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) BatchWrite(ctx context.Context, in *Batch, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/batchWrite", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dBClient) BatchWriteSync(ctx context.Context, in *Batch, opts ...grpc.CallOption) (*Nothing, error) { + out := new(Nothing) + err := grpc.Invoke(ctx, "/protodb.DB/batchWriteSync", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for DB service + +type DBServer interface { + Init(context.Context, *Init) (*Entity, error) + Get(context.Context, *Entity) (*Entity, error) + GetStream(DB_GetStreamServer) error + Has(context.Context, *Entity) (*Entity, error) + Set(context.Context, *Entity) (*Nothing, error) + SetSync(context.Context, *Entity) (*Nothing, error) + Delete(context.Context, *Entity) (*Nothing, error) + DeleteSync(context.Context, *Entity) (*Nothing, error) + Iterator(*Entity, DB_IteratorServer) error + ReverseIterator(*Entity, DB_ReverseIteratorServer) error + // rpc print(Nothing) returns (Entity) {} + Stats(context.Context, *Nothing) (*Stats, error) + BatchWrite(context.Context, *Batch) (*Nothing, error) + BatchWriteSync(context.Context, *Batch) (*Nothing, error) +} + +func RegisterDBServer(s *grpc.Server, srv DBServer) { + s.RegisterService(&_DB_serviceDesc, srv) +} + +func _DB_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Init) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Init(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Init", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Init(ctx, req.(*Init)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Get(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Get", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Get(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_GetStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(DBServer).GetStream(&dBGetStreamServer{stream}) +} + +type DB_GetStreamServer interface { + Send(*Entity) error + Recv() (*Entity, error) + grpc.ServerStream +} + +type dBGetStreamServer struct { + grpc.ServerStream +} + +func (x *dBGetStreamServer) Send(m *Entity) error { + return x.ServerStream.SendMsg(m) +} + +func (x *dBGetStreamServer) Recv() (*Entity, error) { + m := new(Entity) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _DB_Has_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Has(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Has", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Has(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_Set_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Set(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Set", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Set(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_SetSync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).SetSync(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/SetSync", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).SetSync(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_Delete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Delete(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Delete", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Delete(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_DeleteSync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Entity) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).DeleteSync(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/DeleteSync", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).DeleteSync(ctx, req.(*Entity)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_Iterator_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Entity) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DBServer).Iterator(m, &dBIteratorServer{stream}) +} + +type DB_IteratorServer interface { + Send(*Iterator) error + grpc.ServerStream +} + +type dBIteratorServer struct { + grpc.ServerStream +} + +func (x *dBIteratorServer) Send(m *Iterator) error { + return x.ServerStream.SendMsg(m) +} + +func _DB_ReverseIterator_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Entity) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DBServer).ReverseIterator(m, &dBReverseIteratorServer{stream}) +} + +type DB_ReverseIteratorServer interface { + Send(*Iterator) error + grpc.ServerStream +} + +type dBReverseIteratorServer struct { + grpc.ServerStream +} + +func (x *dBReverseIteratorServer) Send(m *Iterator) error { + return x.ServerStream.SendMsg(m) +} + +func _DB_Stats_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Nothing) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).Stats(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/Stats", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).Stats(ctx, req.(*Nothing)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_BatchWrite_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Batch) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).BatchWrite(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/BatchWrite", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).BatchWrite(ctx, req.(*Batch)) + } + return interceptor(ctx, in, info, handler) +} + +func _DB_BatchWriteSync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Batch) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DBServer).BatchWriteSync(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/protodb.DB/BatchWriteSync", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DBServer).BatchWriteSync(ctx, req.(*Batch)) + } + return interceptor(ctx, in, info, handler) +} + +var _DB_serviceDesc = grpc.ServiceDesc{ + ServiceName: "protodb.DB", + HandlerType: (*DBServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "init", + Handler: _DB_Init_Handler, + }, + { + MethodName: "get", + Handler: _DB_Get_Handler, + }, + { + MethodName: "has", + Handler: _DB_Has_Handler, + }, + { + MethodName: "set", + Handler: _DB_Set_Handler, + }, + { + MethodName: "setSync", + Handler: _DB_SetSync_Handler, + }, + { + MethodName: "delete", + Handler: _DB_Delete_Handler, + }, + { + MethodName: "deleteSync", + Handler: _DB_DeleteSync_Handler, + }, + { + MethodName: "stats", + Handler: _DB_Stats_Handler, + }, + { + MethodName: "batchWrite", + Handler: _DB_BatchWrite_Handler, + }, + { + MethodName: "batchWriteSync", + Handler: _DB_BatchWriteSync_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "getStream", + Handler: _DB_GetStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + { + StreamName: "iterator", + Handler: _DB_Iterator_Handler, + ServerStreams: true, + }, + { + StreamName: "reverseIterator", + Handler: _DB_ReverseIterator_Handler, + ServerStreams: true, + }, + }, + Metadata: "defs.proto", +} + +func init() { proto.RegisterFile("defs.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 606 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x54, 0x4f, 0x6f, 0xd3, 0x4e, + 0x10, 0xcd, 0xda, 0x8e, 0x13, 0x4f, 0x7f, 0xbf, 0x34, 0x8c, 0x10, 0xb5, 0x8a, 0x90, 0x22, 0x0b, + 0x09, 0x43, 0x69, 0x14, 0x52, 0x24, 0xfe, 0x9c, 0x68, 0x95, 0x1c, 0x2a, 0xa1, 0x22, 0x39, 0x95, + 0x38, 0xa2, 0x6d, 0x3d, 0x34, 0x2b, 0x1a, 0x3b, 0xac, 0x87, 0x8a, 0x5c, 0xb8, 0xf2, 0x79, 0xf8, + 0x7c, 0x5c, 0xd0, 0xae, 0x1d, 0x87, 0x36, 0x39, 0x84, 0x53, 0x76, 0x66, 0xde, 0x7b, 0xb3, 0xf3, + 0x32, 0x5e, 0x80, 0x94, 0x3e, 0x17, 0xfd, 0xb9, 0xce, 0x39, 0xc7, 0x96, 0xfd, 0x49, 0x2f, 0xa2, + 0x43, 0x68, 0x9e, 0x48, 0xbe, 0x9c, 0xe2, 0x63, 0x70, 0xf3, 0x79, 0x11, 0x8a, 0x9e, 0x1b, 0xef, + 0x0c, 0xb1, 0x5f, 0xd5, 0xfb, 0x1f, 0xe6, 0xa4, 0x25, 0xab, 0x3c, 0x4b, 0x4c, 0x39, 0xfa, 0x01, + 0x41, 0x9d, 0xc1, 0x27, 0xe0, 0x53, 0xc6, 0x8a, 0x17, 0xa1, 0xe8, 0x89, 0x78, 0x67, 0xb8, 0x5b, + 0xb3, 0xc6, 0x36, 0x9d, 0x54, 0x65, 0x3c, 0x00, 0x8f, 0x17, 0x73, 0x0a, 0x9d, 0x9e, 0x88, 0x3b, + 0xc3, 0xbd, 0x75, 0xf1, 0xfe, 0xf9, 0x62, 0x4e, 0x89, 0x05, 0x45, 0x0f, 0xc1, 0x33, 0x11, 0xb6, + 0xc0, 0x9d, 0x8c, 0xcf, 0xbb, 0x0d, 0x04, 0xf0, 0x47, 0xe3, 0xf7, 0xe3, 0xf3, 0x71, 0x57, 0x44, + 0xbf, 0x04, 0xf8, 0xa5, 0x38, 0x76, 0xc0, 0x51, 0xa9, 0xed, 0xdc, 0x4c, 0x1c, 0x95, 0x62, 0x17, + 0xdc, 0x2f, 0xb4, 0xb0, 0x3d, 0xfe, 0x4b, 0xcc, 0x11, 0xef, 0x43, 0xf3, 0x46, 0x5e, 0x7f, 0xa3, + 0xd0, 0xb5, 0xb9, 0x32, 0xc0, 0x07, 0xe0, 0xd3, 0x77, 0x55, 0x70, 0x11, 0x7a, 0x3d, 0x11, 0xb7, + 0x93, 0x2a, 0x32, 0xe8, 0x82, 0xa5, 0xe6, 0xb0, 0x59, 0xa2, 0x6d, 0x60, 0x54, 0x29, 0x4b, 0x43, + 0xbf, 0x54, 0xa5, 0xcc, 0xf6, 0x21, 0xad, 0xc3, 0x56, 0x4f, 0xc4, 0x41, 0x62, 0x8e, 0xf8, 0x08, + 0xe0, 0x52, 0x93, 0x64, 0x4a, 0x3f, 0x49, 0x0e, 0xdb, 0x3d, 0x11, 0xbb, 0x49, 0x50, 0x65, 0x8e, + 0x39, 0x0a, 0xa0, 0x75, 0x96, 0xf3, 0x54, 0x65, 0x57, 0xd1, 0x00, 0xfc, 0x51, 0x3e, 0x93, 0x2a, + 0x5b, 0x75, 0x13, 0x1b, 0xba, 0x39, 0x75, 0xb7, 0xe8, 0x2b, 0xb4, 0x4f, 0xd9, 0xb8, 0x94, 0x6b, + 0xe3, 0x77, 0x6a, 0xd9, 0x6b, 0x7e, 0x97, 0xa2, 0x49, 0x55, 0xae, 0x06, 0x57, 0xa5, 0x50, 0x3b, + 0x29, 0x83, 0xa5, 0x41, 0xee, 0x06, 0x83, 0xbc, 0xbf, 0x0c, 0x8a, 0x7e, 0x0a, 0x68, 0x4e, 0x58, + 0x72, 0x81, 0xcf, 0xc1, 0x4b, 0x25, 0xcb, 0x6a, 0x29, 0xc2, 0xba, 0x9d, 0xad, 0xf6, 0x47, 0x92, + 0xe5, 0x38, 0x63, 0xbd, 0x48, 0x2c, 0x0a, 0xf7, 0xa0, 0xc5, 0x6a, 0x46, 0xc6, 0x03, 0xc7, 0x7a, + 0xe0, 0x9b, 0xf0, 0x98, 0xf7, 0x5f, 0x41, 0x50, 0x63, 0x97, 0xb7, 0x10, 0xa5, 0x7d, 0xb7, 0x6e, + 0xe1, 0xd8, 0x5c, 0x19, 0xbc, 0x75, 0x5e, 0x8b, 0xe8, 0x1d, 0x78, 0xa7, 0x99, 0x62, 0xc4, 0x72, + 0x25, 0x2a, 0x52, 0xb9, 0x1e, 0x08, 0xde, 0x99, 0x9c, 0x2d, 0x49, 0xf6, 0x6c, 0xb4, 0x47, 0x4a, + 0xdb, 0x09, 0x83, 0xc4, 0x1c, 0x87, 0xbf, 0x3d, 0x70, 0x46, 0x27, 0x18, 0x83, 0xa7, 0x8c, 0xd0, + 0xff, 0xf5, 0x08, 0x46, 0x77, 0xff, 0xee, 0xc2, 0x46, 0x0d, 0x7c, 0x0a, 0xee, 0x15, 0x31, 0xde, + 0xad, 0x6c, 0x82, 0x1e, 0x41, 0x70, 0x45, 0x3c, 0x61, 0x4d, 0x72, 0xb6, 0x0d, 0x21, 0x16, 0x03, + 0x61, 0xf4, 0xa7, 0xb2, 0xd8, 0x4a, 0xff, 0x19, 0xb8, 0xc5, 0xa6, 0xab, 0x74, 0xeb, 0xc4, 0x72, + 0xad, 0x1a, 0xd8, 0x87, 0x56, 0x41, 0x3c, 0x59, 0x64, 0x97, 0xdb, 0xe1, 0x0f, 0xc1, 0x4f, 0xe9, + 0x9a, 0x98, 0xb6, 0x83, 0xbf, 0x30, 0x8f, 0x87, 0x81, 0x6f, 0xdf, 0x61, 0x08, 0x6d, 0xb5, 0x5c, + 0xdc, 0x35, 0xc2, 0xbd, 0xd5, 0xff, 0x50, 0x61, 0xa2, 0xc6, 0x40, 0xe0, 0x1b, 0xd8, 0xd5, 0x74, + 0x43, 0xba, 0xa0, 0xd3, 0x7f, 0xa5, 0x1e, 0xd8, 0xef, 0x89, 0x0b, 0x5c, 0xbb, 0xcb, 0x7e, 0xe7, + 0xf6, 0xde, 0x46, 0x0d, 0x1c, 0x00, 0x5c, 0x98, 0x47, 0xef, 0xa3, 0x56, 0x4c, 0xb8, 0xaa, 0xdb, + 0x97, 0x70, 0xe3, 0x34, 0x2f, 0xa1, 0xb3, 0x62, 0x58, 0x13, 0xb6, 0x60, 0x5d, 0xf8, 0x36, 0x75, + 0xf4, 0x27, 0x00, 0x00, 0xff, 0xff, 0x95, 0xf4, 0xe3, 0x82, 0x7a, 0x05, 0x00, 0x00, +} diff --git a/db/remotedb/proto/defs.proto b/db/remotedb/proto/defs.proto new file mode 100644 index 00000000..70471f23 --- /dev/null +++ b/db/remotedb/proto/defs.proto @@ -0,0 +1,71 @@ +syntax = "proto3"; + +package protodb; + +message Batch { + repeated Operation ops = 1; +} + +message Operation { + Entity entity = 1; + enum Type { + SET = 0; + DELETE = 1; + } + Type type = 2; +} + +message Entity { + int32 id = 1; + bytes key = 2; + bytes value = 3; + bool exists = 4; + bytes start = 5; + bytes end = 6; + string err = 7; + int64 created_at = 8; +} + +message Nothing { +} + +message Domain { + bytes start = 1; + bytes end = 2; +} + +message Iterator { + Domain domain = 1; + bool valid = 2; + bytes key = 3; + bytes value = 4; +} + +message Stats { + map data = 1; + int64 time_at = 2; +} + +message Init { + string Type = 1; + string Name = 2; + string Dir = 3; +} + +service DB { + rpc init(Init) returns (Entity) {} + rpc get(Entity) returns (Entity) {} + rpc getStream(stream Entity) returns (stream Entity) {} + + rpc has(Entity) returns (Entity) {} + rpc set(Entity) returns (Nothing) {} + rpc setSync(Entity) returns (Nothing) {} + rpc delete(Entity) returns (Nothing) {} + rpc deleteSync(Entity) returns (Nothing) {} + rpc iterator(Entity) returns (stream Iterator) {} + rpc reverseIterator(Entity) returns (stream Iterator) {} + // rpc print(Nothing) returns (Entity) {} + rpc stats(Nothing) returns (Stats) {} + rpc batchWrite(Batch) returns (Nothing) {} + rpc batchWriteSync(Batch) returns (Nothing) {} +} diff --git a/db/remotedb/remotedb.go b/db/remotedb/remotedb.go new file mode 100644 index 00000000..5332bd68 --- /dev/null +++ b/db/remotedb/remotedb.go @@ -0,0 +1,262 @@ +package remotedb + +import ( + "context" + "fmt" + + "github.com/tendermint/tmlibs/db" + "github.com/tendermint/tmlibs/db/remotedb/grpcdb" + protodb "github.com/tendermint/tmlibs/db/remotedb/proto" +) + +type RemoteDB struct { + ctx context.Context + dc protodb.DBClient +} + +func NewRemoteDB(serverAddr string, serverKey string) (*RemoteDB, error) { + return newRemoteDB(grpcdb.NewClient(serverAddr, serverKey)) +} + +func newRemoteDB(gdc protodb.DBClient, err error) (*RemoteDB, error) { + if err != nil { + return nil, err + } + return &RemoteDB{dc: gdc, ctx: context.Background()}, nil +} + +type Init struct { + Dir string + Name string + Type string +} + +func (rd *RemoteDB) InitRemote(in *Init) error { + _, err := rd.dc.Init(rd.ctx, &protodb.Init{Dir: in.Dir, Type: in.Type, Name: in.Name}) + return err +} + +var _ db.DB = (*RemoteDB)(nil) + +// Close is a noop currently +func (rd *RemoteDB) Close() { +} + +func (rd *RemoteDB) Delete(key []byte) { + if _, err := rd.dc.Delete(rd.ctx, &protodb.Entity{Key: key}); err != nil { + panic(fmt.Sprintf("RemoteDB.Delete: %v", err)) + } +} + +func (rd *RemoteDB) DeleteSync(key []byte) { + if _, err := rd.dc.DeleteSync(rd.ctx, &protodb.Entity{Key: key}); err != nil { + panic(fmt.Sprintf("RemoteDB.DeleteSync: %v", err)) + } +} + +func (rd *RemoteDB) Set(key, value []byte) { + if _, err := rd.dc.Set(rd.ctx, &protodb.Entity{Key: key, Value: value}); err != nil { + panic(fmt.Sprintf("RemoteDB.Set: %v", err)) + } +} + +func (rd *RemoteDB) SetSync(key, value []byte) { + if _, err := rd.dc.SetSync(rd.ctx, &protodb.Entity{Key: key, Value: value}); err != nil { + panic(fmt.Sprintf("RemoteDB.SetSync: %v", err)) + } +} + +func (rd *RemoteDB) Get(key []byte) []byte { + res, err := rd.dc.Get(rd.ctx, &protodb.Entity{Key: key}) + if err != nil { + panic(fmt.Sprintf("RemoteDB.Get error: %v", err)) + } + return res.Value +} + +func (rd *RemoteDB) Has(key []byte) bool { + res, err := rd.dc.Has(rd.ctx, &protodb.Entity{Key: key}) + if err != nil { + panic(fmt.Sprintf("RemoteDB.Has error: %v", err)) + } + return res.Exists +} + +func (rd *RemoteDB) ReverseIterator(start, end []byte) db.Iterator { + dic, err := rd.dc.ReverseIterator(rd.ctx, &protodb.Entity{Start: start, End: end}) + if err != nil { + panic(fmt.Sprintf("RemoteDB.Iterator error: %v", err)) + } + return makeReverseIterator(dic) +} + +func (rd *RemoteDB) NewBatch() db.Batch { + return &batch{ + db: rd, + ops: nil, + } +} + +// TODO: Implement Print when db.DB implements a method +// to print to a string and not db.Print to stdout. +func (rd *RemoteDB) Print() { + panic("Unimplemented") +} + +func (rd *RemoteDB) Stats() map[string]string { + stats, err := rd.dc.Stats(rd.ctx, &protodb.Nothing{}) + if err != nil { + panic(fmt.Sprintf("RemoteDB.Stats error: %v", err)) + } + if stats == nil { + return nil + } + return stats.Data +} + +func (rd *RemoteDB) Iterator(start, end []byte) db.Iterator { + dic, err := rd.dc.Iterator(rd.ctx, &protodb.Entity{Start: start, End: end}) + if err != nil { + panic(fmt.Sprintf("RemoteDB.Iterator error: %v", err)) + } + return makeIterator(dic) +} + +func makeIterator(dic protodb.DB_IteratorClient) db.Iterator { + return &iterator{dic: dic} +} + +func makeReverseIterator(dric protodb.DB_ReverseIteratorClient) db.Iterator { + return &reverseIterator{dric: dric} +} + +type reverseIterator struct { + dric protodb.DB_ReverseIteratorClient + cur *protodb.Iterator +} + +var _ db.Iterator = (*iterator)(nil) + +func (rItr *reverseIterator) Valid() bool { + return rItr.cur != nil && rItr.cur.Valid +} + +func (rItr *reverseIterator) Domain() (start, end []byte) { + if rItr.cur == nil || rItr.cur.Domain == nil { + return nil, nil + } + return rItr.cur.Domain.Start, rItr.cur.Domain.End +} + +// Next advances the current reverseIterator +func (rItr *reverseIterator) Next() { + var err error + rItr.cur, err = rItr.dric.Recv() + if err != nil { + panic(fmt.Sprintf("RemoteDB.ReverseIterator.Next error: %v", err)) + } +} + +func (rItr *reverseIterator) Key() []byte { + if rItr.cur == nil { + return nil + } + return rItr.cur.Key +} + +func (rItr *reverseIterator) Value() []byte { + if rItr.cur == nil { + return nil + } + return rItr.cur.Value +} + +func (rItr *reverseIterator) Close() { +} + +// iterator implements the db.Iterator by retrieving +// streamed iterators from the remote backend as +// needed. It is NOT safe for concurrent usage, +// matching the behavior of other iterators. +type iterator struct { + dic protodb.DB_IteratorClient + cur *protodb.Iterator +} + +var _ db.Iterator = (*iterator)(nil) + +func (itr *iterator) Valid() bool { + return itr.cur != nil && itr.cur.Valid +} + +func (itr *iterator) Domain() (start, end []byte) { + if itr.cur == nil || itr.cur.Domain == nil { + return nil, nil + } + return itr.cur.Domain.Start, itr.cur.Domain.End +} + +// Next advances the current iterator +func (itr *iterator) Next() { + var err error + itr.cur, err = itr.dic.Recv() + if err != nil { + panic(fmt.Sprintf("RemoteDB.Iterator.Next error: %v", err)) + } +} + +func (itr *iterator) Key() []byte { + if itr.cur == nil { + return nil + } + return itr.cur.Key +} + +func (itr *iterator) Value() []byte { + if itr.cur == nil { + return nil + } + return itr.cur.Value +} + +func (itr *iterator) Close() { + err := itr.dic.CloseSend() + if err != nil { + panic(fmt.Sprintf("Error closing iterator: %v", err)) + } +} + +type batch struct { + db *RemoteDB + ops []*protodb.Operation +} + +var _ db.Batch = (*batch)(nil) + +func (bat *batch) Set(key, value []byte) { + op := &protodb.Operation{ + Entity: &protodb.Entity{Key: key, Value: value}, + Type: protodb.Operation_SET, + } + bat.ops = append(bat.ops, op) +} + +func (bat *batch) Delete(key []byte) { + op := &protodb.Operation{ + Entity: &protodb.Entity{Key: key}, + Type: protodb.Operation_DELETE, + } + bat.ops = append(bat.ops, op) +} + +func (bat *batch) Write() { + if _, err := bat.db.dc.BatchWrite(bat.db.ctx, &protodb.Batch{Ops: bat.ops}); err != nil { + panic(fmt.Sprintf("RemoteDB.BatchWrite: %v", err)) + } +} + +func (bat *batch) WriteSync() { + if _, err := bat.db.dc.BatchWriteSync(bat.db.ctx, &protodb.Batch{Ops: bat.ops}); err != nil { + panic(fmt.Sprintf("RemoteDB.BatchWriteSync: %v", err)) + } +} diff --git a/db/remotedb/remotedb_test.go b/db/remotedb/remotedb_test.go new file mode 100644 index 00000000..3cf698a6 --- /dev/null +++ b/db/remotedb/remotedb_test.go @@ -0,0 +1,123 @@ +package remotedb_test + +import ( + "net" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/tendermint/tmlibs/db/remotedb" + "github.com/tendermint/tmlibs/db/remotedb/grpcdb" +) + +func TestRemoteDB(t *testing.T) { + cert := "::.crt" + key := "::.key" + ln, err := net.Listen("tcp", "0.0.0.0:0") + require.Nil(t, err, "expecting a port to have been assigned on which we can listen") + srv, err := grpcdb.NewServer(cert, key) + require.Nil(t, err) + defer srv.Stop() + go func() { + if err := srv.Serve(ln); err != nil { + t.Fatalf("BindServer: %v", err) + } + }() + + client, err := remotedb.NewRemoteDB(ln.Addr().String(), cert) + require.Nil(t, err, "expecting a successful client creation") + dbName := "test-remote-db" + require.Nil(t, client.InitRemote(&remotedb.Init{Name: dbName, Type: "leveldb"})) + defer func() { + err := os.RemoveAll(dbName + ".db") + if err != nil { + panic(err) + } + }() + + k1 := []byte("key-1") + v1 := client.Get(k1) + require.Equal(t, 0, len(v1), "expecting no key1 to have been stored, got %X (%s)", v1, v1) + vv1 := []byte("value-1") + client.Set(k1, vv1) + gv1 := client.Get(k1) + require.Equal(t, gv1, vv1) + + // Simple iteration + itr := client.Iterator(nil, nil) + itr.Next() + require.Equal(t, itr.Key(), []byte("key-1")) + require.Equal(t, itr.Value(), []byte("value-1")) + require.Panics(t, itr.Next) + itr.Close() + + // Set some more keys + k2 := []byte("key-2") + v2 := []byte("value-2") + client.SetSync(k2, v2) + has := client.Has(k2) + require.True(t, has) + gv2 := client.Get(k2) + require.Equal(t, gv2, v2) + + // More iteration + itr = client.Iterator(nil, nil) + itr.Next() + require.Equal(t, itr.Key(), []byte("key-1")) + require.Equal(t, itr.Value(), []byte("value-1")) + itr.Next() + require.Equal(t, itr.Key(), []byte("key-2")) + require.Equal(t, itr.Value(), []byte("value-2")) + require.Panics(t, itr.Next) + itr.Close() + + // Deletion + client.Delete(k1) + client.DeleteSync(k2) + gv1 = client.Get(k1) + gv2 = client.Get(k2) + require.Equal(t, len(gv2), 0, "after deletion, not expecting the key to exist anymore") + require.Equal(t, len(gv1), 0, "after deletion, not expecting the key to exist anymore") + + // Batch tests - set + k3 := []byte("key-3") + k4 := []byte("key-4") + k5 := []byte("key-5") + v3 := []byte("value-3") + v4 := []byte("value-4") + v5 := []byte("value-5") + bat := client.NewBatch() + bat.Set(k3, v3) + bat.Set(k4, v4) + rv3 := client.Get(k3) + require.Equal(t, 0, len(rv3), "expecting no k3 to have been stored") + rv4 := client.Get(k4) + require.Equal(t, 0, len(rv4), "expecting no k4 to have been stored") + bat.Write() + rv3 = client.Get(k3) + require.Equal(t, rv3, v3, "expecting k3 to have been stored") + rv4 = client.Get(k4) + require.Equal(t, rv4, v4, "expecting k4 to have been stored") + + // Batch tests - deletion + bat = client.NewBatch() + bat.Delete(k4) + bat.Delete(k3) + bat.WriteSync() + rv3 = client.Get(k3) + require.Equal(t, 0, len(rv3), "expecting k3 to have been deleted") + rv4 = client.Get(k4) + require.Equal(t, 0, len(rv4), "expecting k4 to have been deleted") + + // Batch tests - set and delete + bat = client.NewBatch() + bat.Set(k4, v4) + bat.Set(k5, v5) + bat.Delete(k4) + bat.WriteSync() + rv4 = client.Get(k4) + require.Equal(t, 0, len(rv4), "expecting k4 to have been deleted") + rv5 := client.Get(k5) + require.Equal(t, rv5, v5, "expecting k5 to have been stored") +} diff --git a/db/types.go b/db/types.go new file mode 100644 index 00000000..ad78859a --- /dev/null +++ b/db/types.go @@ -0,0 +1,134 @@ +package db + +// DBs are goroutine safe. +type DB interface { + + // Get returns nil iff key doesn't exist. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Get([]byte) []byte + + // Has checks if a key exists. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Has(key []byte) bool + + // Set sets the key. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key, value readonly []byte + Set([]byte, []byte) + SetSync([]byte, []byte) + + // Delete deletes the key. + // A nil key is interpreted as an empty byteslice. + // CONTRACT: key readonly []byte + Delete([]byte) + DeleteSync([]byte) + + // Iterate over a domain of keys in ascending order. End is exclusive. + // Start must be less than end, or the Iterator is invalid. + // A nil start is interpreted as an empty byteslice. + // If end is nil, iterates up to the last item (inclusive). + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + // CONTRACT: start, end readonly []byte + Iterator(start, end []byte) Iterator + + // Iterate over a domain of keys in descending order. End is exclusive. + // Start must be greater than end, or the Iterator is invalid. + // If start is nil, iterates from the last/greatest item (inclusive). + // If end is nil, iterates up to the first/least item (inclusive). + // CONTRACT: No writes may happen within a domain while an iterator exists over it. + // CONTRACT: start, end readonly []byte + ReverseIterator(start, end []byte) Iterator + + // Closes the connection. + Close() + + // Creates a batch for atomic updates. + NewBatch() Batch + + // For debugging + Print() + + // Stats returns a map of property values for all keys and the size of the cache. + Stats() map[string]string +} + +//---------------------------------------- +// Batch + +type Batch interface { + SetDeleter + Write() + WriteSync() +} + +type SetDeleter interface { + Set(key, value []byte) // CONTRACT: key, value readonly []byte + Delete(key []byte) // CONTRACT: key readonly []byte +} + +//---------------------------------------- +// Iterator + +/* + Usage: + + var itr Iterator = ... + defer itr.Close() + + for ; itr.Valid(); itr.Next() { + k, v := itr.Key(); itr.Value() + // ... + } +*/ +type Iterator interface { + + // The start & end (exclusive) limits to iterate over. + // If end < start, then the Iterator goes in reverse order. + // + // A domain of ([]byte{12, 13}, []byte{12, 14}) will iterate + // over anything with the prefix []byte{12, 13}. + // + // The smallest key is the empty byte array []byte{} - see BeginningKey(). + // The largest key is the nil byte array []byte(nil) - see EndingKey(). + // CONTRACT: start, end readonly []byte + Domain() (start []byte, end []byte) + + // Valid returns whether the current position is valid. + // Once invalid, an Iterator is forever invalid. + Valid() bool + + // Next moves the iterator to the next sequential key in the database, as + // defined by order of iteration. + // + // If Valid returns false, this method will panic. + Next() + + // Key returns the key of the cursor. + // If Valid returns false, this method will panic. + // CONTRACT: key readonly []byte + Key() (key []byte) + + // Value returns the value of the cursor. + // If Valid returns false, this method will panic. + // CONTRACT: value readonly []byte + Value() (value []byte) + + // Close releases the Iterator. + Close() +} + +// For testing convenience. +func bz(s string) []byte { + return []byte(s) +} + +// We defensively turn nil keys or values into []byte{} for +// most operations. +func nonNilBytes(bz []byte) []byte { + if bz == nil { + return []byte{} + } + return bz +} diff --git a/db/util.go b/db/util.go new file mode 100644 index 00000000..51277ac4 --- /dev/null +++ b/db/util.go @@ -0,0 +1,78 @@ +package db + +import ( + "bytes" +) + +func cp(bz []byte) (ret []byte) { + ret = make([]byte, len(bz)) + copy(ret, bz) + return ret +} + +// Returns a slice of the same length (big endian) +// except incremented by one. +// Returns nil on overflow (e.g. if bz bytes are all 0xFF) +// CONTRACT: len(bz) > 0 +func cpIncr(bz []byte) (ret []byte) { + if len(bz) == 0 { + panic("cpIncr expects non-zero bz length") + } + ret = cp(bz) + for i := len(bz) - 1; i >= 0; i-- { + if ret[i] < byte(0xFF) { + ret[i]++ + return + } + ret[i] = byte(0x00) + if i == 0 { + // Overflow + return nil + } + } + return nil +} + +// Returns a slice of the same length (big endian) +// except decremented by one. +// Returns nil on underflow (e.g. if bz bytes are all 0x00) +// CONTRACT: len(bz) > 0 +func cpDecr(bz []byte) (ret []byte) { + if len(bz) == 0 { + panic("cpDecr expects non-zero bz length") + } + ret = cp(bz) + for i := len(bz) - 1; i >= 0; i-- { + if ret[i] > byte(0x00) { + ret[i]-- + return + } + ret[i] = byte(0xFF) + if i == 0 { + // Underflow + return nil + } + } + return nil +} + +// See DB interface documentation for more information. +func IsKeyInDomain(key, start, end []byte, isReverse bool) bool { + if !isReverse { + if bytes.Compare(key, start) < 0 { + return false + } + if end != nil && bytes.Compare(end, key) <= 0 { + return false + } + return true + } else { + if start != nil && bytes.Compare(start, key) < 0 { + return false + } + if end != nil && bytes.Compare(key, end) <= 0 { + return false + } + return true + } +} diff --git a/db/util_test.go b/db/util_test.go new file mode 100644 index 00000000..44f1f9f7 --- /dev/null +++ b/db/util_test.go @@ -0,0 +1,93 @@ +package db + +import ( + "fmt" + "testing" +) + +// Empty iterator for empty db. +func TestPrefixIteratorNoMatchNil(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + + checkInvalid(t, itr) + }) + } +} + +// Empty iterator for db populated after iterator created. +func TestPrefixIteratorNoMatch1(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + db.SetSync(bz("1"), bz("value_1")) + + checkInvalid(t, itr) + }) + } +} + +// Empty iterator for prefix starting after db entry. +func TestPrefixIteratorNoMatch2(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("3"), bz("value_3")) + itr := IteratePrefix(db, []byte("4")) + + checkInvalid(t, itr) + }) + } +} + +// Iterator with single val for db with single val, starting from that val. +func TestPrefixIteratorMatch1(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("2"), bz("value_2")) + itr := IteratePrefix(db, bz("2")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("2"), bz("value_2")) + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Iterator with prefix iterates over everything with same prefix. +func TestPrefixIteratorMatches1N(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + // prefixed + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + + // not + db.SetSync(bz("b/3"), bz("value_3")) + db.SetSync(bz("a-3"), bz("value_3")) + db.SetSync(bz("a.3"), bz("value_3")) + db.SetSync(bz("abcdefg"), bz("value_3")) + itr := IteratePrefix(db, bz("a/")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + + // Bad! + checkNext(t, itr, false) + + //Once invalid... + checkInvalid(t, itr) + }) + } +} diff --git a/flowrate/README.md b/flowrate/README.md new file mode 100644 index 00000000..db428090 --- /dev/null +++ b/flowrate/README.md @@ -0,0 +1,10 @@ +Data Flow Rate Control +====================== + +To download and install this package run: + +go get github.com/mxk/go-flowrate/flowrate + +The documentation is available at: + +http://godoc.org/github.com/mxk/go-flowrate/flowrate diff --git a/flowrate/flowrate.go b/flowrate/flowrate.go new file mode 100644 index 00000000..e233eae0 --- /dev/null +++ b/flowrate/flowrate.go @@ -0,0 +1,275 @@ +// +// Written by Maxim Khitrov (November 2012) +// + +// Package flowrate provides the tools for monitoring and limiting the flow rate +// of an arbitrary data stream. +package flowrate + +import ( + "math" + "sync" + "time" +) + +// Monitor monitors and limits the transfer rate of a data stream. +type Monitor struct { + mu sync.Mutex // Mutex guarding access to all internal fields + active bool // Flag indicating an active transfer + start time.Duration // Transfer start time (clock() value) + bytes int64 // Total number of bytes transferred + samples int64 // Total number of samples taken + + rSample float64 // Most recent transfer rate sample (bytes per second) + rEMA float64 // Exponential moving average of rSample + rPeak float64 // Peak transfer rate (max of all rSamples) + rWindow float64 // rEMA window (seconds) + + sBytes int64 // Number of bytes transferred since sLast + sLast time.Duration // Most recent sample time (stop time when inactive) + sRate time.Duration // Sampling rate + + tBytes int64 // Number of bytes expected in the current transfer + tLast time.Duration // Time of the most recent transfer of at least 1 byte +} + +// New creates a new flow control monitor. Instantaneous transfer rate is +// measured and updated for each sampleRate interval. windowSize determines the +// weight of each sample in the exponential moving average (EMA) calculation. +// The exact formulas are: +// +// sampleTime = currentTime - prevSampleTime +// sampleRate = byteCount / sampleTime +// weight = 1 - exp(-sampleTime/windowSize) +// newRate = weight*sampleRate + (1-weight)*oldRate +// +// The default values for sampleRate and windowSize (if <= 0) are 100ms and 1s, +// respectively. +func New(sampleRate, windowSize time.Duration) *Monitor { + if sampleRate = clockRound(sampleRate); sampleRate <= 0 { + sampleRate = 5 * clockRate + } + if windowSize <= 0 { + windowSize = 1 * time.Second + } + now := clock() + return &Monitor{ + active: true, + start: now, + rWindow: windowSize.Seconds(), + sLast: now, + sRate: sampleRate, + tLast: now, + } +} + +// Update records the transfer of n bytes and returns n. It should be called +// after each Read/Write operation, even if n is 0. +func (m *Monitor) Update(n int) int { + m.mu.Lock() + m.update(n) + m.mu.Unlock() + return n +} + +// Hack to set the current rEMA. +func (m *Monitor) SetREMA(rEMA float64) { + m.mu.Lock() + m.rEMA = rEMA + m.samples++ + m.mu.Unlock() +} + +// IO is a convenience method intended to wrap io.Reader and io.Writer method +// execution. It calls m.Update(n) and then returns (n, err) unmodified. +func (m *Monitor) IO(n int, err error) (int, error) { + return m.Update(n), err +} + +// Done marks the transfer as finished and prevents any further updates or +// limiting. Instantaneous and current transfer rates drop to 0. Update, IO, and +// Limit methods become NOOPs. It returns the total number of bytes transferred. +func (m *Monitor) Done() int64 { + m.mu.Lock() + if now := m.update(0); m.sBytes > 0 { + m.reset(now) + } + m.active = false + m.tLast = 0 + n := m.bytes + m.mu.Unlock() + return n +} + +// timeRemLimit is the maximum Status.TimeRem value. +const timeRemLimit = 999*time.Hour + 59*time.Minute + 59*time.Second + +// Status represents the current Monitor status. All transfer rates are in bytes +// per second rounded to the nearest byte. +type Status struct { + Active bool // Flag indicating an active transfer + Start time.Time // Transfer start time + Duration time.Duration // Time period covered by the statistics + Idle time.Duration // Time since the last transfer of at least 1 byte + Bytes int64 // Total number of bytes transferred + Samples int64 // Total number of samples taken + InstRate int64 // Instantaneous transfer rate + CurRate int64 // Current transfer rate (EMA of InstRate) + AvgRate int64 // Average transfer rate (Bytes / Duration) + PeakRate int64 // Maximum instantaneous transfer rate + BytesRem int64 // Number of bytes remaining in the transfer + TimeRem time.Duration // Estimated time to completion + Progress Percent // Overall transfer progress +} + +// Status returns current transfer status information. The returned value +// becomes static after a call to Done. +func (m *Monitor) Status() Status { + m.mu.Lock() + now := m.update(0) + s := Status{ + Active: m.active, + Start: clockToTime(m.start), + Duration: m.sLast - m.start, + Idle: now - m.tLast, + Bytes: m.bytes, + Samples: m.samples, + PeakRate: round(m.rPeak), + BytesRem: m.tBytes - m.bytes, + Progress: percentOf(float64(m.bytes), float64(m.tBytes)), + } + if s.BytesRem < 0 { + s.BytesRem = 0 + } + if s.Duration > 0 { + rAvg := float64(s.Bytes) / s.Duration.Seconds() + s.AvgRate = round(rAvg) + if s.Active { + s.InstRate = round(m.rSample) + s.CurRate = round(m.rEMA) + if s.BytesRem > 0 { + if tRate := 0.8*m.rEMA + 0.2*rAvg; tRate > 0 { + ns := float64(s.BytesRem) / tRate * 1e9 + if ns > float64(timeRemLimit) { + ns = float64(timeRemLimit) + } + s.TimeRem = clockRound(time.Duration(ns)) + } + } + } + } + m.mu.Unlock() + return s +} + +// Limit restricts the instantaneous (per-sample) data flow to rate bytes per +// second. It returns the maximum number of bytes (0 <= n <= want) that may be +// transferred immediately without exceeding the limit. If block == true, the +// call blocks until n > 0. want is returned unmodified if want < 1, rate < 1, +// or the transfer is inactive (after a call to Done). +// +// At least one byte is always allowed to be transferred in any given sampling +// period. Thus, if the sampling rate is 100ms, the lowest achievable flow rate +// is 10 bytes per second. +// +// For usage examples, see the implementation of Reader and Writer in io.go. +func (m *Monitor) Limit(want int, rate int64, block bool) (n int) { + if want < 1 || rate < 1 { + return want + } + m.mu.Lock() + + // Determine the maximum number of bytes that can be sent in one sample + limit := round(float64(rate) * m.sRate.Seconds()) + if limit <= 0 { + limit = 1 + } + + // If block == true, wait until m.sBytes < limit + if now := m.update(0); block { + for m.sBytes >= limit && m.active { + now = m.waitNextSample(now) + } + } + + // Make limit <= want (unlimited if the transfer is no longer active) + if limit -= m.sBytes; limit > int64(want) || !m.active { + limit = int64(want) + } + m.mu.Unlock() + + if limit < 0 { + limit = 0 + } + return int(limit) +} + +// SetTransferSize specifies the total size of the data transfer, which allows +// the Monitor to calculate the overall progress and time to completion. +func (m *Monitor) SetTransferSize(bytes int64) { + if bytes < 0 { + bytes = 0 + } + m.mu.Lock() + m.tBytes = bytes + m.mu.Unlock() +} + +// update accumulates the transferred byte count for the current sample until +// clock() - m.sLast >= m.sRate. The monitor status is updated once the current +// sample is done. +func (m *Monitor) update(n int) (now time.Duration) { + if !m.active { + return + } + if now = clock(); n > 0 { + m.tLast = now + } + m.sBytes += int64(n) + if sTime := now - m.sLast; sTime >= m.sRate { + t := sTime.Seconds() + if m.rSample = float64(m.sBytes) / t; m.rSample > m.rPeak { + m.rPeak = m.rSample + } + + // Exponential moving average using a method similar to *nix load + // average calculation. Longer sampling periods carry greater weight. + if m.samples > 0 { + w := math.Exp(-t / m.rWindow) + m.rEMA = m.rSample + w*(m.rEMA-m.rSample) + } else { + m.rEMA = m.rSample + } + m.reset(now) + } + return +} + +// reset clears the current sample state in preparation for the next sample. +func (m *Monitor) reset(sampleTime time.Duration) { + m.bytes += m.sBytes + m.samples++ + m.sBytes = 0 + m.sLast = sampleTime +} + +// waitNextSample sleeps for the remainder of the current sample. The lock is +// released and reacquired during the actual sleep period, so it's possible for +// the transfer to be inactive when this method returns. +func (m *Monitor) waitNextSample(now time.Duration) time.Duration { + const minWait = 5 * time.Millisecond + current := m.sLast + + // sleep until the last sample time changes (ideally, just one iteration) + for m.sLast == current && m.active { + d := current + m.sRate - now + m.mu.Unlock() + if d < minWait { + d = minWait + } + time.Sleep(d) + m.mu.Lock() + now = m.update(0) + } + return now +} diff --git a/flowrate/io.go b/flowrate/io.go new file mode 100644 index 00000000..fbe09097 --- /dev/null +++ b/flowrate/io.go @@ -0,0 +1,133 @@ +// +// Written by Maxim Khitrov (November 2012) +// + +package flowrate + +import ( + "errors" + "io" +) + +// ErrLimit is returned by the Writer when a non-blocking write is short due to +// the transfer rate limit. +var ErrLimit = errors.New("flowrate: flow rate limit exceeded") + +// Limiter is implemented by the Reader and Writer to provide a consistent +// interface for monitoring and controlling data transfer. +type Limiter interface { + Done() int64 + Status() Status + SetTransferSize(bytes int64) + SetLimit(new int64) (old int64) + SetBlocking(new bool) (old bool) +} + +// Reader implements io.ReadCloser with a restriction on the rate of data +// transfer. +type Reader struct { + io.Reader // Data source + *Monitor // Flow control monitor + + limit int64 // Rate limit in bytes per second (unlimited when <= 0) + block bool // What to do when no new bytes can be read due to the limit +} + +// NewReader restricts all Read operations on r to limit bytes per second. +func NewReader(r io.Reader, limit int64) *Reader { + return &Reader{r, New(0, 0), limit, true} +} + +// Read reads up to len(p) bytes into p without exceeding the current transfer +// rate limit. It returns (0, nil) immediately if r is non-blocking and no new +// bytes can be read at this time. +func (r *Reader) Read(p []byte) (n int, err error) { + p = p[:r.Limit(len(p), r.limit, r.block)] + if len(p) > 0 { + n, err = r.IO(r.Reader.Read(p)) + } + return +} + +// SetLimit changes the transfer rate limit to new bytes per second and returns +// the previous setting. +func (r *Reader) SetLimit(new int64) (old int64) { + old, r.limit = r.limit, new + return +} + +// SetBlocking changes the blocking behavior and returns the previous setting. A +// Read call on a non-blocking reader returns immediately if no additional bytes +// may be read at this time due to the rate limit. +func (r *Reader) SetBlocking(new bool) (old bool) { + old, r.block = r.block, new + return +} + +// Close closes the underlying reader if it implements the io.Closer interface. +func (r *Reader) Close() error { + defer r.Done() + if c, ok := r.Reader.(io.Closer); ok { + return c.Close() + } + return nil +} + +// Writer implements io.WriteCloser with a restriction on the rate of data +// transfer. +type Writer struct { + io.Writer // Data destination + *Monitor // Flow control monitor + + limit int64 // Rate limit in bytes per second (unlimited when <= 0) + block bool // What to do when no new bytes can be written due to the limit +} + +// NewWriter restricts all Write operations on w to limit bytes per second. The +// transfer rate and the default blocking behavior (true) can be changed +// directly on the returned *Writer. +func NewWriter(w io.Writer, limit int64) *Writer { + return &Writer{w, New(0, 0), limit, true} +} + +// Write writes len(p) bytes from p to the underlying data stream without +// exceeding the current transfer rate limit. It returns (n, ErrLimit) if w is +// non-blocking and no additional bytes can be written at this time. +func (w *Writer) Write(p []byte) (n int, err error) { + var c int + for len(p) > 0 && err == nil { + s := p[:w.Limit(len(p), w.limit, w.block)] + if len(s) > 0 { + c, err = w.IO(w.Writer.Write(s)) + } else { + return n, ErrLimit + } + p = p[c:] + n += c + } + return +} + +// SetLimit changes the transfer rate limit to new bytes per second and returns +// the previous setting. +func (w *Writer) SetLimit(new int64) (old int64) { + old, w.limit = w.limit, new + return +} + +// SetBlocking changes the blocking behavior and returns the previous setting. A +// Write call on a non-blocking writer returns as soon as no additional bytes +// may be written at this time due to the rate limit. +func (w *Writer) SetBlocking(new bool) (old bool) { + old, w.block = w.block, new + return +} + +// Close closes the underlying writer if it implements the io.Closer interface. +func (w *Writer) Close() error { + defer w.Done() + if c, ok := w.Writer.(io.Closer); ok { + return c.Close() + } + return nil +} diff --git a/flowrate/io_test.go b/flowrate/io_test.go new file mode 100644 index 00000000..c84029d5 --- /dev/null +++ b/flowrate/io_test.go @@ -0,0 +1,194 @@ +// +// Written by Maxim Khitrov (November 2012) +// + +package flowrate + +import ( + "bytes" + "testing" + "time" +) + +const ( + _50ms = 50 * time.Millisecond + _100ms = 100 * time.Millisecond + _200ms = 200 * time.Millisecond + _300ms = 300 * time.Millisecond + _400ms = 400 * time.Millisecond + _500ms = 500 * time.Millisecond +) + +func nextStatus(m *Monitor) Status { + samples := m.samples + for i := 0; i < 30; i++ { + if s := m.Status(); s.Samples != samples { + return s + } + time.Sleep(5 * time.Millisecond) + } + return m.Status() +} + +func TestReader(t *testing.T) { + in := make([]byte, 100) + for i := range in { + in[i] = byte(i) + } + b := make([]byte, 100) + r := NewReader(bytes.NewReader(in), 100) + start := time.Now() + + // Make sure r implements Limiter + _ = Limiter(r) + + // 1st read of 10 bytes is performed immediately + if n, err := r.Read(b); n != 10 || err != nil { + t.Fatalf("r.Read(b) expected 10 (); got %v (%v)", n, err) + } else if rt := time.Since(start); rt > _50ms { + t.Fatalf("r.Read(b) took too long (%v)", rt) + } + + // No new Reads allowed in the current sample + r.SetBlocking(false) + if n, err := r.Read(b); n != 0 || err != nil { + t.Fatalf("r.Read(b) expected 0 (); got %v (%v)", n, err) + } else if rt := time.Since(start); rt > _50ms { + t.Fatalf("r.Read(b) took too long (%v)", rt) + } + + status := [6]Status{0: r.Status()} // No samples in the first status + + // 2nd read of 10 bytes blocks until the next sample + r.SetBlocking(true) + if n, err := r.Read(b[10:]); n != 10 || err != nil { + t.Fatalf("r.Read(b[10:]) expected 10 (); got %v (%v)", n, err) + } else if rt := time.Since(start); rt < _100ms { + t.Fatalf("r.Read(b[10:]) returned ahead of time (%v)", rt) + } + + status[1] = r.Status() // 1st sample + status[2] = nextStatus(r.Monitor) // 2nd sample + status[3] = nextStatus(r.Monitor) // No activity for the 3rd sample + + if n := r.Done(); n != 20 { + t.Fatalf("r.Done() expected 20; got %v", n) + } + + status[4] = r.Status() + status[5] = nextStatus(r.Monitor) // Timeout + start = status[0].Start + + // Active, Start, Duration, Idle, Bytes, Samples, InstRate, CurRate, AvgRate, PeakRate, BytesRem, TimeRem, Progress + want := []Status{ + Status{true, start, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + Status{true, start, _100ms, 0, 10, 1, 100, 100, 100, 100, 0, 0, 0}, + Status{true, start, _200ms, _100ms, 20, 2, 100, 100, 100, 100, 0, 0, 0}, + Status{true, start, _300ms, _200ms, 20, 3, 0, 90, 67, 100, 0, 0, 0}, + Status{false, start, _300ms, 0, 20, 3, 0, 0, 67, 100, 0, 0, 0}, + Status{false, start, _300ms, 0, 20, 3, 0, 0, 67, 100, 0, 0, 0}, + } + for i, s := range status { + if !statusesAreEqual(&s, &want[i]) { + t.Errorf("r.Status(%v)\nexpected: %v\ngot : %v", i, want[i], s) + } + } + if !bytes.Equal(b[:20], in[:20]) { + t.Errorf("r.Read() input doesn't match output") + } +} + +func TestWriter(t *testing.T) { + b := make([]byte, 100) + for i := range b { + b[i] = byte(i) + } + w := NewWriter(&bytes.Buffer{}, 200) + start := time.Now() + + // Make sure w implements Limiter + _ = Limiter(w) + + // Non-blocking 20-byte write for the first sample returns ErrLimit + w.SetBlocking(false) + if n, err := w.Write(b); n != 20 || err != ErrLimit { + t.Fatalf("w.Write(b) expected 20 (ErrLimit); got %v (%v)", n, err) + } else if rt := time.Since(start); rt > _50ms { + t.Fatalf("w.Write(b) took too long (%v)", rt) + } + + // Blocking 80-byte write + w.SetBlocking(true) + if n, err := w.Write(b[20:]); n != 80 || err != nil { + t.Fatalf("w.Write(b[20:]) expected 80 (); got %v (%v)", n, err) + } else if rt := time.Since(start); rt < _300ms { + // Explanation for `rt < _300ms` (as opposed to `< _400ms`) + // + // |<-- start | | + // epochs: -----0ms|---100ms|---200ms|---300ms|---400ms + // sends: 20|20 |20 |20 |20# + // + // NOTE: The '#' symbol can thus happen before 400ms is up. + // Thus, we can only panic if rt < _300ms. + t.Fatalf("w.Write(b[20:]) returned ahead of time (%v)", rt) + } + + w.SetTransferSize(100) + status := []Status{w.Status(), nextStatus(w.Monitor)} + start = status[0].Start + + // Active, Start, Duration, Idle, Bytes, Samples, InstRate, CurRate, AvgRate, PeakRate, BytesRem, TimeRem, Progress + want := []Status{ + Status{true, start, _400ms, 0, 80, 4, 200, 200, 200, 200, 20, _100ms, 80000}, + Status{true, start, _500ms, _100ms, 100, 5, 200, 200, 200, 200, 0, 0, 100000}, + } + for i, s := range status { + if !statusesAreEqual(&s, &want[i]) { + t.Errorf("w.Status(%v)\nexpected: %v\ngot : %v\n", i, want[i], s) + } + } + if !bytes.Equal(b, w.Writer.(*bytes.Buffer).Bytes()) { + t.Errorf("w.Write() input doesn't match output") + } +} + +const maxDeviationForDuration = 50 * time.Millisecond +const maxDeviationForRate int64 = 50 + +// statusesAreEqual returns true if s1 is equal to s2. Equality here means +// general equality of fields except for the duration and rates, which can +// drift due to unpredictable delays (e.g. thread wakes up 25ms after +// `time.Sleep` has ended). +func statusesAreEqual(s1 *Status, s2 *Status) bool { + if s1.Active == s2.Active && + s1.Start == s2.Start && + durationsAreEqual(s1.Duration, s2.Duration, maxDeviationForDuration) && + s1.Idle == s2.Idle && + s1.Bytes == s2.Bytes && + s1.Samples == s2.Samples && + ratesAreEqual(s1.InstRate, s2.InstRate, maxDeviationForRate) && + ratesAreEqual(s1.CurRate, s2.CurRate, maxDeviationForRate) && + ratesAreEqual(s1.AvgRate, s2.AvgRate, maxDeviationForRate) && + ratesAreEqual(s1.PeakRate, s2.PeakRate, maxDeviationForRate) && + s1.BytesRem == s2.BytesRem && + durationsAreEqual(s1.TimeRem, s2.TimeRem, maxDeviationForDuration) && + s1.Progress == s2.Progress { + return true + } + return false +} + +func durationsAreEqual(d1 time.Duration, d2 time.Duration, maxDeviation time.Duration) bool { + return d2-d1 <= maxDeviation +} + +func ratesAreEqual(r1 int64, r2 int64, maxDeviation int64) bool { + sub := r1 - r2 + if sub < 0 { + sub = -sub + } + if sub <= maxDeviation { + return true + } + return false +} diff --git a/flowrate/util.go b/flowrate/util.go new file mode 100644 index 00000000..b33ddc70 --- /dev/null +++ b/flowrate/util.go @@ -0,0 +1,67 @@ +// +// Written by Maxim Khitrov (November 2012) +// + +package flowrate + +import ( + "math" + "strconv" + "time" +) + +// clockRate is the resolution and precision of clock(). +const clockRate = 20 * time.Millisecond + +// czero is the process start time rounded down to the nearest clockRate +// increment. +var czero = time.Now().Round(clockRate) + +// clock returns a low resolution timestamp relative to the process start time. +func clock() time.Duration { + return time.Now().Round(clockRate).Sub(czero) +} + +// clockToTime converts a clock() timestamp to an absolute time.Time value. +func clockToTime(c time.Duration) time.Time { + return czero.Add(c) +} + +// clockRound returns d rounded to the nearest clockRate increment. +func clockRound(d time.Duration) time.Duration { + return (d + clockRate>>1) / clockRate * clockRate +} + +// round returns x rounded to the nearest int64 (non-negative values only). +func round(x float64) int64 { + if _, frac := math.Modf(x); frac >= 0.5 { + return int64(math.Ceil(x)) + } + return int64(math.Floor(x)) +} + +// Percent represents a percentage in increments of 1/1000th of a percent. +type Percent uint32 + +// percentOf calculates what percent of the total is x. +func percentOf(x, total float64) Percent { + if x < 0 || total <= 0 { + return 0 + } else if p := round(x / total * 1e5); p <= math.MaxUint32 { + return Percent(p) + } + return Percent(math.MaxUint32) +} + +func (p Percent) Float() float64 { + return float64(p) * 1e-3 +} + +func (p Percent) String() string { + var buf [12]byte + b := strconv.AppendUint(buf[:0], uint64(p)/1000, 10) + n := len(b) + b = strconv.AppendUint(b, 1000+uint64(p)%1000, 10) + b[n] = '.' + return string(append(b, '%')) +} diff --git a/glide.lock b/glide.lock new file mode 100644 index 00000000..a0ada5a4 --- /dev/null +++ b/glide.lock @@ -0,0 +1,108 @@ +hash: 98752078f39da926f655268b3b143f713d64edd379fc9fcb1210d9d8aa7ab4e0 +updated: 2018-02-03T01:28:00.221548057-05:00 +imports: +- name: github.com/fsnotify/fsnotify + version: c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9 +- name: github.com/go-kit/kit + version: 4dc7be5d2d12881735283bcab7352178e190fc71 + subpackages: + - log + - log/level + - log/term +- name: github.com/go-logfmt/logfmt + version: 390ab7935ee28ec6b286364bba9b4dd6410cb3d5 +- name: github.com/go-stack/stack + version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf +- name: github.com/gogo/protobuf + version: 1adfc126b41513cc696b209667c8656ea7aac67c + subpackages: + - gogoproto + - proto + - protoc-gen-gogo/descriptor +- name: github.com/golang/snappy + version: 553a641470496b2327abcac10b36396bd98e45c9 +- name: github.com/hashicorp/hcl + version: 23c074d0eceb2b8a5bfdbb271ab780cde70f05a8 + subpackages: + - hcl/ast + - hcl/parser + - hcl/scanner + - hcl/strconv + - hcl/token + - json/parser + - json/scanner + - json/token +- name: github.com/inconshreveable/mousetrap + version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75 +- name: github.com/jmhodges/levigo + version: c42d9e0ca023e2198120196f842701bb4c55d7b9 +- name: github.com/kr/logfmt + version: b84e30acd515aadc4b783ad4ff83aff3299bdfe0 +- name: github.com/magiconair/properties + version: 49d762b9817ba1c2e9d0c69183c2b4a8b8f1d934 +- name: github.com/mitchellh/mapstructure + version: b4575eea38cca1123ec2dc90c26529b5c5acfcff +- name: github.com/pelletier/go-toml + version: acdc4509485b587f5e675510c4f2c63e90ff68a8 +- name: github.com/pkg/errors + version: 645ef00459ed84a119197bfb8d8205042c6df63d +- name: github.com/spf13/afero + version: bb8f1927f2a9d3ab41c9340aa034f6b803f4359c + subpackages: + - mem +- name: github.com/spf13/cast + version: acbeb36b902d72a7a4c18e8f3241075e7ab763e4 +- name: github.com/spf13/cobra + version: 7b2c5ac9fc04fc5efafb60700713d4fa609b777b +- name: github.com/spf13/jwalterweatherman + version: 7c0cea34c8ece3fbeb2b27ab9b59511d360fb394 +- name: github.com/spf13/pflag + version: 97afa5e7ca8a08a383cb259e06636b5e2cc7897f +- name: github.com/spf13/viper + version: 25b30aa063fc18e48662b86996252eabdcf2f0c7 +- name: github.com/syndtr/goleveldb + version: b89cc31ef7977104127d34c1bd31ebd1a9db2199 + subpackages: + - leveldb + - leveldb/cache + - leveldb/comparer + - leveldb/errors + - leveldb/filter + - leveldb/iterator + - leveldb/journal + - leveldb/memdb + - leveldb/opt + - leveldb/storage + - leveldb/table + - leveldb/util +- name: golang.org/x/crypto + version: edd5e9b0879d13ee6970a50153d85b8fec9f7686 + subpackages: + - ripemd160 +- name: golang.org/x/sys + version: 37707fdb30a5b38865cfb95e5aab41707daec7fd + subpackages: + - unix +- name: golang.org/x/text + version: c01e4764d870b77f8abe5096ee19ad20d80e8075 + subpackages: + - transform + - unicode/norm +- name: gopkg.in/yaml.v2 + version: d670f9405373e636a5a2765eea47fac0c9bc91a4 +testImports: +- name: github.com/davecgh/go-spew + version: 346938d642f2ec3594ed81d874461961cd0faa76 + subpackages: + - spew +- name: github.com/fortytw2/leaktest + version: 3b724c3d7b8729a35bf4e577f71653aec6e53513 +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 12b6f73e6084dad08a7c6e575284b177ecafbc71 + subpackages: + - assert + - require diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 00000000..cf3da346 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,38 @@ +package: github.com/tendermint/tmlibs +import: +- package: github.com/go-kit/kit + version: ^0.6.0 + subpackages: + - log + - log/level + - log/term +- package: github.com/go-logfmt/logfmt + version: ^0.3.0 +- package: github.com/gogo/protobuf + version: ^1.0.0 + subpackages: + - gogoproto + - proto +- package: github.com/jmhodges/levigo +- package: github.com/pkg/errors + version: ^0.8.0 +- package: github.com/spf13/cobra + version: ^0.0.1 +- package: github.com/spf13/viper + version: ^1.0.0 +- package: github.com/syndtr/goleveldb + subpackages: + - leveldb + - leveldb/errors + - leveldb/iterator + - leveldb/opt +- package: golang.org/x/crypto + subpackages: + - ripemd160 +testImport: +- package: github.com/stretchr/testify + version: ^1.2.1 + subpackages: + - assert + - require +- package: github.com/fortytw2/leaktest diff --git a/log/filter.go b/log/filter.go new file mode 100644 index 00000000..768c09b8 --- /dev/null +++ b/log/filter.go @@ -0,0 +1,158 @@ +package log + +import "fmt" + +type level byte + +const ( + levelDebug level = 1 << iota + levelInfo + levelError +) + +type filter struct { + next Logger + allowed level // XOR'd levels for default case + allowedKeyvals map[keyval]level // When key-value match, use this level +} + +type keyval struct { + key interface{} + value interface{} +} + +// NewFilter wraps next and implements filtering. See the commentary on the +// Option functions for a detailed description of how to configure levels. If +// no options are provided, all leveled log events created with Debug, Info or +// Error helper methods are squelched. +func NewFilter(next Logger, options ...Option) Logger { + l := &filter{ + next: next, + allowedKeyvals: make(map[keyval]level), + } + for _, option := range options { + option(l) + } + return l +} + +func (l *filter) Info(msg string, keyvals ...interface{}) { + levelAllowed := l.allowed&levelInfo != 0 + if !levelAllowed { + return + } + l.next.Info(msg, keyvals...) +} + +func (l *filter) Debug(msg string, keyvals ...interface{}) { + levelAllowed := l.allowed&levelDebug != 0 + if !levelAllowed { + return + } + l.next.Debug(msg, keyvals...) +} + +func (l *filter) Error(msg string, keyvals ...interface{}) { + levelAllowed := l.allowed&levelError != 0 + if !levelAllowed { + return + } + l.next.Error(msg, keyvals...) +} + +// With implements Logger by constructing a new filter with a keyvals appended +// to the logger. +// +// If custom level was set for a keyval pair using one of the +// Allow*With methods, it is used as the logger's level. +// +// Examples: +// logger = log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("module", "crypto")) +// logger.With("module", "crypto").Info("Hello") # produces "I... Hello module=crypto" +// +// logger = log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("module", "crypto"), log.AllowNoneWith("user", "Sam")) +// logger.With("module", "crypto", "user", "Sam").Info("Hello") # returns nil +// +// logger = log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("module", "crypto"), log.AllowNoneWith("user", "Sam")) +// logger.With("user", "Sam").With("module", "crypto").Info("Hello") # produces "I... Hello module=crypto user=Sam" +func (l *filter) With(keyvals ...interface{}) Logger { + for i := len(keyvals) - 2; i >= 0; i -= 2 { + for kv, allowed := range l.allowedKeyvals { + if keyvals[i] == kv.key && keyvals[i+1] == kv.value { + return &filter{next: l.next.With(keyvals...), allowed: allowed, allowedKeyvals: l.allowedKeyvals} + } + } + } + return &filter{next: l.next.With(keyvals...), allowed: l.allowed, allowedKeyvals: l.allowedKeyvals} +} + +//-------------------------------------------------------------------------------- + +// Option sets a parameter for the filter. +type Option func(*filter) + +// AllowLevel returns an option for the given level or error if no option exist +// for such level. +func AllowLevel(lvl string) (Option, error) { + switch lvl { + case "debug": + return AllowDebug(), nil + case "info": + return AllowInfo(), nil + case "error": + return AllowError(), nil + case "none": + return AllowNone(), nil + default: + return nil, fmt.Errorf("Expected either \"info\", \"debug\", \"error\" or \"none\" level, given %s", lvl) + } +} + +// AllowAll is an alias for AllowDebug. +func AllowAll() Option { + return AllowDebug() +} + +// AllowDebug allows error, info and debug level log events to pass. +func AllowDebug() Option { + return allowed(levelError | levelInfo | levelDebug) +} + +// AllowInfo allows error and info level log events to pass. +func AllowInfo() Option { + return allowed(levelError | levelInfo) +} + +// AllowError allows only error level log events to pass. +func AllowError() Option { + return allowed(levelError) +} + +// AllowNone allows no leveled log events to pass. +func AllowNone() Option { + return allowed(0) +} + +func allowed(allowed level) Option { + return func(l *filter) { l.allowed = allowed } +} + +// AllowDebugWith allows error, info and debug level log events to pass for a specific key value pair. +func AllowDebugWith(key interface{}, value interface{}) Option { + return func(l *filter) { l.allowedKeyvals[keyval{key, value}] = levelError | levelInfo | levelDebug } +} + +// AllowInfoWith allows error and info level log events to pass for a specific key value pair. +func AllowInfoWith(key interface{}, value interface{}) Option { + return func(l *filter) { l.allowedKeyvals[keyval{key, value}] = levelError | levelInfo } +} + +// AllowErrorWith allows only error level log events to pass for a specific key value pair. +func AllowErrorWith(key interface{}, value interface{}) Option { + return func(l *filter) { l.allowedKeyvals[keyval{key, value}] = levelError } +} + +// AllowNoneWith allows no leveled log events to pass for a specific key value pair. +func AllowNoneWith(key interface{}, value interface{}) Option { + return func(l *filter) { l.allowedKeyvals[keyval{key, value}] = 0 } +} diff --git a/log/filter_test.go b/log/filter_test.go new file mode 100644 index 00000000..8d8b3b27 --- /dev/null +++ b/log/filter_test.go @@ -0,0 +1,118 @@ +package log_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/tendermint/tmlibs/log" +) + +func TestVariousLevels(t *testing.T) { + testCases := []struct { + name string + allowed log.Option + want string + }{ + { + "AllowAll", + log.AllowAll(), + strings.Join([]string{ + `{"_msg":"here","level":"debug","this is":"debug log"}`, + `{"_msg":"here","level":"info","this is":"info log"}`, + `{"_msg":"here","level":"error","this is":"error log"}`, + }, "\n"), + }, + { + "AllowDebug", + log.AllowDebug(), + strings.Join([]string{ + `{"_msg":"here","level":"debug","this is":"debug log"}`, + `{"_msg":"here","level":"info","this is":"info log"}`, + `{"_msg":"here","level":"error","this is":"error log"}`, + }, "\n"), + }, + { + "AllowInfo", + log.AllowInfo(), + strings.Join([]string{ + `{"_msg":"here","level":"info","this is":"info log"}`, + `{"_msg":"here","level":"error","this is":"error log"}`, + }, "\n"), + }, + { + "AllowError", + log.AllowError(), + strings.Join([]string{ + `{"_msg":"here","level":"error","this is":"error log"}`, + }, "\n"), + }, + { + "AllowNone", + log.AllowNone(), + ``, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + logger := log.NewFilter(log.NewTMJSONLogger(&buf), tc.allowed) + + logger.Debug("here", "this is", "debug log") + logger.Info("here", "this is", "info log") + logger.Error("here", "this is", "error log") + + if want, have := tc.want, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant:\n%s\nhave:\n%s", want, have) + } + }) + } +} + +func TestLevelContext(t *testing.T) { + var buf bytes.Buffer + + logger := log.NewTMJSONLogger(&buf) + logger = log.NewFilter(logger, log.AllowError()) + logger = logger.With("context", "value") + + logger.Error("foo", "bar", "baz") + if want, have := `{"_msg":"foo","bar":"baz","context":"value","level":"error"}`, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } + + buf.Reset() + logger.Info("foo", "bar", "baz") + if want, have := ``, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } +} + +func TestVariousAllowWith(t *testing.T) { + var buf bytes.Buffer + + logger := log.NewTMJSONLogger(&buf) + + logger1 := log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("context", "value")) + logger1.With("context", "value").Info("foo", "bar", "baz") + if want, have := `{"_msg":"foo","bar":"baz","context":"value","level":"info"}`, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } + + buf.Reset() + + logger2 := log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("context", "value"), log.AllowNoneWith("user", "Sam")) + logger2.With("context", "value", "user", "Sam").Info("foo", "bar", "baz") + if want, have := ``, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } + + buf.Reset() + + logger3 := log.NewFilter(logger, log.AllowError(), log.AllowInfoWith("context", "value"), log.AllowNoneWith("user", "Sam")) + logger3.With("user", "Sam").With("context", "value").Info("foo", "bar", "baz") + if want, have := `{"_msg":"foo","bar":"baz","context":"value","level":"info","user":"Sam"}`, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } +} diff --git a/log/logger.go b/log/logger.go new file mode 100644 index 00000000..ddb187bc --- /dev/null +++ b/log/logger.go @@ -0,0 +1,30 @@ +package log + +import ( + "io" + + kitlog "github.com/go-kit/kit/log" +) + +// Logger is what any Tendermint library should take. +type Logger interface { + Debug(msg string, keyvals ...interface{}) + Info(msg string, keyvals ...interface{}) + Error(msg string, keyvals ...interface{}) + + With(keyvals ...interface{}) Logger +} + +// NewSyncWriter returns a new writer that is safe for concurrent use by +// multiple goroutines. Writes to the returned writer are passed on to w. If +// another write is already in progress, the calling goroutine blocks until +// the writer is available. +// +// If w implements the following interface, so does the returned writer. +// +// interface { +// Fd() uintptr +// } +func NewSyncWriter(w io.Writer) io.Writer { + return kitlog.NewSyncWriter(w) +} diff --git a/log/nop_logger.go b/log/nop_logger.go new file mode 100644 index 00000000..12d75abe --- /dev/null +++ b/log/nop_logger.go @@ -0,0 +1,17 @@ +package log + +type nopLogger struct{} + +// Interface assertions +var _ Logger = (*nopLogger)(nil) + +// NewNopLogger returns a logger that doesn't do anything. +func NewNopLogger() Logger { return &nopLogger{} } + +func (nopLogger) Info(string, ...interface{}) {} +func (nopLogger) Debug(string, ...interface{}) {} +func (nopLogger) Error(string, ...interface{}) {} + +func (l *nopLogger) With(...interface{}) Logger { + return l +} diff --git a/log/testing_logger.go b/log/testing_logger.go new file mode 100644 index 00000000..81482bef --- /dev/null +++ b/log/testing_logger.go @@ -0,0 +1,49 @@ +package log + +import ( + "os" + "testing" + + "github.com/go-kit/kit/log/term" +) + +var ( + // reuse the same logger across all tests + _testingLogger Logger +) + +// TestingLogger returns a TMLogger which writes to STDOUT if testing being run +// with the verbose (-v) flag, NopLogger otherwise. +// +// Note that the call to TestingLogger() must be made +// inside a test (not in the init func) because +// verbose flag only set at the time of testing. +func TestingLogger() Logger { + if _testingLogger != nil { + return _testingLogger + } + + if testing.Verbose() { + _testingLogger = NewTMLogger(NewSyncWriter(os.Stdout)) + } else { + _testingLogger = NewNopLogger() + } + + return _testingLogger +} + +// TestingLoggerWithColorFn allow you to provide your own color function. See +// TestingLogger for documentation. +func TestingLoggerWithColorFn(colorFn func(keyvals ...interface{}) term.FgBgColor) Logger { + if _testingLogger != nil { + return _testingLogger + } + + if testing.Verbose() { + _testingLogger = NewTMLoggerWithColorFn(NewSyncWriter(os.Stdout), colorFn) + } else { + _testingLogger = NewNopLogger() + } + + return _testingLogger +} diff --git a/log/tm_json_logger.go b/log/tm_json_logger.go new file mode 100644 index 00000000..a71ac103 --- /dev/null +++ b/log/tm_json_logger.go @@ -0,0 +1,15 @@ +package log + +import ( + "io" + + kitlog "github.com/go-kit/kit/log" +) + +// NewTMJSONLogger returns a Logger that encodes keyvals to the Writer as a +// single JSON object. Each log event produces no more than one call to +// w.Write. The passed Writer must be safe for concurrent use by multiple +// goroutines if the returned Logger will be used concurrently. +func NewTMJSONLogger(w io.Writer) Logger { + return &tmLogger{kitlog.NewJSONLogger(w)} +} diff --git a/log/tm_logger.go b/log/tm_logger.go new file mode 100644 index 00000000..d49e8d22 --- /dev/null +++ b/log/tm_logger.go @@ -0,0 +1,83 @@ +package log + +import ( + "fmt" + "io" + + kitlog "github.com/go-kit/kit/log" + kitlevel "github.com/go-kit/kit/log/level" + "github.com/go-kit/kit/log/term" +) + +const ( + msgKey = "_msg" // "_" prefixed to avoid collisions + moduleKey = "module" +) + +type tmLogger struct { + srcLogger kitlog.Logger +} + +// Interface assertions +var _ Logger = (*tmLogger)(nil) + +// NewTMTermLogger returns a logger that encodes msg and keyvals to the Writer +// using go-kit's log as an underlying logger and our custom formatter. Note +// that underlying logger could be swapped with something else. +func NewTMLogger(w io.Writer) Logger { + // Color by level value + colorFn := func(keyvals ...interface{}) term.FgBgColor { + if keyvals[0] != kitlevel.Key() { + panic(fmt.Sprintf("expected level key to be first, got %v", keyvals[0])) + } + switch keyvals[1].(kitlevel.Value).String() { + case "debug": + return term.FgBgColor{Fg: term.DarkGray} + case "error": + return term.FgBgColor{Fg: term.Red} + default: + return term.FgBgColor{} + } + } + + return &tmLogger{term.NewLogger(w, NewTMFmtLogger, colorFn)} +} + +// NewTMLoggerWithColorFn allows you to provide your own color function. See +// NewTMLogger for documentation. +func NewTMLoggerWithColorFn(w io.Writer, colorFn func(keyvals ...interface{}) term.FgBgColor) Logger { + return &tmLogger{term.NewLogger(w, NewTMFmtLogger, colorFn)} +} + +// Info logs a message at level Info. +func (l *tmLogger) Info(msg string, keyvals ...interface{}) { + lWithLevel := kitlevel.Info(l.srcLogger) + if err := kitlog.With(lWithLevel, msgKey, msg).Log(keyvals...); err != nil { + errLogger := kitlevel.Error(l.srcLogger) + kitlog.With(errLogger, msgKey, msg).Log("err", err) + } +} + +// Debug logs a message at level Debug. +func (l *tmLogger) Debug(msg string, keyvals ...interface{}) { + lWithLevel := kitlevel.Debug(l.srcLogger) + if err := kitlog.With(lWithLevel, msgKey, msg).Log(keyvals...); err != nil { + errLogger := kitlevel.Error(l.srcLogger) + kitlog.With(errLogger, msgKey, msg).Log("err", err) + } +} + +// Error logs a message at level Error. +func (l *tmLogger) Error(msg string, keyvals ...interface{}) { + lWithLevel := kitlevel.Error(l.srcLogger) + lWithMsg := kitlog.With(lWithLevel, msgKey, msg) + if err := lWithMsg.Log(keyvals...); err != nil { + lWithMsg.Log("err", err) + } +} + +// With returns a new contextual logger with keyvals prepended to those passed +// to calls to Info, Debug or Error. +func (l *tmLogger) With(keyvals ...interface{}) Logger { + return &tmLogger{kitlog.With(l.srcLogger, keyvals...)} +} diff --git a/log/tm_logger_test.go b/log/tm_logger_test.go new file mode 100644 index 00000000..b2b600ad --- /dev/null +++ b/log/tm_logger_test.go @@ -0,0 +1,44 @@ +package log_test + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" + + "github.com/go-logfmt/logfmt" + "github.com/tendermint/tmlibs/log" +) + +func TestLoggerLogsItsErrors(t *testing.T) { + var buf bytes.Buffer + + logger := log.NewTMLogger(&buf) + logger.Info("foo", "baz baz", "bar") + msg := strings.TrimSpace(buf.String()) + if !strings.Contains(msg, logfmt.ErrInvalidKey.Error()) { + t.Errorf("Expected logger msg to contain ErrInvalidKey, got %s", msg) + } +} + +func BenchmarkTMLoggerSimple(b *testing.B) { + benchmarkRunner(b, log.NewTMLogger(ioutil.Discard), baseInfoMessage) +} + +func BenchmarkTMLoggerContextual(b *testing.B) { + benchmarkRunner(b, log.NewTMLogger(ioutil.Discard), withInfoMessage) +} + +func benchmarkRunner(b *testing.B, logger log.Logger, f func(log.Logger)) { + lc := logger.With("common_key", "common_value") + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + f(lc) + } +} + +var ( + baseInfoMessage = func(logger log.Logger) { logger.Info("foo_message", "foo_key", "foo_value") } + withInfoMessage = func(logger log.Logger) { logger.With("a", "b").Info("c", "d", "f") } +) diff --git a/log/tmfmt_logger.go b/log/tmfmt_logger.go new file mode 100644 index 00000000..d0397971 --- /dev/null +++ b/log/tmfmt_logger.go @@ -0,0 +1,127 @@ +package log + +import ( + "bytes" + "fmt" + "io" + "sync" + "time" + + kitlog "github.com/go-kit/kit/log" + kitlevel "github.com/go-kit/kit/log/level" + "github.com/go-logfmt/logfmt" +) + +type tmfmtEncoder struct { + *logfmt.Encoder + buf bytes.Buffer +} + +func (l *tmfmtEncoder) Reset() { + l.Encoder.Reset() + l.buf.Reset() +} + +var tmfmtEncoderPool = sync.Pool{ + New: func() interface{} { + var enc tmfmtEncoder + enc.Encoder = logfmt.NewEncoder(&enc.buf) + return &enc + }, +} + +type tmfmtLogger struct { + w io.Writer +} + +// NewTMFmtLogger returns a logger that encodes keyvals to the Writer in +// Tendermint custom format. Note complex types (structs, maps, slices) +// formatted as "%+v". +// +// Each log event produces no more than one call to w.Write. +// The passed Writer must be safe for concurrent use by multiple goroutines if +// the returned Logger will be used concurrently. +func NewTMFmtLogger(w io.Writer) kitlog.Logger { + return &tmfmtLogger{w} +} + +func (l tmfmtLogger) Log(keyvals ...interface{}) error { + enc := tmfmtEncoderPool.Get().(*tmfmtEncoder) + enc.Reset() + defer tmfmtEncoderPool.Put(enc) + + const unknown = "unknown" + lvl := "none" + msg := unknown + module := unknown + + // indexes of keys to skip while encoding later + excludeIndexes := make([]int, 0) + + for i := 0; i < len(keyvals)-1; i += 2 { + // Extract level + if keyvals[i] == kitlevel.Key() { + excludeIndexes = append(excludeIndexes, i) + switch keyvals[i+1].(type) { + case string: + lvl = keyvals[i+1].(string) + case kitlevel.Value: + lvl = keyvals[i+1].(kitlevel.Value).String() + default: + panic(fmt.Sprintf("level value of unknown type %T", keyvals[i+1])) + } + // and message + } else if keyvals[i] == msgKey { + excludeIndexes = append(excludeIndexes, i) + msg = keyvals[i+1].(string) + // and module (could be multiple keyvals; if such case last keyvalue wins) + } else if keyvals[i] == moduleKey { + excludeIndexes = append(excludeIndexes, i) + module = keyvals[i+1].(string) + } + } + + // Form a custom Tendermint line + // + // Example: + // D[05-02|11:06:44.322] Stopping AddrBook (ignoring: already stopped) + // + // Description: + // D - first character of the level, uppercase (ASCII only) + // [05-02|11:06:44.322] - our time format (see https://golang.org/src/time/format.go) + // Stopping ... - message + enc.buf.WriteString(fmt.Sprintf("%c[%s] %-44s ", lvl[0]-32, time.Now().UTC().Format("01-02|15:04:05.000"), msg)) + + if module != unknown { + enc.buf.WriteString("module=" + module + " ") + } + +KeyvalueLoop: + for i := 0; i < len(keyvals)-1; i += 2 { + for _, j := range excludeIndexes { + if i == j { + continue KeyvalueLoop + } + } + + err := enc.EncodeKeyval(keyvals[i], keyvals[i+1]) + if err == logfmt.ErrUnsupportedValueType { + enc.EncodeKeyval(keyvals[i], fmt.Sprintf("%+v", keyvals[i+1])) + } else if err != nil { + return err + } + } + + // Add newline to the end of the buffer + if err := enc.EndRecord(); err != nil { + return err + } + + // The Logger interface requires implementations to be safe for concurrent + // use by multiple goroutines. For this implementation that means making + // only one call to l.w.Write() for each call to Log. + if _, err := l.w.Write(enc.buf.Bytes()); err != nil { + return err + } + return nil +} diff --git a/log/tmfmt_logger_test.go b/log/tmfmt_logger_test.go new file mode 100644 index 00000000..a07b323c --- /dev/null +++ b/log/tmfmt_logger_test.go @@ -0,0 +1,118 @@ +package log_test + +import ( + "bytes" + "errors" + "io/ioutil" + "math" + "regexp" + "testing" + + kitlog "github.com/go-kit/kit/log" + "github.com/stretchr/testify/assert" + "github.com/tendermint/tmlibs/log" +) + +func TestTMFmtLogger(t *testing.T) { + t.Parallel() + buf := &bytes.Buffer{} + logger := log.NewTMFmtLogger(buf) + + if err := logger.Log("hello", "world"); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`N\[.+\] unknown \s+ hello=world\n$`), buf.String()) + + buf.Reset() + if err := logger.Log("a", 1, "err", errors.New("error")); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`N\[.+\] unknown \s+ a=1 err=error\n$`), buf.String()) + + buf.Reset() + if err := logger.Log("std_map", map[int]int{1: 2}, "my_map", mymap{0: 0}); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`N\[.+\] unknown \s+ std_map=map\[1:2\] my_map=special_behavior\n$`), buf.String()) + + buf.Reset() + if err := logger.Log("level", "error"); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`E\[.+\] unknown \s+\n$`), buf.String()) + + buf.Reset() + if err := logger.Log("_msg", "Hello"); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`N\[.+\] Hello \s+\n$`), buf.String()) + + buf.Reset() + if err := logger.Log("module", "main", "module", "crypto", "module", "wire"); err != nil { + t.Fatal(err) + } + assert.Regexp(t, regexp.MustCompile(`N\[.+\] unknown \s+module=wire\s+\n$`), buf.String()) +} + +func BenchmarkTMFmtLoggerSimple(b *testing.B) { + benchmarkRunnerKitlog(b, log.NewTMFmtLogger(ioutil.Discard), baseMessage) +} + +func BenchmarkTMFmtLoggerContextual(b *testing.B) { + benchmarkRunnerKitlog(b, log.NewTMFmtLogger(ioutil.Discard), withMessage) +} + +func TestTMFmtLoggerConcurrency(t *testing.T) { + t.Parallel() + testConcurrency(t, log.NewTMFmtLogger(ioutil.Discard), 10000) +} + +func benchmarkRunnerKitlog(b *testing.B, logger kitlog.Logger, f func(kitlog.Logger)) { + lc := kitlog.With(logger, "common_key", "common_value") + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + f(lc) + } +} + +var ( + baseMessage = func(logger kitlog.Logger) { logger.Log("foo_key", "foo_value") } + withMessage = func(logger kitlog.Logger) { kitlog.With(logger, "a", "b").Log("d", "f") } +) + +// These test are designed to be run with the race detector. + +func testConcurrency(t *testing.T, logger kitlog.Logger, total int) { + n := int(math.Sqrt(float64(total))) + share := total / n + + errC := make(chan error, n) + + for i := 0; i < n; i++ { + go func() { + errC <- spam(logger, share) + }() + } + + for i := 0; i < n; i++ { + err := <-errC + if err != nil { + t.Fatalf("concurrent logging error: %v", err) + } + } +} + +func spam(logger kitlog.Logger, count int) error { + for i := 0; i < count; i++ { + err := logger.Log("key", i) + if err != nil { + return err + } + } + return nil +} + +type mymap map[int]int + +func (m mymap) String() string { return "special_behavior" } diff --git a/log/tracing_logger.go b/log/tracing_logger.go new file mode 100644 index 00000000..d2a6ff44 --- /dev/null +++ b/log/tracing_logger.go @@ -0,0 +1,76 @@ +package log + +import ( + "fmt" + + "github.com/pkg/errors" +) + +// NewTracingLogger enables tracing by wrapping all errors (if they +// implement stackTracer interface) in tracedError. +// +// All errors returned by https://github.com/pkg/errors implement stackTracer +// interface. +// +// For debugging purposes only as it doubles the amount of allocations. +func NewTracingLogger(next Logger) Logger { + return &tracingLogger{ + next: next, + } +} + +type stackTracer interface { + error + StackTrace() errors.StackTrace +} + +type tracingLogger struct { + next Logger +} + +func (l *tracingLogger) Info(msg string, keyvals ...interface{}) { + l.next.Info(msg, formatErrors(keyvals)...) +} + +func (l *tracingLogger) Debug(msg string, keyvals ...interface{}) { + l.next.Debug(msg, formatErrors(keyvals)...) +} + +func (l *tracingLogger) Error(msg string, keyvals ...interface{}) { + l.next.Error(msg, formatErrors(keyvals)...) +} + +func (l *tracingLogger) With(keyvals ...interface{}) Logger { + return &tracingLogger{next: l.next.With(formatErrors(keyvals)...)} +} + +func formatErrors(keyvals []interface{}) []interface{} { + newKeyvals := make([]interface{}, len(keyvals)) + copy(newKeyvals, keyvals) + for i := 0; i < len(newKeyvals)-1; i += 2 { + if err, ok := newKeyvals[i+1].(stackTracer); ok { + newKeyvals[i+1] = tracedError{err} + } + } + return newKeyvals +} + +// tracedError wraps a stackTracer and just makes the Error() result +// always return a full stack trace. +type tracedError struct { + wrapped stackTracer +} + +var _ stackTracer = tracedError{} + +func (t tracedError) StackTrace() errors.StackTrace { + return t.wrapped.StackTrace() +} + +func (t tracedError) Cause() error { + return t.wrapped +} + +func (t tracedError) Error() string { + return fmt.Sprintf("%+v", t.wrapped) +} diff --git a/log/tracing_logger_test.go b/log/tracing_logger_test.go new file mode 100644 index 00000000..6b0838ca --- /dev/null +++ b/log/tracing_logger_test.go @@ -0,0 +1,41 @@ +package log_test + +import ( + "bytes" + stderr "errors" + "fmt" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/tendermint/tmlibs/log" +) + +func TestTracingLogger(t *testing.T) { + var buf bytes.Buffer + + logger := log.NewTMJSONLogger(&buf) + + logger1 := log.NewTracingLogger(logger) + err1 := errors.New("Courage is grace under pressure.") + err2 := errors.New("It does not matter how slowly you go, so long as you do not stop.") + logger1.With("err1", err1).Info("foo", "err2", err2) + have := strings.Replace(strings.Replace(strings.TrimSpace(buf.String()), "\\n", "", -1), "\\t", "", -1) + if want := strings.Replace(strings.Replace(`{"_msg":"foo","err1":"`+fmt.Sprintf("%+v", err1)+`","err2":"`+fmt.Sprintf("%+v", err2)+`","level":"info"}`, "\t", "", -1), "\n", "", -1); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } + + buf.Reset() + + logger.With("err1", stderr.New("Opportunities don't happen. You create them.")).Info("foo", "err2", stderr.New("Once you choose hope, anything's possible.")) + if want, have := `{"_msg":"foo","err1":"Opportunities don't happen. You create them.","err2":"Once you choose hope, anything's possible.","level":"info"}`, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } + + buf.Reset() + + logger.With("user", "Sam").With("context", "value").Info("foo", "bar", "baz") + if want, have := `{"_msg":"foo","bar":"baz","context":"value","level":"info","user":"Sam"}`, strings.TrimSpace(buf.String()); want != have { + t.Errorf("\nwant '%s'\nhave '%s'", want, have) + } +} diff --git a/merge.sh b/merge.sh new file mode 100644 index 00000000..7ff0ed94 --- /dev/null +++ b/merge.sh @@ -0,0 +1,54 @@ +#! /bin/bash +set -e + +# NOTE: go-alert depends on go-common + +REPOS=("autofile" "clist" "db" "events" "flowrate" "logger" "process") + +mkdir common +git mv *.go common +git mv LICENSE common + +git commit -m "move all files to common/ to begin repo merge" + +for repo in "${REPOS[@]}"; do + # add and fetch the repo + git remote add -f "$repo" "https://github.com/tendermint/go-${repo}" + + # merge master and move into subdir + git merge "$repo/master" --no-edit + + if [[ "$repo" != "flowrate" ]]; then + mkdir "$repo" + git mv *.go "$repo/" + fi + + set +e # these might not exist + git mv *.md "$repo/" + git mv README "$repo/README.md" + git mv Makefile "$repo/Makefile" + git rm LICENSE + set -e + + # commit + git commit -m "merge go-${repo}" + + git remote rm "$repo" +done + +go get github.com/ebuchman/got +got replace "tendermint/go-common" "tendermint/go-common/common" +for repo in "${REPOS[@]}"; do + + if [[ "$repo" != "flowrate" ]]; then + got replace "tendermint/go-${repo}" "tendermint/go-common/${repo}" + else + got replace "tendermint/go-${repo}/flowrate" "tendermint/go-common/flowrate" + fi +done + +git add -u +git commit -m "update import paths" + +# TODO: change any paths in non-Go files +# TODO: add license diff --git a/merkle/README.md b/merkle/README.md new file mode 100644 index 00000000..c4497836 --- /dev/null +++ b/merkle/README.md @@ -0,0 +1,4 @@ +## Simple Merkle Tree + +For smaller static data structures that don't require immutable snapshots or mutability; +for instance the transactions and validation signatures of a block can be hashed using this simple merkle tree logic. diff --git a/merkle/simple_map.go b/merkle/simple_map.go new file mode 100644 index 00000000..bd5c88d8 --- /dev/null +++ b/merkle/simple_map.go @@ -0,0 +1,84 @@ +package merkle + +import ( + cmn "github.com/tendermint/tmlibs/common" + "github.com/tendermint/tmlibs/merkle/tmhash" +) + +type SimpleMap struct { + kvs cmn.KVPairs + sorted bool +} + +func NewSimpleMap() *SimpleMap { + return &SimpleMap{ + kvs: nil, + sorted: false, + } +} + +func (sm *SimpleMap) Set(key string, value Hasher) { + sm.sorted = false + + // Hash the key to blind it... why not? + khash := SimpleHashFromBytes([]byte(key)) + + // And the value is hashed too, so you can + // check for equality with a cached value (say) + // and make a determination to fetch or not. + vhash := value.Hash() + + sm.kvs = append(sm.kvs, cmn.KVPair{ + Key: khash, + Value: vhash, + }) +} + +// Merkle root hash of items sorted by key +// (UNSTABLE: and by value too if duplicate key). +func (sm *SimpleMap) Hash() []byte { + sm.Sort() + return hashKVPairs(sm.kvs) +} + +func (sm *SimpleMap) Sort() { + if sm.sorted { + return + } + sm.kvs.Sort() + sm.sorted = true +} + +// Returns a copy of sorted KVPairs. +func (sm *SimpleMap) KVPairs() cmn.KVPairs { + sm.Sort() + kvs := make(cmn.KVPairs, len(sm.kvs)) + copy(kvs, sm.kvs) + return kvs +} + +//---------------------------------------- + +// A local extension to KVPair that can be hashed. +type KVPair cmn.KVPair + +func (kv KVPair) Hash() []byte { + hasher := tmhash.New() + err := encodeByteSlice(hasher, kv.Key) + if err != nil { + panic(err) + } + err = encodeByteSlice(hasher, kv.Value) + if err != nil { + panic(err) + } + return hasher.Sum(nil) +} + +func hashKVPairs(kvs cmn.KVPairs) []byte { + kvsH := make([]Hasher, 0, len(kvs)) + for _, kvp := range kvs { + kvsH = append(kvsH, KVPair(kvp)) + } + return SimpleHashFromHashers(kvsH) +} diff --git a/merkle/simple_map_test.go b/merkle/simple_map_test.go new file mode 100644 index 00000000..6e1004db --- /dev/null +++ b/merkle/simple_map_test.go @@ -0,0 +1,53 @@ +package merkle + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type strHasher string + +func (str strHasher) Hash() []byte { + return SimpleHashFromBytes([]byte(str)) +} + +func TestSimpleMap(t *testing.T) { + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + assert.Equal(t, "3dafc06a52039d029be57c75c9d16356a4256ef4", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value2")) + assert.Equal(t, "03eb5cfdff646bc4e80fec844e72fd248a1c6b2c", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + db.Set("key2", strHasher("value2")) + assert.Equal(t, "acc3971eab8513171cc90ce8b74f368c38f9657d", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", strHasher("value2")) // NOTE: out of order + db.Set("key1", strHasher("value1")) + assert.Equal(t, "acc3971eab8513171cc90ce8b74f368c38f9657d", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", strHasher("value1")) + db.Set("key2", strHasher("value2")) + db.Set("key3", strHasher("value3")) + assert.Equal(t, "0cd117ad14e6cd22edcd9aa0d84d7063b54b862f", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", strHasher("value2")) // NOTE: out of order + db.Set("key1", strHasher("value1")) + db.Set("key3", strHasher("value3")) + assert.Equal(t, "0cd117ad14e6cd22edcd9aa0d84d7063b54b862f", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } +} diff --git a/merkle/simple_proof.go b/merkle/simple_proof.go new file mode 100644 index 00000000..ca6ccf37 --- /dev/null +++ b/merkle/simple_proof.go @@ -0,0 +1,144 @@ +package merkle + +import ( + "bytes" + "fmt" +) + +type SimpleProof struct { + Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. +} + +// proofs[0] is the proof for items[0]. +func SimpleProofsFromHashers(items []Hasher) (rootHash []byte, proofs []*SimpleProof) { + trails, rootSPN := trailsFromHashers(items) + rootHash = rootSPN.Hash + proofs = make([]*SimpleProof, len(items)) + for i, trail := range trails { + proofs[i] = &SimpleProof{ + Aunts: trail.FlattenAunts(), + } + } + return +} + +func SimpleProofsFromMap(m map[string]Hasher) (rootHash []byte, proofs []*SimpleProof) { + sm := NewSimpleMap() + for k, v := range m { + sm.Set(k, v) + } + sm.Sort() + kvs := sm.kvs + kvsH := make([]Hasher, 0, len(kvs)) + for _, kvp := range kvs { + kvsH = append(kvsH, KVPair(kvp)) + } + return SimpleProofsFromHashers(kvsH) +} + +// Verify that leafHash is a leaf hash of the simple-merkle-tree +// which hashes to rootHash. +func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { + computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) + return computedHash != nil && bytes.Equal(computedHash, rootHash) +} + +func (sp *SimpleProof) String() string { + return sp.StringIndented("") +} + +func (sp *SimpleProof) StringIndented(indent string) string { + return fmt.Sprintf(`SimpleProof{ +%s Aunts: %X +%s}`, + indent, sp.Aunts, + indent) +} + +// Use the leafHash and innerHashes to get the root merkle hash. +// If the length of the innerHashes slice isn't exactly correct, the result is nil. +// Recursive impl. +func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { + if index >= total || index < 0 || total <= 0 { + return nil + } + switch total { + case 0: + panic("Cannot call computeHashFromAunts() with 0 total") + case 1: + if len(innerHashes) != 0 { + return nil + } + return leafHash + default: + if len(innerHashes) == 0 { + return nil + } + numLeft := (total + 1) / 2 + if index < numLeft { + leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if leftHash == nil { + return nil + } + return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) + } + rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if rightHash == nil { + return nil + } + return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) + } +} + +// Helper structure to construct merkle proof. +// The node and the tree is thrown away afterwards. +// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. +// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +type SimpleProofNode struct { + Hash []byte + Parent *SimpleProofNode + Left *SimpleProofNode // Left sibling (only one of Left,Right is set) + Right *SimpleProofNode // Right sibling (only one of Left,Right is set) +} + +// Starting from a leaf SimpleProofNode, FlattenAunts() will return +// the inner hashes for the item corresponding to the leaf. +func (spn *SimpleProofNode) FlattenAunts() [][]byte { + // Nonrecursive impl. + innerHashes := [][]byte{} + for spn != nil { + if spn.Left != nil { + innerHashes = append(innerHashes, spn.Left.Hash) + } else if spn.Right != nil { + innerHashes = append(innerHashes, spn.Right.Hash) + } else { + break + } + spn = spn.Parent + } + return innerHashes +} + +// trails[0].Hash is the leaf hash for items[0]. +// trails[i].Parent.Parent....Parent == root for all i. +func trailsFromHashers(items []Hasher) (trails []*SimpleProofNode, root *SimpleProofNode) { + // Recursive impl. + switch len(items) { + case 0: + return nil, nil + case 1: + trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} + return []*SimpleProofNode{trail}, trail + default: + lefts, leftRoot := trailsFromHashers(items[:(len(items)+1)/2]) + rights, rightRoot := trailsFromHashers(items[(len(items)+1)/2:]) + rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) + root := &SimpleProofNode{rootHash, nil, nil, nil} + leftRoot.Parent = root + leftRoot.Right = rightRoot + rightRoot.Parent = root + rightRoot.Left = leftRoot + return append(lefts, rights...), root + } +} diff --git a/merkle/simple_tree.go b/merkle/simple_tree.go new file mode 100644 index 00000000..6bd80f55 --- /dev/null +++ b/merkle/simple_tree.go @@ -0,0 +1,91 @@ +/* +Computes a deterministic minimal height merkle tree hash. +If the number of items is not a power of two, some leaves +will be at different levels. Tries to keep both sides of +the tree the same size, but the left may be one greater. + +Use this for short deterministic trees, such as the validator list. +For larger datasets, use IAVLTree. + + * + / \ + / \ + / \ + / \ + * * + / \ / \ + / \ / \ + / \ / \ + * * * h6 + / \ / \ / \ + h0 h1 h2 h3 h4 h5 + +*/ + +package merkle + +import ( + "github.com/tendermint/tmlibs/merkle/tmhash" +) + +func SimpleHashFromTwoHashes(left []byte, right []byte) []byte { + var hasher = tmhash.New() + err := encodeByteSlice(hasher, left) + if err != nil { + panic(err) + } + err = encodeByteSlice(hasher, right) + if err != nil { + panic(err) + } + return hasher.Sum(nil) +} + +func SimpleHashFromHashes(hashes [][]byte) []byte { + // Recursive impl. + switch len(hashes) { + case 0: + return nil + case 1: + return hashes[0] + default: + left := SimpleHashFromHashes(hashes[:(len(hashes)+1)/2]) + right := SimpleHashFromHashes(hashes[(len(hashes)+1)/2:]) + return SimpleHashFromTwoHashes(left, right) + } +} + +// NOTE: Do not implement this, use SimpleHashFromByteslices instead. +// type Byteser interface { Bytes() []byte } +// func SimpleHashFromBytesers(items []Byteser) []byte { ... } + +func SimpleHashFromByteslices(bzs [][]byte) []byte { + hashes := make([][]byte, len(bzs)) + for i, bz := range bzs { + hashes[i] = SimpleHashFromBytes(bz) + } + return SimpleHashFromHashes(hashes) +} + +func SimpleHashFromBytes(bz []byte) []byte { + hasher := tmhash.New() + hasher.Write(bz) + return hasher.Sum(nil) +} + +func SimpleHashFromHashers(items []Hasher) []byte { + hashes := make([][]byte, len(items)) + for i, item := range items { + hash := item.Hash() + hashes[i] = hash + } + return SimpleHashFromHashes(hashes) +} + +func SimpleHashFromMap(m map[string]Hasher) []byte { + sm := NewSimpleMap() + for k, v := range m { + sm.Set(k, v) + } + return sm.Hash() +} diff --git a/merkle/simple_tree_test.go b/merkle/simple_tree_test.go new file mode 100644 index 00000000..8c4ed01f --- /dev/null +++ b/merkle/simple_tree_test.go @@ -0,0 +1,87 @@ +package merkle + +import ( + "bytes" + + cmn "github.com/tendermint/tmlibs/common" + . "github.com/tendermint/tmlibs/test" + + "testing" +) + +type testItem []byte + +func (tI testItem) Hash() []byte { + return []byte(tI) +} + +func TestSimpleProof(t *testing.T) { + + total := 100 + + items := make([]Hasher, total) + for i := 0; i < total; i++ { + items[i] = testItem(cmn.RandBytes(32)) + } + + rootHash := SimpleHashFromHashers(items) + + rootHash2, proofs := SimpleProofsFromHashers(items) + + if !bytes.Equal(rootHash, rootHash2) { + t.Errorf("Unmatched root hashes: %X vs %X", rootHash, rootHash2) + } + + // For each item, check the trail. + for i, item := range items { + itemHash := item.Hash() + proof := proofs[i] + + // Verify success + ok := proof.Verify(i, total, itemHash, rootHash) + if !ok { + t.Errorf("Verification failed for index %v.", i) + } + + // Wrong item index should make it fail + { + ok = proof.Verify((i+1)%total, total, itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong index %v.", i) + } + } + + // Trail too long should make it fail + origAunts := proof.Aunts + proof.Aunts = append(proof.Aunts, cmn.RandBytes(32)) + { + ok = proof.Verify(i, total, itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong trail length.") + } + } + proof.Aunts = origAunts + + // Trail too short should make it fail + proof.Aunts = proof.Aunts[0 : len(proof.Aunts)-1] + { + ok = proof.Verify(i, total, itemHash, rootHash) + if ok { + t.Errorf("Expected verification to fail for wrong trail length.") + } + } + proof.Aunts = origAunts + + // Mutating the itemHash should make it fail. + ok = proof.Verify(i, total, MutateByteSlice(itemHash), rootHash) + if ok { + t.Errorf("Expected verification to fail for mutated leaf hash") + } + + // Mutating the rootHash should make it fail. + ok = proof.Verify(i, total, itemHash, MutateByteSlice(rootHash)) + if ok { + t.Errorf("Expected verification to fail for mutated root hash") + } + } +} diff --git a/merkle/tmhash/hash.go b/merkle/tmhash/hash.go new file mode 100644 index 00000000..de69c406 --- /dev/null +++ b/merkle/tmhash/hash.go @@ -0,0 +1,41 @@ +package tmhash + +import ( + "crypto/sha256" + "hash" +) + +var ( + Size = 20 + BlockSize = sha256.BlockSize +) + +type sha256trunc struct { + sha256 hash.Hash +} + +func (h sha256trunc) Write(p []byte) (n int, err error) { + return h.sha256.Write(p) +} +func (h sha256trunc) Sum(b []byte) []byte { + shasum := h.sha256.Sum(b) + return shasum[:Size] +} + +func (h sha256trunc) Reset() { + h.sha256.Reset() +} + +func (h sha256trunc) Size() int { + return Size +} + +func (h sha256trunc) BlockSize() int { + return h.sha256.BlockSize() +} + +func New() hash.Hash { + return sha256trunc{ + sha256: sha256.New(), + } +} diff --git a/merkle/tmhash/hash_test.go b/merkle/tmhash/hash_test.go new file mode 100644 index 00000000..c9e80f2b --- /dev/null +++ b/merkle/tmhash/hash_test.go @@ -0,0 +1,23 @@ +package tmhash_test + +import ( + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tendermint/tmlibs/merkle/tmhash" +) + +func TestHash(t *testing.T) { + testVector := []byte("abc") + hasher := tmhash.New() + hasher.Write(testVector) + bz := hasher.Sum(nil) + + hasher = sha256.New() + hasher.Write(testVector) + bz2 := hasher.Sum(nil) + bz2 = bz2[:20] + + assert.Equal(t, bz, bz2) +} diff --git a/merkle/types.go b/merkle/types.go new file mode 100644 index 00000000..a0c491a7 --- /dev/null +++ b/merkle/types.go @@ -0,0 +1,47 @@ +package merkle + +import ( + "encoding/binary" + "io" +) + +type Tree interface { + Size() (size int) + Height() (height int8) + Has(key []byte) (has bool) + Proof(key []byte) (value []byte, proof []byte, exists bool) // TODO make it return an index + Get(key []byte) (index int, value []byte, exists bool) + GetByIndex(index int) (key []byte, value []byte) + Set(key []byte, value []byte) (updated bool) + Remove(key []byte) (value []byte, removed bool) + HashWithCount() (hash []byte, count int) + Hash() (hash []byte) + Save() (hash []byte) + Load(hash []byte) + Copy() Tree + Iterate(func(key []byte, value []byte) (stop bool)) (stopped bool) + IterateRange(start []byte, end []byte, ascending bool, fx func(key []byte, value []byte) (stop bool)) (stopped bool) +} + +type Hasher interface { + Hash() []byte +} + +//----------------------------------------------------------------------- +// NOTE: these are duplicated from go-amino so we dont need go-amino as a dep + +func encodeByteSlice(w io.Writer, bz []byte) (err error) { + err = encodeUvarint(w, uint64(len(bz))) + if err != nil { + return + } + _, err = w.Write(bz) + return +} + +func encodeUvarint(w io.Writer, i uint64) (err error) { + var buf [10]byte + n := binary.PutUvarint(buf[:], i) + _, err = w.Write(buf[0:n]) + return +} diff --git a/test.sh b/test.sh new file mode 100755 index 00000000..ecf17fc4 --- /dev/null +++ b/test.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -e + +# run the linter +# make metalinter_test + +# setup certs +make gen_certs + +# run the unit tests with coverage +echo "" > coverage.txt +for d in $(go list ./... | grep -v vendor); do + go test -race -coverprofile=profile.out -covermode=atomic "$d" + if [ -f profile.out ]; then + cat profile.out >> coverage.txt + rm profile.out + fi +done + +# cleanup certs +make clean_certs diff --git a/test/assert.go b/test/assert.go new file mode 100644 index 00000000..a6ffed0c --- /dev/null +++ b/test/assert.go @@ -0,0 +1,14 @@ +package test + +import ( + "testing" +) + +func AssertPanics(t *testing.T, msg string, f func()) { + defer func() { + if err := recover(); err == nil { + t.Errorf("Should have panic'd, but didn't: %v", msg) + } + }() + f() +} diff --git a/test/mutate.go b/test/mutate.go new file mode 100644 index 00000000..76534e8b --- /dev/null +++ b/test/mutate.go @@ -0,0 +1,28 @@ +package test + +import ( + cmn "github.com/tendermint/tmlibs/common" +) + +// Contract: !bytes.Equal(input, output) && len(input) >= len(output) +func MutateByteSlice(bytez []byte) []byte { + // If bytez is empty, panic + if len(bytez) == 0 { + panic("Cannot mutate an empty bytez") + } + + // Copy bytez + mBytez := make([]byte, len(bytez)) + copy(mBytez, bytez) + bytez = mBytez + + // Try a random mutation + switch cmn.RandInt() % 2 { + case 0: // Mutate a single byte + bytez[cmn.RandInt()%len(bytez)] += byte(cmn.RandInt()%255 + 1) + case 1: // Remove an arbitrary byte + pos := cmn.RandInt() % len(bytez) + bytez = append(bytez[:pos], bytez[pos+1:]...) + } + return bytez +} diff --git a/version/version.go b/version/version.go new file mode 100644 index 00000000..6e73a937 --- /dev/null +++ b/version/version.go @@ -0,0 +1,3 @@ +package version + +const Version = "0.9.0"