mirror of https://github.com/poanetwork/gecko.git
Merge pull request #124 from StephenButtolph/check-nil-account
check for nil account IDs in get user
This commit is contained in:
commit
f257187941
|
@ -63,7 +63,7 @@ func (tx *rewardValidatorTx) SemanticVerify(db database.Database) (*versiondb.Da
|
|||
return nil, nil, nil, nil, err
|
||||
}
|
||||
if db == nil {
|
||||
return nil, nil, nil, nil, errDbNil
|
||||
return nil, nil, nil, nil, errDBNil
|
||||
}
|
||||
|
||||
currentEvents, err := tx.vm.getCurrentValidators(db, DefaultSubnetID)
|
||||
|
|
|
@ -322,7 +322,7 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re
|
|||
return errGetAccounts
|
||||
}
|
||||
|
||||
var accounts []APIAccount
|
||||
reply.Accounts = []APIAccount{}
|
||||
for _, accountID := range accountIDs {
|
||||
account, err := service.vm.getAccount(service.vm.DB, accountID) // Get account whose ID is [accountID]
|
||||
if err != nil && err != database.ErrNotFound {
|
||||
|
@ -331,13 +331,12 @@ func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, re
|
|||
} else if err == database.ErrNotFound {
|
||||
account = newAccount(accountID, 0, 0)
|
||||
}
|
||||
accounts = append(accounts, APIAccount{
|
||||
reply.Accounts = append(reply.Accounts, APIAccount{
|
||||
Address: accountID,
|
||||
Nonce: json.Uint64(account.Nonce),
|
||||
Balance: json.Uint64(account.Balance),
|
||||
})
|
||||
}
|
||||
reply.Accounts = accounts
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,10 @@ import (
|
|||
// account IDs this user controls
|
||||
var accountIDsKey = ids.Empty.Bytes()
|
||||
|
||||
var errDbNil = errors.New("db uninitialized")
|
||||
var (
|
||||
errDBNil = errors.New("db uninitialized")
|
||||
errKeyNil = errors.New("key uninitialized")
|
||||
)
|
||||
|
||||
type user struct {
|
||||
// This user's database, acquired from the keystore
|
||||
|
@ -25,7 +28,7 @@ type user struct {
|
|||
// Get the IDs of the accounts controlled by this user
|
||||
func (u *user) getAccountIDs() ([]ids.ShortID, error) {
|
||||
if u.db == nil {
|
||||
return nil, errDbNil
|
||||
return nil, errDBNil
|
||||
}
|
||||
|
||||
// If user has no accounts, return empty list
|
||||
|
@ -34,8 +37,9 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) {
|
|||
return nil, errDB
|
||||
}
|
||||
if !hasAccounts {
|
||||
return make([]ids.ShortID, 0), nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// User has accounts. Get them.
|
||||
bytes, err := u.db.Get(accountIDsKey)
|
||||
if err != nil {
|
||||
|
@ -50,21 +54,24 @@ func (u *user) getAccountIDs() ([]ids.ShortID, error) {
|
|||
|
||||
// controlsAccount returns true iff this user controls the account
|
||||
// with the specified ID
|
||||
func (u *user) controlsAccount(ID ids.ShortID) (bool, error) {
|
||||
func (u *user) controlsAccount(accountID ids.ShortID) (bool, error) {
|
||||
if u.db == nil {
|
||||
return false, errDbNil
|
||||
return false, errDBNil
|
||||
}
|
||||
|
||||
if _, err := u.db.Get(ID.Bytes()); err == nil {
|
||||
return true, nil
|
||||
if accountID.IsZero() {
|
||||
return false, errEmptyAccountAddress
|
||||
}
|
||||
return false, nil
|
||||
return u.db.Has(accountID.Bytes())
|
||||
}
|
||||
|
||||
// putAccount persists that this user controls the account whose ID is
|
||||
// [privKey].PublicKey().Address()
|
||||
func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error {
|
||||
newAccountID := privKey.PublicKey().Address() // Account thie privKey controls
|
||||
if privKey == nil {
|
||||
return errKeyNil
|
||||
}
|
||||
|
||||
newAccountID := privKey.PublicKey().Address() // Account the privKey controls
|
||||
controlsAccount, err := u.controlsAccount(newAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -102,7 +109,10 @@ func (u *user) putAccount(privKey *crypto.PrivateKeySECP256K1R) error {
|
|||
// Key returns the private key that controls the account with the specified ID
|
||||
func (u *user) getKey(accountID ids.ShortID) (*crypto.PrivateKeySECP256K1R, error) {
|
||||
if u.db == nil {
|
||||
return nil, errDbNil
|
||||
return nil, errDBNil
|
||||
}
|
||||
if accountID.IsZero() {
|
||||
return nil, errEmptyAccountAddress
|
||||
}
|
||||
|
||||
factory := crypto.FactorySECP256K1R{}
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
|
||||
// See the file LICENSE for licensing terms.
|
||||
|
||||
package platformvm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ava-labs/gecko/database/memdb"
|
||||
"github.com/ava-labs/gecko/ids"
|
||||
"github.com/ava-labs/gecko/utils/crypto"
|
||||
)
|
||||
|
||||
func TestUserNilDB(t *testing.T) {
|
||||
u := user{}
|
||||
|
||||
_, err := u.getAccountIDs()
|
||||
assert.Error(t, err, "nil db should have caused an error")
|
||||
|
||||
_, err = u.controlsAccount(ids.ShortEmpty)
|
||||
assert.Error(t, err, "nil db should have caused an error")
|
||||
|
||||
_, err = u.getKey(ids.ShortEmpty)
|
||||
assert.Error(t, err, "nil db should have caused an error")
|
||||
|
||||
factory := crypto.FactorySECP256K1R{}
|
||||
sk, err := factory.NewPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
|
||||
assert.Error(t, err, "nil db should have caused an error")
|
||||
}
|
||||
|
||||
func TestUserClosedDB(t *testing.T) {
|
||||
db := memdb.New()
|
||||
err := db.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
u := user{db: db}
|
||||
|
||||
_, err = u.getAccountIDs()
|
||||
assert.Error(t, err, "closed db should have caused an error")
|
||||
|
||||
_, err = u.controlsAccount(ids.ShortEmpty)
|
||||
assert.Error(t, err, "closed db should have caused an error")
|
||||
|
||||
_, err = u.getKey(ids.ShortEmpty)
|
||||
assert.Error(t, err, "closed db should have caused an error")
|
||||
|
||||
factory := crypto.FactorySECP256K1R{}
|
||||
sk, err := factory.NewPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
|
||||
assert.Error(t, err, "closed db should have caused an error")
|
||||
}
|
||||
|
||||
func TestUserNilSK(t *testing.T) {
|
||||
u := user{db: memdb.New()}
|
||||
|
||||
err := u.putAccount(nil)
|
||||
assert.Error(t, err, "nil key should have caused an error")
|
||||
}
|
||||
|
||||
func TestUserNilAccount(t *testing.T) {
|
||||
u := user{db: memdb.New()}
|
||||
|
||||
_, err := u.controlsAccount(ids.ShortID{})
|
||||
assert.Error(t, err, "nil accountID should have caused an error")
|
||||
|
||||
_, err = u.getKey(ids.ShortID{})
|
||||
assert.Error(t, err, "nil accountID should have caused an error")
|
||||
}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
u := user{db: memdb.New()}
|
||||
|
||||
accountIDs, err := u.getAccountIDs()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, accountIDs, "new user shouldn't have accounts")
|
||||
|
||||
factory := crypto.FactorySECP256K1R{}
|
||||
sk, err := factory.NewPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = u.putAccount(sk.(*crypto.PrivateKeySECP256K1R))
|
||||
assert.NoError(t, err)
|
||||
|
||||
addr := sk.PublicKey().Address()
|
||||
|
||||
ok, err := u.controlsAccount(addr)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ok, "added account should have been marked as controlled")
|
||||
|
||||
savedSk, err := u.getKey(addr)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, sk.Bytes(), savedSk.Bytes(), "wrong key returned")
|
||||
|
||||
accountIDs, err = u.getAccountIDs()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, accountIDs, 1, "account should have been added")
|
||||
|
||||
savedAddr := accountIDs[0]
|
||||
equals := addr.Equals(savedAddr)
|
||||
assert.True(t, equals, "saved address should match provided address")
|
||||
}
|
Loading…
Reference in New Issue