gecko/database/encdb/encdb.go

284 lines
5.9 KiB
Go

// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package encdb
import (
"crypto/cipher"
"crypto/rand"
"sync"
"golang.org/x/crypto/chacha20poly1305"
"github.com/ava-labs/gecko/database"
"github.com/ava-labs/gecko/database/nodb"
"github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/vms/components/codec"
)
// Database encrypts all values that are provided
type Database struct {
lock sync.RWMutex
codec codec.Codec
cipher cipher.AEAD
db database.Database
}
// New returns a new encrypted database
func New(password []byte, db database.Database) (*Database, error) {
h := hashing.ComputeHash256(password)
aead, err := chacha20poly1305.NewX(h)
if err != nil {
return nil, err
}
return &Database{
codec: codec.NewDefault(),
cipher: aead,
db: db,
}, nil
}
// Has implements the Database interface
func (db *Database) Has(key []byte) (bool, error) {
db.lock.RLock()
defer db.lock.RUnlock()
if db.db == nil {
return false, database.ErrClosed
}
return db.db.Has(key)
}
// Get implements the Database interface
func (db *Database) Get(key []byte) ([]byte, error) {
db.lock.RLock()
defer db.lock.RUnlock()
if db.db == nil {
return nil, database.ErrClosed
}
encVal, err := db.db.Get(key)
if err != nil {
return nil, err
}
return db.decrypt(encVal)
}
// Put implements the Database interface
func (db *Database) Put(key, value []byte) error {
db.lock.Lock()
defer db.lock.Unlock()
if db.db == nil {
return database.ErrClosed
}
encValue, err := db.encrypt(value)
if err != nil {
return err
}
return db.db.Put(key, encValue)
}
// Delete implements the Database interface
func (db *Database) Delete(key []byte) error {
db.lock.Lock()
defer db.lock.Unlock()
if db.db == nil {
return database.ErrClosed
}
return db.db.Delete(key)
}
// NewBatch implements the Database interface
func (db *Database) NewBatch() database.Batch {
return &batch{
Batch: db.db.NewBatch(),
db: db,
}
}
// NewIterator implements the Database interface
func (db *Database) NewIterator() database.Iterator { return db.NewIteratorWithStartAndPrefix(nil, nil) }
// NewIteratorWithStart implements the Database interface
func (db *Database) NewIteratorWithStart(start []byte) database.Iterator {
return db.NewIteratorWithStartAndPrefix(start, nil)
}
// NewIteratorWithPrefix implements the Database interface
func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator {
return db.NewIteratorWithStartAndPrefix(nil, prefix)
}
// NewIteratorWithStartAndPrefix implements the Database interface
func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator {
db.lock.RLock()
defer db.lock.RUnlock()
if db.db == nil {
return &nodb.Iterator{Err: database.ErrClosed}
}
return &iterator{
Iterator: db.db.NewIteratorWithStartAndPrefix(start, prefix),
db: db,
}
}
// Stat implements the Database interface
func (db *Database) Stat(stat string) (string, error) {
db.lock.RLock()
defer db.lock.RUnlock()
if db.db == nil {
return "", database.ErrClosed
}
return db.db.Stat(stat)
}
// Compact implements the Database interface
func (db *Database) Compact(start, limit []byte) error {
db.lock.Lock()
defer db.lock.Unlock()
if db.db == nil {
return database.ErrClosed
}
return db.db.Compact(start, limit)
}
// Close implements the Database interface
func (db *Database) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
if db.db == nil {
return database.ErrClosed
}
db.db = nil
return nil
}
type keyValue struct {
key []byte
value []byte
delete bool
}
type batch struct {
database.Batch
db *Database
writes []keyValue
}
func (b *batch) Put(key, value []byte) error {
b.writes = append(b.writes, keyValue{copyBytes(key), copyBytes(value), false})
encValue, err := b.db.encrypt(value)
if err != nil {
return err
}
return b.Batch.Put(key, encValue)
}
func (b *batch) Delete(key []byte) error {
b.writes = append(b.writes, keyValue{copyBytes(key), nil, true})
return b.Batch.Delete(key)
}
func (b *batch) Write() error {
b.db.lock.Lock()
defer b.db.lock.Unlock()
if b.db.db == nil {
return database.ErrClosed
}
return b.Batch.Write()
}
// Reset resets the batch for reuse.
func (b *batch) Reset() {
b.writes = b.writes[:0]
b.Batch.Reset()
}
// Replay replays the batch contents.
func (b *batch) Replay(w database.KeyValueWriter) error {
for _, keyvalue := range b.writes {
if keyvalue.delete {
if err := w.Delete(keyvalue.key); err != nil {
return err
}
} else if err := w.Put(keyvalue.key, keyvalue.value); err != nil {
return err
}
}
return nil
}
type iterator struct {
database.Iterator
db *Database
val []byte
err error
}
func (it *iterator) Next() bool {
next := it.Iterator.Next()
if next {
encVal := it.Iterator.Value()
val, err := it.db.decrypt(encVal)
if err != nil {
it.err = err
return false
}
it.val = val
} else {
it.val = nil
}
return next
}
func (it *iterator) Error() error {
if it.err != nil {
return it.err
}
return it.Iterator.Error()
}
func (it *iterator) Value() []byte { return it.val }
func copyBytes(bytes []byte) []byte {
copiedBytes := make([]byte, len(bytes))
copy(copiedBytes, bytes)
return copiedBytes
}
type encryptedValue struct {
Ciphertext []byte `serialize:"true"`
Nonce []byte `serialize:"true"`
}
func (db *Database) encrypt(plaintext []byte) ([]byte, error) {
nonce := make([]byte, chacha20poly1305.NonceSizeX)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
ciphertext := db.cipher.Seal(nil, nonce, plaintext, nil)
return db.codec.Marshal(&encryptedValue{
Ciphertext: ciphertext,
Nonce: nonce,
})
}
func (db *Database) decrypt(ciphertext []byte) ([]byte, error) {
val := encryptedValue{}
if err := db.codec.Unmarshal(ciphertext, &val); err != nil {
return nil, err
}
return db.cipher.Open(nil, val.Nonce, val.Ciphertext, nil)
}