Added user tests

This commit is contained in:
StephenButtolph 2020-05-02 14:05:57 -04:00
parent acb96c8184
commit 4c3fce408e
4 changed files with 129 additions and 15 deletions

View File

@ -63,7 +63,7 @@ func (tx *rewardValidatorTx) SemanticVerify(db database.Database) (*versiondb.Da
return nil, nil, nil, nil, err
}
if db == nil {
return nil, nil, nil, nil, errDbNil
return nil, nil, nil, nil, errDBNil
}
currentEvents, err := tx.vm.getCurrentValidators(db, DefaultSubnetID)

View File

@ -322,7 +322,7 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re
return errGetAccounts
}
var accounts []APIAccount
reply.Accounts = []APIAccount{}
for _, accountID := range accountIDs {
account, err := service.vm.getAccount(service.vm.DB, accountID) // Get account whose ID is [accountID]
if err != nil && err != database.ErrNotFound {
@ -331,13 +331,12 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re
} else if err == database.ErrNotFound {
account = newAccount(accountID, 0, 0)
}
accounts = append(accounts, APIAccount{
reply.Accounts = append(reply.Accounts, APIAccount{
Address: accountID,
Nonce: json.Uint64(account.Nonce),
Balance: json.Uint64(account.Balance),
})
}
reply.Accounts = accounts
return nil
}

View File

@ -15,7 +15,10 @@ import (
// account IDs this user controls
var accountIDsKey = ids.Empty.Bytes()
var errDbNil = errors.New("db uninitialized")
var (
errDBNil = errors.New("db uninitialized")
errKeyNil = errors.New("key uninitialized")
)
type user struct {
// This user's database, acquired from the keystore
@ -25,7 +28,7 @@ type user struct {
// Get the IDs of the accounts controlled by this user
func (u *user) getAccountIDs() ([]ids.ShortID, error) {
if u.db == nil {
return nil, errDbNil
return nil, errDBNil
}
// If user has no accounts, return empty list
@ -34,8 +37,9 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) {
return nil, errDB
}
if !hasAccounts {
return make([]ids.ShortID, 0), nil
return nil, nil
}
// User has accounts. Get them.
bytes, err := u.db.Get(accountIDsKey)
if err != nil {
@ -50,21 +54,24 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) {
// controlsAccount returns true iff this user controls the account
// with the specified ID
func (u *user) controlsAccount(ID ids.ShortID) (bool, error) {
func (u *user) controlsAccount(accountID ids.ShortID) (bool, error) {
if u.db == nil {
return false, errDbNil
return false, errDBNil
}
if _, err := u.db.Get(ID.Bytes()); err == nil {
return true, nil
if accountID.IsZero() {
return false, errEmptyAccountAddress
}
return false, nil
return u.db.Has(accountID.Bytes())
}
// putAccount persists that this user controls the account whose ID is
// [privKey].PublicKey().Address()
func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error {
newAccountID := privKey.PublicKey().Address() // Account thie privKey controls
if privKey == nil {
return errKeyNil
}
newAccountID := privKey.PublicKey().Address() // Account the privKey controls
controlsAccount, err := u.controlsAccount(newAccountID)
if err != nil {
return err
@ -102,7 +109,7 @@ func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error {
// Key returns the private key that controls the account with the specified ID
func (u *user) getKey(accountID ids.ShortID) (*crypto.PrivateKeySECP256K1R, error) {
if u.db == nil {
return nil, errDbNil
return nil, errDBNil
}
if accountID.IsZero() {
return nil, errEmptyAccountAddress

108
vms/platformvm/user_test.go Normal file
View File

@ -0,0 +1,108 @@
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package platformvm
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/ava-labs/gecko/database/memdb"
"github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/crypto"
)
func TestUserNilDB(t *testing.T) {
u := user{}
_, err := u.getAccountIDs()
assert.Error(t, err, "nil db should have caused an error")
_, err = u.controlsAccount(ids.ShortEmpty)
assert.Error(t, err, "nil db should have caused an error")
_, err = u.getKey(ids.ShortEmpty)
assert.Error(t, err, "nil db should have caused an error")
factory := crypto.FactorySECP256K1R{}
sk, err := factory.NewPrivateKey()
assert.NoError(t, err)
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
assert.Error(t, err, "nil db should have caused an error")
}
func TestUserClosedDB(t *testing.T) {
db := memdb.New()
err := db.Close()
assert.NoError(t, err)
u := user{db: db}
_, err = u.getAccountIDs()
assert.Error(t, err, "closed db should have caused an error")
_, err = u.controlsAccount(ids.ShortEmpty)
assert.Error(t, err, "closed db should have caused an error")
_, err = u.getKey(ids.ShortEmpty)
assert.Error(t, err, "closed db should have caused an error")
factory := crypto.FactorySECP256K1R{}
sk, err := factory.NewPrivateKey()
assert.NoError(t, err)
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
assert.Error(t, err, "closed db should have caused an error")
}
func TestUserNilSK(t *testing.T) {
u := user{db: memdb.New()}
err := u.putAccount(nil)
assert.Error(t, err, "nil key should have caused an error")
}
func TestUserNilAccount(t *testing.T) {
u := user{db: memdb.New()}
_, err := u.controlsAccount(ids.ShortID{})
assert.Error(t, err, "nil accountID should have caused an error")
_, err = u.getKey(ids.ShortID{})
assert.Error(t, err, "nil accountID should have caused an error")
}
func TestUser(t *testing.T) {
u := user{db: memdb.New()}
accountIDs, err := u.getAccountIDs()
assert.NoError(t, err)
assert.Empty(t, accountIDs, "new user shouldn't have accounts")
factory := crypto.FactorySECP256K1R{}
sk, err := factory.NewPrivateKey()
assert.NoError(t, err)
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
assert.NoError(t, err)
addr := sk.PublicKey().Address()
ok, err := u.controlsAccount(addr)
assert.NoError(t, err)
assert.True(t, ok, "added account should have been marked as controlled")
savedSk, err := u.getKey(addr)
assert.NoError(t, err)
assert.Equal(t, sk.Bytes(), savedSk.Bytes(), "wrong key returned")
accountIDs, err = u.getAccountIDs()
assert.NoError(t, err)
assert.Len(t, accountIDs, 1, "account should have been added")
savedAddr := accountIDs[0]
equals := addr.Equals(savedAddr)
assert.True(t, equals, "saved address should match provided address")
}