diff --git a/vms/platformvm/reward_validator_tx.go b/vms/platformvm/reward_validator_tx.go index fa1d309..2903e42 100644 --- a/vms/platformvm/reward_validator_tx.go +++ b/vms/platformvm/reward_validator_tx.go @@ -63,7 +63,7 @@ func (tx *rewardValidatorTx) SemanticVerify(db database.Database) (*versiondb.Da return nil, nil, nil, nil, err } if db == nil { - return nil, nil, nil, nil, errDbNil + return nil, nil, nil, nil, errDBNil } currentEvents, err := tx.vm.getCurrentValidators(db, DefaultSubnetID) diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index 86093b0..e2a9ffe 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -322,7 +322,7 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re return errGetAccounts } - var accounts []APIAccount + reply.Accounts = []APIAccount{} for _, accountID := range accountIDs { account, err := service.vm.getAccount(service.vm.DB, accountID) // Get account whose ID is [accountID] if err != nil && err != database.ErrNotFound { @@ -331,13 +331,12 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re } else if err == database.ErrNotFound { account = newAccount(accountID, 0, 0) } - accounts = append(accounts, APIAccount{ + reply.Accounts = append(reply.Accounts, APIAccount{ Address: accountID, Nonce: json.Uint64(account.Nonce), Balance: json.Uint64(account.Balance), }) } - reply.Accounts = accounts return nil } diff --git a/vms/platformvm/user.go b/vms/platformvm/user.go index e8d9c7e..45f7ce9 100644 --- a/vms/platformvm/user.go +++ b/vms/platformvm/user.go @@ -15,7 +15,10 @@ import ( // account IDs this user controls var accountIDsKey = ids.Empty.Bytes() -var errDbNil = errors.New("db uninitialized") +var ( + errDBNil = errors.New("db uninitialized") + errKeyNil = errors.New("key uninitialized") +) type user struct { // This user's database, acquired from the keystore @@ -25,7 +28,7 @@ type user struct { // Get the IDs of the accounts controlled by this user func (u *user) getAccountIDs() ([]ids.ShortID, error) { if u.db == nil { - return nil, errDbNil + return nil, errDBNil } // If user has no accounts, return empty list @@ -34,8 +37,9 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) { return nil, errDB } if !hasAccounts { - return make([]ids.ShortID, 0), nil + return nil, nil } + // User has accounts. Get them. bytes, err := u.db.Get(accountIDsKey) if err != nil { @@ -50,21 +54,24 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) { // controlsAccount returns true iff this user controls the account // with the specified ID -func (u *user) controlsAccount(ID ids.ShortID) (bool, error) { +func (u *user) controlsAccount(accountID ids.ShortID) (bool, error) { if u.db == nil { - return false, errDbNil + return false, errDBNil } - - if _, err := u.db.Get(ID.Bytes()); err == nil { - return true, nil + if accountID.IsZero() { + return false, errEmptyAccountAddress } - return false, nil + return u.db.Has(accountID.Bytes()) } // putAccount persists that this user controls the account whose ID is // [privKey].PublicKey().Address() func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error { - newAccountID := privKey.PublicKey().Address() // Account thie privKey controls + if privKey == nil { + return errKeyNil + } + + newAccountID := privKey.PublicKey().Address() // Account the privKey controls controlsAccount, err := u.controlsAccount(newAccountID) if err != nil { return err @@ -102,7 +109,10 @@ func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error { // Key returns the private key that controls the account with the specified ID func (u *user) getKey(accountID ids.ShortID) (*crypto.PrivateKeySECP256K1R, error) { if u.db == nil { - return nil, errDbNil + return nil, errDBNil + } + if accountID.IsZero() { + return nil, errEmptyAccountAddress } factory := crypto.FactorySECP256K1R{} diff --git a/vms/platformvm/user_test.go b/vms/platformvm/user_test.go new file mode 100644 index 0000000..758be35 --- /dev/null +++ b/vms/platformvm/user_test.go @@ -0,0 +1,108 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package platformvm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ava-labs/gecko/database/memdb" + "github.com/ava-labs/gecko/ids" + "github.com/ava-labs/gecko/utils/crypto" +) + +func TestUserNilDB(t *testing.T) { + u := user{} + + _, err := u.getAccountIDs() + assert.Error(t, err, "nil db should have caused an error") + + _, err = u.controlsAccount(ids.ShortEmpty) + assert.Error(t, err, "nil db should have caused an error") + + _, err = u.getKey(ids.ShortEmpty) + assert.Error(t, err, "nil db should have caused an error") + + factory := crypto.FactorySECP256K1R{} + sk, err := factory.NewPrivateKey() + assert.NoError(t, err) + + err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R)) + assert.Error(t, err, "nil db should have caused an error") +} + +func TestUserClosedDB(t *testing.T) { + db := memdb.New() + err := db.Close() + assert.NoError(t, err) + + u := user{db: db} + + _, err = u.getAccountIDs() + assert.Error(t, err, "closed db should have caused an error") + + _, err = u.controlsAccount(ids.ShortEmpty) + assert.Error(t, err, "closed db should have caused an error") + + _, err = u.getKey(ids.ShortEmpty) + assert.Error(t, err, "closed db should have caused an error") + + factory := crypto.FactorySECP256K1R{} + sk, err := factory.NewPrivateKey() + assert.NoError(t, err) + + err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R)) + assert.Error(t, err, "closed db should have caused an error") +} + +func TestUserNilSK(t *testing.T) { + u := user{db: memdb.New()} + + err := u.putAccount(nil) + assert.Error(t, err, "nil key should have caused an error") +} + +func TestUserNilAccount(t *testing.T) { + u := user{db: memdb.New()} + + _, err := u.controlsAccount(ids.ShortID{}) + assert.Error(t, err, "nil accountID should have caused an error") + + _, err = u.getKey(ids.ShortID{}) + assert.Error(t, err, "nil accountID should have caused an error") +} + +func TestUser(t *testing.T) { + u := user{db: memdb.New()} + + accountIDs, err := u.getAccountIDs() + assert.NoError(t, err) + assert.Empty(t, accountIDs, "new user shouldn't have accounts") + + factory := crypto.FactorySECP256K1R{} + sk, err := factory.NewPrivateKey() + assert.NoError(t, err) + + err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R)) + assert.NoError(t, err) + + addr := sk.PublicKey().Address() + + ok, err := u.controlsAccount(addr) + assert.NoError(t, err) + assert.True(t, ok, "added account should have been marked as controlled") + + savedSk, err := u.getKey(addr) + assert.NoError(t, err) + assert.Equal(t, sk.Bytes(), savedSk.Bytes(), "wrong key returned") + + accountIDs, err = u.getAccountIDs() + assert.NoError(t, err) + assert.Len(t, accountIDs, 1, "account should have been added") + + savedAddr := accountIDs[0] + equals := addr.Equals(savedAddr) + assert.True(t, equals, "saved address should match provided address") +}