diff --git a/ethdb/database_test.go b/ethdb/database_test.go index 0e69a1218..4740cdaed 100644 --- a/ethdb/database_test.go +++ b/ethdb/database_test.go @@ -14,21 +14,164 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package ethdb +package ethdb_test import ( + "bytes" + "fmt" + "io/ioutil" "os" - "path/filepath" + "strconv" + "sync" + "testing" - "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" ) -func newDb() *LDBDatabase { - file := filepath.Join("/", "tmp", "ldbtesttmpfile") - if common.FileExist(file) { - os.RemoveAll(file) +func newTestLDB() (*ethdb.LDBDatabase, func()) { + dirname, err := ioutil.TempDir(os.TempDir(), "ethdb_test_") + if err != nil { + panic("failed to create test file: " + err.Error()) + } + db, err := ethdb.NewLDBDatabase(dirname, 0, 0) + if err != nil { + panic("failed to create test database: " + err.Error()) } - db, _ := NewLDBDatabase(file, 0, 0) - return db + return db, func() { + db.Close() + os.RemoveAll(dirname) + } +} + +var test_values = []string{"", "a", "1251", "\x00123\x00"} + +func TestLDB_PutGet(t *testing.T) { + db, remove := newTestLDB() + defer remove() + testPutGet(db, t) +} + +func TestMemoryDB_PutGet(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + testPutGet(db, t) +} + +func testPutGet(db ethdb.Database, t *testing.T) { + t.Parallel() + + for _, v := range test_values { + err := db.Put([]byte(v), []byte(v)) + if err != nil { + t.Fatalf("put failed: %v", err) + } + } + + for _, v := range test_values { + data, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if !bytes.Equal(data, []byte(v)) { + t.Fatalf("get returned wrong result, got %q expected %q", string(data), v) + } + } + + for _, v := range test_values { + err := db.Put([]byte(v), []byte("?")) + if err != nil { + t.Fatalf("put override failed: %v", err) + } + } + + for _, v := range test_values { + data, err := db.Get([]byte(v)) + if err != nil { + t.Fatalf("get failed: %v", err) + } + if !bytes.Equal(data, []byte("?")) { + t.Fatalf("get returned wrong result, got %q expected ?", string(data)) + } + } + + for _, v := range test_values { + err := db.Delete([]byte(v)) + if err != nil { + t.Fatalf("delete %q failed: %v", v, err) + } + } + + for _, v := range test_values { + _, err := db.Get([]byte(v)) + if err == nil { + t.Fatalf("got deleted value %q", v) + } + } +} + +func TestLDB_ParallelPutGet(t *testing.T) { + db, remove := newTestLDB() + defer remove() + testParallelPutGet(db, t) +} + +func TestMemoryDB_ParallelPutGet(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + testParallelPutGet(db, t) +} + +func testParallelPutGet(db ethdb.Database, t *testing.T) { + const n = 8 + var pending sync.WaitGroup + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + err := db.Put([]byte(key), []byte("v"+key)) + if err != nil { + panic("put failed: " + err.Error()) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + data, err := db.Get([]byte(key)) + if err != nil { + panic("get failed: " + err.Error()) + } + if !bytes.Equal(data, []byte("v"+key)) { + panic(fmt.Sprintf("get failed, got %q expected %q", []byte(data), []byte("v"+key))) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + err := db.Delete([]byte(key)) + if err != nil { + panic("delete failed: " + err.Error()) + } + }(strconv.Itoa(i)) + } + pending.Wait() + + pending.Add(n) + for i := 0; i < n; i++ { + go func(key string) { + defer pending.Done() + _, err := db.Get([]byte(key)) + if err == nil { + panic("get succeeded") + } + }(strconv.Itoa(i)) + } + pending.Wait() }