mirror of https://github.com/poanetwork/gecko.git
284 lines
5.9 KiB
Go
284 lines
5.9 KiB
Go
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
|
|
// See the file LICENSE for licensing terms.
|
|
|
|
package encdb
|
|
|
|
import (
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"sync"
|
|
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
|
|
"github.com/ava-labs/gecko/database"
|
|
"github.com/ava-labs/gecko/database/nodb"
|
|
"github.com/ava-labs/gecko/utils/hashing"
|
|
"github.com/ava-labs/gecko/vms/components/codec"
|
|
)
|
|
|
|
// Database encrypts all values that are provided
|
|
type Database struct {
|
|
lock sync.RWMutex
|
|
codec codec.Codec
|
|
cipher cipher.AEAD
|
|
db database.Database
|
|
}
|
|
|
|
// New returns a new encrypted database
|
|
func New(password []byte, db database.Database) (*Database, error) {
|
|
h := hashing.ComputeHash256(password)
|
|
aead, err := chacha20poly1305.NewX(h)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Database{
|
|
codec: codec.NewDefault(),
|
|
cipher: aead,
|
|
db: db,
|
|
}, nil
|
|
}
|
|
|
|
// Has implements the Database interface
|
|
func (db *Database) Has(key []byte) (bool, error) {
|
|
db.lock.RLock()
|
|
defer db.lock.RUnlock()
|
|
|
|
if db.db == nil {
|
|
return false, database.ErrClosed
|
|
}
|
|
return db.db.Has(key)
|
|
}
|
|
|
|
// Get implements the Database interface
|
|
func (db *Database) Get(key []byte) ([]byte, error) {
|
|
db.lock.RLock()
|
|
defer db.lock.RUnlock()
|
|
|
|
if db.db == nil {
|
|
return nil, database.ErrClosed
|
|
}
|
|
encVal, err := db.db.Get(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return db.decrypt(encVal)
|
|
}
|
|
|
|
// Put implements the Database interface
|
|
func (db *Database) Put(key, value []byte) error {
|
|
db.lock.Lock()
|
|
defer db.lock.Unlock()
|
|
|
|
if db.db == nil {
|
|
return database.ErrClosed
|
|
}
|
|
|
|
encValue, err := db.encrypt(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return db.db.Put(key, encValue)
|
|
}
|
|
|
|
// Delete implements the Database interface
|
|
func (db *Database) Delete(key []byte) error {
|
|
db.lock.Lock()
|
|
defer db.lock.Unlock()
|
|
|
|
if db.db == nil {
|
|
return database.ErrClosed
|
|
}
|
|
return db.db.Delete(key)
|
|
}
|
|
|
|
// NewBatch implements the Database interface
|
|
func (db *Database) NewBatch() database.Batch {
|
|
return &batch{
|
|
Batch: db.db.NewBatch(),
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// NewIterator implements the Database interface
|
|
func (db *Database) NewIterator() database.Iterator { return db.NewIteratorWithStartAndPrefix(nil, nil) }
|
|
|
|
// NewIteratorWithStart implements the Database interface
|
|
func (db *Database) NewIteratorWithStart(start []byte) database.Iterator {
|
|
return db.NewIteratorWithStartAndPrefix(start, nil)
|
|
}
|
|
|
|
// NewIteratorWithPrefix implements the Database interface
|
|
func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator {
|
|
return db.NewIteratorWithStartAndPrefix(nil, prefix)
|
|
}
|
|
|
|
// NewIteratorWithStartAndPrefix implements the Database interface
|
|
func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator {
|
|
db.lock.RLock()
|
|
defer db.lock.RUnlock()
|
|
|
|
if db.db == nil {
|
|
return &nodb.Iterator{Err: database.ErrClosed}
|
|
}
|
|
return &iterator{
|
|
Iterator: db.db.NewIteratorWithStartAndPrefix(start, prefix),
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
// Stat implements the Database interface
|
|
func (db *Database) Stat(stat string) (string, error) {
|
|
db.lock.RLock()
|
|
defer db.lock.RUnlock()
|
|
|
|
if db.db == nil {
|
|
return "", database.ErrClosed
|
|
}
|
|
return db.db.Stat(stat)
|
|
}
|
|
|
|
// Compact implements the Database interface
|
|
func (db *Database) Compact(start, limit []byte) error {
|
|
db.lock.Lock()
|
|
defer db.lock.Unlock()
|
|
|
|
if db.db == nil {
|
|
return database.ErrClosed
|
|
}
|
|
return db.db.Compact(start, limit)
|
|
}
|
|
|
|
// Close implements the Database interface
|
|
func (db *Database) Close() error {
|
|
db.lock.Lock()
|
|
defer db.lock.Unlock()
|
|
|
|
if db.db == nil {
|
|
return database.ErrClosed
|
|
}
|
|
db.db = nil
|
|
return nil
|
|
}
|
|
|
|
type keyValue struct {
|
|
key []byte
|
|
value []byte
|
|
delete bool
|
|
}
|
|
|
|
type batch struct {
|
|
database.Batch
|
|
|
|
db *Database
|
|
writes []keyValue
|
|
}
|
|
|
|
func (b *batch) Put(key, value []byte) error {
|
|
b.writes = append(b.writes, keyValue{copyBytes(key), copyBytes(value), false})
|
|
encValue, err := b.db.encrypt(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return b.Batch.Put(key, encValue)
|
|
}
|
|
|
|
func (b *batch) Delete(key []byte) error {
|
|
b.writes = append(b.writes, keyValue{copyBytes(key), nil, true})
|
|
return b.Batch.Delete(key)
|
|
}
|
|
|
|
func (b *batch) Write() error {
|
|
b.db.lock.Lock()
|
|
defer b.db.lock.Unlock()
|
|
|
|
if b.db.db == nil {
|
|
return database.ErrClosed
|
|
}
|
|
|
|
return b.Batch.Write()
|
|
}
|
|
|
|
// Reset resets the batch for reuse.
|
|
func (b *batch) Reset() {
|
|
b.writes = b.writes[:0]
|
|
b.Batch.Reset()
|
|
}
|
|
|
|
// Replay replays the batch contents.
|
|
func (b *batch) Replay(w database.KeyValueWriter) error {
|
|
for _, keyvalue := range b.writes {
|
|
if keyvalue.delete {
|
|
if err := w.Delete(keyvalue.key); err != nil {
|
|
return err
|
|
}
|
|
} else if err := w.Put(keyvalue.key, keyvalue.value); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type iterator struct {
|
|
database.Iterator
|
|
db *Database
|
|
|
|
val []byte
|
|
err error
|
|
}
|
|
|
|
func (it *iterator) Next() bool {
|
|
next := it.Iterator.Next()
|
|
if next {
|
|
encVal := it.Iterator.Value()
|
|
val, err := it.db.decrypt(encVal)
|
|
if err != nil {
|
|
it.err = err
|
|
return false
|
|
}
|
|
it.val = val
|
|
} else {
|
|
it.val = nil
|
|
}
|
|
return next
|
|
}
|
|
|
|
func (it *iterator) Error() error {
|
|
if it.err != nil {
|
|
return it.err
|
|
}
|
|
return it.Iterator.Error()
|
|
}
|
|
|
|
func (it *iterator) Value() []byte { return it.val }
|
|
|
|
func copyBytes(bytes []byte) []byte {
|
|
copiedBytes := make([]byte, len(bytes))
|
|
copy(copiedBytes, bytes)
|
|
return copiedBytes
|
|
}
|
|
|
|
type encryptedValue struct {
|
|
Ciphertext []byte `serialize:"true"`
|
|
Nonce []byte `serialize:"true"`
|
|
}
|
|
|
|
func (db *Database) encrypt(plaintext []byte) ([]byte, error) {
|
|
nonce := make([]byte, chacha20poly1305.NonceSizeX)
|
|
if _, err := rand.Read(nonce); err != nil {
|
|
return nil, err
|
|
}
|
|
ciphertext := db.cipher.Seal(nil, nonce, plaintext, nil)
|
|
return db.codec.Marshal(&encryptedValue{
|
|
Ciphertext: ciphertext,
|
|
Nonce: nonce,
|
|
})
|
|
}
|
|
|
|
func (db *Database) decrypt(ciphertext []byte) ([]byte, error) {
|
|
val := encryptedValue{}
|
|
if err := db.codec.Unmarshal(ciphertext, &val); err != nil {
|
|
return nil, err
|
|
}
|
|
return db.cipher.Open(nil, val.Nonce, val.Ciphertext, nil)
|
|
}
|