This commit is contained in:
StephenButtolph 2020-06-16 16:10:45 -04:00
commit ad130e848e
71 changed files with 1672 additions and 710 deletions

View File

@ -57,7 +57,7 @@ type GetNodeVersionReply struct {
// GetNodeVersion returns the version this node is running // GetNodeVersion returns the version this node is running
func (service *Admin) GetNodeVersion(_ *http.Request, _ *struct{}, reply *GetNodeVersionReply) error { func (service *Admin) GetNodeVersion(_ *http.Request, _ *struct{}, reply *GetNodeVersionReply) error {
service.log.Debug("Admin: GetNodeVersion called") service.log.Info("Admin: GetNodeVersion called")
reply.Version = service.version.String() reply.Version = service.version.String()
return nil return nil
@ -70,7 +70,7 @@ type GetNodeIDReply struct {
// GetNodeID returns the node ID of this node // GetNodeID returns the node ID of this node
func (service *Admin) GetNodeID(_ *http.Request, _ *struct{}, reply *GetNodeIDReply) error { func (service *Admin) GetNodeID(_ *http.Request, _ *struct{}, reply *GetNodeIDReply) error {
service.log.Debug("Admin: GetNodeID called") service.log.Info("Admin: GetNodeID called")
reply.NodeID = service.nodeID reply.NodeID = service.nodeID
return nil return nil
@ -83,7 +83,7 @@ type GetNetworkIDReply struct {
// GetNetworkID returns the network ID this node is running on // GetNetworkID returns the network ID this node is running on
func (service *Admin) GetNetworkID(_ *http.Request, _ *struct{}, reply *GetNetworkIDReply) error { func (service *Admin) GetNetworkID(_ *http.Request, _ *struct{}, reply *GetNetworkIDReply) error {
service.log.Debug("Admin: GetNetworkID called") service.log.Info("Admin: GetNetworkID called")
reply.NetworkID = cjson.Uint32(service.networkID) reply.NetworkID = cjson.Uint32(service.networkID)
return nil return nil
@ -96,7 +96,7 @@ type GetNetworkNameReply struct {
// GetNetworkName returns the network name this node is running on // GetNetworkName returns the network name this node is running on
func (service *Admin) GetNetworkName(_ *http.Request, _ *struct{}, reply *GetNetworkNameReply) error { func (service *Admin) GetNetworkName(_ *http.Request, _ *struct{}, reply *GetNetworkNameReply) error {
service.log.Debug("Admin: GetNetworkName called") service.log.Info("Admin: GetNetworkName called")
reply.NetworkName = genesis.NetworkName(service.networkID) reply.NetworkName = genesis.NetworkName(service.networkID)
return nil return nil
@ -114,7 +114,7 @@ type GetBlockchainIDReply struct {
// GetBlockchainID returns the blockchain ID that resolves the alias that was supplied // GetBlockchainID returns the blockchain ID that resolves the alias that was supplied
func (service *Admin) GetBlockchainID(_ *http.Request, args *GetBlockchainIDArgs, reply *GetBlockchainIDReply) error { func (service *Admin) GetBlockchainID(_ *http.Request, args *GetBlockchainIDArgs, reply *GetBlockchainIDReply) error {
service.log.Debug("Admin: GetBlockchainID called") service.log.Info("Admin: GetBlockchainID called")
bID, err := service.chainManager.Lookup(args.Alias) bID, err := service.chainManager.Lookup(args.Alias)
reply.BlockchainID = bID.String() reply.BlockchainID = bID.String()
@ -128,7 +128,7 @@ type PeersReply struct {
// Peers returns the list of current validators // Peers returns the list of current validators
func (service *Admin) Peers(_ *http.Request, _ *struct{}, reply *PeersReply) error { func (service *Admin) Peers(_ *http.Request, _ *struct{}, reply *PeersReply) error {
service.log.Debug("Admin: Peers called") service.log.Info("Admin: Peers called")
reply.Peers = service.networking.Peers() reply.Peers = service.networking.Peers()
return nil return nil
} }
@ -145,7 +145,7 @@ type StartCPUProfilerReply struct {
// StartCPUProfiler starts a cpu profile writing to the specified file // StartCPUProfiler starts a cpu profile writing to the specified file
func (service *Admin) StartCPUProfiler(_ *http.Request, args *StartCPUProfilerArgs, reply *StartCPUProfilerReply) error { func (service *Admin) StartCPUProfiler(_ *http.Request, args *StartCPUProfilerArgs, reply *StartCPUProfilerReply) error {
service.log.Debug("Admin: StartCPUProfiler called with %s", args.Filename) service.log.Info("Admin: StartCPUProfiler called with %s", args.Filename)
reply.Success = true reply.Success = true
return service.performance.StartCPUProfiler(args.Filename) return service.performance.StartCPUProfiler(args.Filename)
} }
@ -157,7 +157,7 @@ type StopCPUProfilerReply struct {
// StopCPUProfiler stops the cpu profile // StopCPUProfiler stops the cpu profile
func (service *Admin) StopCPUProfiler(_ *http.Request, _ *struct{}, reply *StopCPUProfilerReply) error { func (service *Admin) StopCPUProfiler(_ *http.Request, _ *struct{}, reply *StopCPUProfilerReply) error {
service.log.Debug("Admin: StopCPUProfiler called") service.log.Info("Admin: StopCPUProfiler called")
reply.Success = true reply.Success = true
return service.performance.StopCPUProfiler() return service.performance.StopCPUProfiler()
} }
@ -174,7 +174,7 @@ type MemoryProfileReply struct {
// MemoryProfile runs a memory profile writing to the specified file // MemoryProfile runs a memory profile writing to the specified file
func (service *Admin) MemoryProfile(_ *http.Request, args *MemoryProfileArgs, reply *MemoryProfileReply) error { func (service *Admin) MemoryProfile(_ *http.Request, args *MemoryProfileArgs, reply *MemoryProfileReply) error {
service.log.Debug("Admin: MemoryProfile called with %s", args.Filename) service.log.Info("Admin: MemoryProfile called with %s", args.Filename)
reply.Success = true reply.Success = true
return service.performance.MemoryProfile(args.Filename) return service.performance.MemoryProfile(args.Filename)
} }
@ -191,7 +191,7 @@ type LockProfileReply struct {
// LockProfile runs a mutex profile writing to the specified file // LockProfile runs a mutex profile writing to the specified file
func (service *Admin) LockProfile(_ *http.Request, args *LockProfileArgs, reply *LockProfileReply) error { func (service *Admin) LockProfile(_ *http.Request, args *LockProfileArgs, reply *LockProfileReply) error {
service.log.Debug("Admin: LockProfile called with %s", args.Filename) service.log.Info("Admin: LockProfile called with %s", args.Filename)
reply.Success = true reply.Success = true
return service.performance.LockProfile(args.Filename) return service.performance.LockProfile(args.Filename)
} }
@ -209,7 +209,7 @@ type AliasReply struct {
// Alias attempts to alias an HTTP endpoint to a new name // Alias attempts to alias an HTTP endpoint to a new name
func (service *Admin) Alias(_ *http.Request, args *AliasArgs, reply *AliasReply) error { func (service *Admin) Alias(_ *http.Request, args *AliasArgs, reply *AliasReply) error {
service.log.Debug("Admin: Alias called with URL: %s, Alias: %s", args.Endpoint, args.Alias) service.log.Info("Admin: Alias called with URL: %s, Alias: %s", args.Endpoint, args.Alias)
reply.Success = true reply.Success = true
return service.httpServer.AddAliasesWithReadLock(args.Endpoint, args.Alias) return service.httpServer.AddAliasesWithReadLock(args.Endpoint, args.Alias)
} }
@ -227,7 +227,7 @@ type AliasChainReply struct {
// AliasChain attempts to alias a chain to a new name // AliasChain attempts to alias a chain to a new name
func (service *Admin) AliasChain(_ *http.Request, args *AliasChainArgs, reply *AliasChainReply) error { func (service *Admin) AliasChain(_ *http.Request, args *AliasChainArgs, reply *AliasChainReply) error {
service.log.Debug("Admin: AliasChain called with Chain: %s, Alias: %s", args.Chain, args.Alias) service.log.Info("Admin: AliasChain called with Chain: %s, Alias: %s", args.Chain, args.Alias)
chainID, err := service.chainManager.Lookup(args.Chain) chainID, err := service.chainManager.Lookup(args.Chain)
if err != nil { if err != nil {

View File

@ -74,7 +74,7 @@ type GetLivenessReply struct {
// GetLiveness returns a summation of the health of the node // GetLiveness returns a summation of the health of the node
func (h *Health) GetLiveness(_ *http.Request, _ *GetLivenessArgs, reply *GetLivenessReply) error { func (h *Health) GetLiveness(_ *http.Request, _ *GetLivenessArgs, reply *GetLivenessReply) error {
h.log.Debug("Health: GetLiveness called") h.log.Info("Health: GetLiveness called")
reply.Checks, reply.Healthy = h.health.Results() reply.Checks, reply.Healthy = h.health.Results()
return nil return nil
} }

View File

@ -61,6 +61,7 @@ type PublishBlockchainReply struct {
// PublishBlockchain publishes the finalized accepted transactions from the blockchainID over the IPC // PublishBlockchain publishes the finalized accepted transactions from the blockchainID over the IPC
func (ipc *IPCs) PublishBlockchain(r *http.Request, args *PublishBlockchainArgs, reply *PublishBlockchainReply) error { func (ipc *IPCs) PublishBlockchain(r *http.Request, args *PublishBlockchainArgs, reply *PublishBlockchainReply) error {
ipc.log.Info("IPCs: PublishBlockchain called with BlockchainID: %s", args.BlockchainID)
chainID, err := ipc.chainManager.Lookup(args.BlockchainID) chainID, err := ipc.chainManager.Lookup(args.BlockchainID)
if err != nil { if err != nil {
ipc.log.Error("unknown blockchainID: %s", err) ipc.log.Error("unknown blockchainID: %s", err)
@ -116,6 +117,7 @@ type UnpublishBlockchainReply struct {
// UnpublishBlockchain closes publishing of a blockchainID // UnpublishBlockchain closes publishing of a blockchainID
func (ipc *IPCs) UnpublishBlockchain(r *http.Request, args *UnpublishBlockchainArgs, reply *UnpublishBlockchainReply) error { func (ipc *IPCs) UnpublishBlockchain(r *http.Request, args *UnpublishBlockchainArgs, reply *UnpublishBlockchainReply) error {
ipc.log.Info("IPCs: UnpublishBlockchain called with BlockchainID: %s", args.BlockchainID)
chainID, err := ipc.chainManager.Lookup(args.BlockchainID) chainID, err := ipc.chainManager.Lookup(args.BlockchainID)
if err != nil { if err != nil {
ipc.log.Error("unknown blockchainID %s: %s", args.BlockchainID, err) ipc.log.Error("unknown blockchainID %s: %s", args.BlockchainID, err)

View File

@ -8,18 +8,20 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
"testing"
"github.com/gorilla/rpc/v2" "github.com/gorilla/rpc/v2"
"github.com/ava-labs/gecko/chains/atomic" "github.com/ava-labs/gecko/chains/atomic"
"github.com/ava-labs/gecko/database" "github.com/ava-labs/gecko/database"
"github.com/ava-labs/gecko/database/encdb" "github.com/ava-labs/gecko/database/encdb"
"github.com/ava-labs/gecko/database/memdb"
"github.com/ava-labs/gecko/database/prefixdb" "github.com/ava-labs/gecko/database/prefixdb"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/engine/common" "github.com/ava-labs/gecko/snow/engine/common"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
jsoncodec "github.com/ava-labs/gecko/utils/json" jsoncodec "github.com/ava-labs/gecko/utils/json"
zxcvbn "github.com/nbutton23/zxcvbn-go" zxcvbn "github.com/nbutton23/zxcvbn-go"
@ -29,8 +31,17 @@ const (
// maxUserPassLen is the maximum length of the username or password allowed // maxUserPassLen is the maximum length of the username or password allowed
maxUserPassLen = 1024 maxUserPassLen = 1024
// requiredPassScore defines the score a password must achieve to be accepted // maxCheckedPassLen limits the length of the password that should be
// as a password with strong characteristics by the zxcvbn package // strength checked.
//
// As per issue https://github.com/ava-labs/gecko/issues/195 it was found
// the longer the length of password the slower zxcvbn.PasswordStrength()
// performs. To avoid performance issues, and a DoS vector, we only check
// the first 50 characters of the password.
maxCheckedPassLen = 50
// requiredPassScore defines the score a password must achieve to be
// accepted as a password with strong characteristics by the zxcvbn package
// //
// The scoring mechanism defined is as follows; // The scoring mechanism defined is as follows;
// //
@ -135,37 +146,11 @@ func (ks *Keystore) CreateUser(_ *http.Request, args *CreateUserArgs, reply *Cre
ks.lock.Lock() ks.lock.Lock()
defer ks.lock.Unlock() defer ks.lock.Unlock()
ks.log.Verbo("CreateUser called with %.*s", maxUserPassLen, args.Username) ks.log.Info("Keystore: CreateUser called with %.*s", maxUserPassLen, args.Username)
if err := ks.AddUser(args.Username, args.Password); err != nil {
if len(args.Username) > maxUserPassLen || len(args.Password) > maxUserPassLen {
return errUserPassMaxLength
}
if args.Username == "" {
return errEmptyUsername
}
if usr, err := ks.getUser(args.Username); err == nil || usr != nil {
return fmt.Errorf("user already exists: %s", args.Username)
}
if zxcvbn.PasswordStrength(args.Password, nil).Score < requiredPassScore {
return errWeakPassword
}
usr := &User{}
if err := usr.Initialize(args.Password); err != nil {
return err return err
} }
usrBytes, err := ks.codec.Marshal(usr)
if err != nil {
return err
}
if err := ks.userDB.Put([]byte(args.Username), usrBytes); err != nil {
return err
}
ks.users[args.Username] = usr
reply.Success = true reply.Success = true
return nil return nil
} }
@ -183,7 +168,7 @@ func (ks *Keystore) ListUsers(_ *http.Request, args *ListUsersArgs, reply *ListU
ks.lock.Lock() ks.lock.Lock()
defer ks.lock.Unlock() defer ks.lock.Unlock()
ks.log.Verbo("ListUsers called") ks.log.Info("Keystore: ListUsers called")
reply.Users = []string{} reply.Users = []string{}
@ -211,7 +196,7 @@ func (ks *Keystore) ExportUser(_ *http.Request, args *ExportUserArgs, reply *Exp
ks.lock.Lock() ks.lock.Lock()
defer ks.lock.Unlock() defer ks.lock.Unlock()
ks.log.Verbo("ExportUser called for %s", args.Username) ks.log.Info("Keystore: ExportUser called for %s", args.Username)
usr, err := ks.getUser(args.Username) usr, err := ks.getUser(args.Username)
if err != nil { if err != nil {
@ -264,7 +249,7 @@ func (ks *Keystore) ImportUser(r *http.Request, args *ImportUserArgs, reply *Imp
ks.lock.Lock() ks.lock.Lock()
defer ks.lock.Unlock() defer ks.lock.Unlock()
ks.log.Verbo("ImportUser called for %s", args.Username) ks.log.Info("Keystore: ImportUser called for %s", args.Username)
if args.Username == "" { if args.Username == "" {
return errEmptyUsername return errEmptyUsername
@ -324,7 +309,7 @@ func (ks *Keystore) DeleteUser(_ *http.Request, args *DeleteUserArgs, reply *Del
ks.lock.Lock() ks.lock.Lock()
defer ks.lock.Unlock() defer ks.lock.Unlock()
ks.log.Verbo("DeleteUser called with %s", args.Username) ks.log.Info("Keystore: DeleteUser called with %s", args.Username)
if args.Username == "" { if args.Username == "" {
return errEmptyUsername return errEmptyUsername
@ -403,3 +388,51 @@ func (ks *Keystore) GetDatabase(bID ids.ID, username, password string) (database
return encDB, nil return encDB, nil
} }
// AddUser attempts to register this username and password as a new user of the
// keystore.
func (ks *Keystore) AddUser(username, password string) error {
if len(username) > maxUserPassLen || len(password) > maxUserPassLen {
return errUserPassMaxLength
}
if username == "" {
return errEmptyUsername
}
if usr, err := ks.getUser(username); err == nil || usr != nil {
return fmt.Errorf("user already exists: %s", username)
}
checkPass := password
if len(password) > maxCheckedPassLen {
checkPass = password[:maxCheckedPassLen]
}
if zxcvbn.PasswordStrength(checkPass, nil).Score < requiredPassScore {
return errWeakPassword
}
usr := &User{}
if err := usr.Initialize(password); err != nil {
return err
}
usrBytes, err := ks.codec.Marshal(usr)
if err != nil {
return err
}
if err := ks.userDB.Put([]byte(username), usrBytes); err != nil {
return err
}
ks.users[username] = usr
return nil
}
// CreateTestKeystore returns a new keystore that can be utilized for testing
func CreateTestKeystore(t *testing.T) *Keystore {
ks := &Keystore{}
ks.Initialize(logging.NoLog{}, memdb.New())
return ks
}

View File

@ -10,9 +10,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/ava-labs/gecko/database/memdb"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/logging"
) )
var ( var (
@ -22,8 +20,7 @@ var (
) )
func TestServiceListNoUsers(t *testing.T) { func TestServiceListNoUsers(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
reply := ListUsersReply{} reply := ListUsersReply{}
if err := ks.ListUsers(nil, &ListUsersArgs{}, &reply); err != nil { if err := ks.ListUsers(nil, &ListUsersArgs{}, &reply); err != nil {
@ -35,8 +32,7 @@ func TestServiceListNoUsers(t *testing.T) {
} }
func TestServiceCreateUser(t *testing.T) { func TestServiceCreateUser(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -75,8 +71,7 @@ func genStr(n int) string {
// TestServiceCreateUserArgsChecks generates excessively long usernames or // TestServiceCreateUserArgsChecks generates excessively long usernames or
// passwords to assure the santity checks on string length are not exceeded // passwords to assure the santity checks on string length are not exceeded
func TestServiceCreateUserArgsCheck(t *testing.T) { func TestServiceCreateUserArgsCheck(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -117,8 +112,7 @@ func TestServiceCreateUserArgsCheck(t *testing.T) {
// TestServiceCreateUserWeakPassword tests creating a new user with a weak // TestServiceCreateUserWeakPassword tests creating a new user with a weak
// password to ensure the password strength check is working // password to ensure the password strength check is working
func TestServiceCreateUserWeakPassword(t *testing.T) { func TestServiceCreateUserWeakPassword(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -138,8 +132,7 @@ func TestServiceCreateUserWeakPassword(t *testing.T) {
} }
func TestServiceCreateDuplicate(t *testing.T) { func TestServiceCreateDuplicate(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -166,8 +159,7 @@ func TestServiceCreateDuplicate(t *testing.T) {
} }
func TestServiceCreateUserNoName(t *testing.T) { func TestServiceCreateUserNoName(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
reply := CreateUserReply{} reply := CreateUserReply{}
if err := ks.CreateUser(nil, &CreateUserArgs{ if err := ks.CreateUser(nil, &CreateUserArgs{
@ -178,8 +170,7 @@ func TestServiceCreateUserNoName(t *testing.T) {
} }
func TestServiceUseBlockchainDB(t *testing.T) { func TestServiceUseBlockchainDB(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -218,8 +209,7 @@ func TestServiceUseBlockchainDB(t *testing.T) {
} }
func TestServiceExportImport(t *testing.T) { func TestServiceExportImport(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := CreateUserReply{} reply := CreateUserReply{}
@ -252,8 +242,7 @@ func TestServiceExportImport(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
newKS := Keystore{} newKS := CreateTestKeystore(t)
newKS.Initialize(logging.NoLog{}, memdb.New())
{ {
reply := ImportUserReply{} reply := ImportUserReply{}
@ -358,11 +347,10 @@ func TestServiceDeleteUser(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
ks := Keystore{} ks := CreateTestKeystore(t)
ks.Initialize(logging.NoLog{}, memdb.New())
if tt.setup != nil { if tt.setup != nil {
if err := tt.setup(&ks); err != nil { if err := tt.setup(ks); err != nil {
t.Fatalf("failed to create user setup in keystore: %v", err) t.Fatalf("failed to create user setup in keystore: %v", err)
} }
} }

View File

@ -12,7 +12,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
type rcLock struct { type rcLock struct {

View File

@ -14,7 +14,7 @@ import (
"github.com/ava-labs/gecko/database/nodb" "github.com/ava-labs/gecko/database/nodb"
"github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
// Database encrypts all values that are provided // Database encrypts all values that are provided

View File

@ -14,7 +14,7 @@ import (
"github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/utils/units"
"github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/utils/wrappers"
"github.com/ava-labs/gecko/vms/avm" "github.com/ava-labs/gecko/vms/avm"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/nftfx" "github.com/ava-labs/gecko/vms/nftfx"
"github.com/ava-labs/gecko/vms/platformvm" "github.com/ava-labs/gecko/vms/platformvm"
"github.com/ava-labs/gecko/vms/propertyfx" "github.com/ava-labs/gecko/vms/propertyfx"

1
go.sum
View File

@ -2,6 +2,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
github.com/AppsFlyer/go-sundheit v0.2.0 h1:FArqX+HbqZ6U32RC3giEAWRUpkggqxHj91KIvxNgwjU= github.com/AppsFlyer/go-sundheit v0.2.0 h1:FArqX+HbqZ6U32RC3giEAWRUpkggqxHj91KIvxNgwjU=
github.com/AppsFlyer/go-sundheit v0.2.0/go.mod h1:rCRkVTMQo7/krF7xQ9X0XEF1an68viFR6/Gy02q+4ds= github.com/AppsFlyer/go-sundheit v0.2.0/go.mod h1:rCRkVTMQo7/krF7xQ9X0XEF1an68viFR6/Gy02q+4ds=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Microsoft/go-winio v0.4.11 h1:zoIOcVf0xPN1tnMVbTtEdI+P8OofVk3NObnwOQ6nK2Q=
github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/Shopify/sarama v1.26.1/go.mod h1:NbSGBSSndYaIhRcBtY9V0U7AyH+x71bG668AuWys/yU= github.com/Shopify/sarama v1.26.1/go.mod h1:NbSGBSSndYaIhRcBtY9V0U7AyH+x71bG668AuWys/yU=

View File

@ -45,7 +45,10 @@ func main() {
} }
// Track if sybil control is enforced // Track if sybil control is enforced
if !Config.EnableStaking { if !Config.EnableStaking && Config.EnableP2PTLS {
log.Warn("Staking is disabled. Sybil control is not enforced.")
}
if !Config.EnableStaking && !Config.EnableP2PTLS {
log.Warn("Staking and p2p encryption are disabled. Packet spoofing is possible.") log.Warn("Staking and p2p encryption are disabled. Packet spoofing is possible.")
} }

View File

@ -37,6 +37,7 @@ const (
var ( var (
Config = node.Config{} Config = node.Config{}
Err error Err error
defaultNetworkName = genesis.TestnetName
defaultDbDir = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "db")) defaultDbDir = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "db"))
defaultStakingKeyPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.key")) defaultStakingKeyPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.key"))
defaultStakingCertPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.crt")) defaultStakingCertPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.crt"))
@ -49,7 +50,8 @@ var (
) )
var ( var (
errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs") errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs")
errStakingRequiresTLS = errors.New("if staking is enabled, network TLS must also be enabled")
) )
// GetIPs returns the default IPs for each network // GetIPs returns the default IPs for each network
@ -169,7 +171,7 @@ func init() {
version := fs.Bool("version", false, "If true, print version and quit") version := fs.Bool("version", false, "If true, print version and quit")
// NetworkID: // NetworkID:
networkName := fs.String("network-id", genesis.TestnetName, "Network ID this node will connect to") networkName := fs.String("network-id", defaultNetworkName, "Network ID this node will connect to")
// Ava fees: // Ava fees:
fs.Uint64Var(&Config.AvaTxFee, "ava-tx-fee", 0, "Ava transaction fee, in $nAva") fs.Uint64Var(&Config.AvaTxFee, "ava-tx-fee", 0, "Ava transaction fee, in $nAva")
@ -200,7 +202,9 @@ func init() {
// Staking: // Staking:
consensusPort := fs.Uint("staking-port", 9651, "Port of the consensus server") consensusPort := fs.Uint("staking-port", 9651, "Port of the consensus server")
fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Require TLS to authenticate staking connections") // TODO - keeping same flag for backwards compatibility, should be changed to "staking-enabled"
fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Enable staking. If enabled, Network TLS is required.")
fs.BoolVar(&Config.EnableP2PTLS, "p2p-tls-enabled", true, "Require TLS to authenticate network communication")
fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", defaultStakingKeyPath, "TLS private key for staking") fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", defaultStakingKeyPath, "TLS private key for staking")
fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", defaultStakingCertPath, "TLS certificate for staking") fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", defaultStakingCertPath, "TLS certificate for staking")
@ -234,7 +238,15 @@ func init() {
ferr := fs.Parse(os.Args[1:]) ferr := fs.Parse(os.Args[1:])
if *version { // If --version used, print version and exit if *version { // If --version used, print version and exit
fmt.Println(node.Version.String()) networkID, err := genesis.NetworkID(defaultNetworkName)
if errs.Add(err); err != nil {
return
}
networkGeneration := genesis.NetworkName(networkID)
fmt.Printf(
"%s [database=%s, network=%s/%s]\n",
node.Version, dbVersion, defaultNetworkName, networkGeneration,
)
os.Exit(0) os.Exit(0)
} }
@ -318,7 +330,13 @@ func init() {
*bootstrapIDs = strings.Join(defaultBootstrapIDs, ",") *bootstrapIDs = strings.Join(defaultBootstrapIDs, ",")
} }
} }
if Config.EnableStaking {
if Config.EnableStaking && !Config.EnableP2PTLS {
errs.Add(errStakingRequiresTLS)
return
}
if Config.EnableP2PTLS {
i := 0 i := 0
cb58 := formatting.CB58{} cb58 := formatting.CB58{}
for _, id := range strings.Split(*bootstrapIDs, ",") { for _, id := range strings.Split(*bootstrapIDs, ",") {

View File

@ -21,6 +21,7 @@ import (
"github.com/ava-labs/gecko/snow/triggers" "github.com/ava-labs/gecko/snow/triggers"
"github.com/ava-labs/gecko/snow/validators" "github.com/ava-labs/gecko/snow/validators"
"github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/utils/random" "github.com/ava-labs/gecko/utils/random"
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
@ -278,8 +279,11 @@ func (n *network) GetAcceptedFrontier(validatorIDs ids.ShortSet, chainID ids.ID,
func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) { func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) {
msg, err := n.b.AcceptedFrontier(chainID, requestID, containerIDs) msg, err := n.b.AcceptedFrontier(chainID, requestID, containerIDs)
if err != nil { if err != nil {
n.log.Error("attempted to pack too large of an AcceptedFrontier message.\nNumber of containerIDs: %d", n.log.Error("failed to build AcceptedFrontier(%s, %d, %s): %s",
containerIDs.Len()) chainID,
requestID,
containerIDs,
err)
return // Packing message failed return // Packing message failed
} }
@ -291,7 +295,11 @@ func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requ
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send an AcceptedFrontier message to: %s", validatorID) n.log.Debug("failed to send AcceptedFrontier(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerIDs)
n.acceptedFrontier.numFailed.Inc() n.acceptedFrontier.numFailed.Inc()
} else { } else {
n.acceptedFrontier.numSent.Inc() n.acceptedFrontier.numSent.Inc()
@ -302,6 +310,11 @@ func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requ
func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerIDs ids.Set) { func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerIDs ids.Set) {
msg, err := n.b.GetAccepted(chainID, requestID, containerIDs) msg, err := n.b.GetAccepted(chainID, requestID, containerIDs)
if err != nil { if err != nil {
n.log.Error("failed to build GetAccepted(%s, %d, %s): %s",
chainID,
requestID,
containerIDs,
err)
for _, validatorID := range validatorIDs.List() { for _, validatorID := range validatorIDs.List() {
vID := validatorID vID := validatorID
n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) }) n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) })
@ -319,6 +332,11 @@ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, request
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send GetAccepted(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerIDs)
n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) }) n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) })
n.getAccepted.numFailed.Inc() n.getAccepted.numFailed.Inc()
} else { } else {
@ -331,8 +349,11 @@ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, request
func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) { func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) {
msg, err := n.b.Accepted(chainID, requestID, containerIDs) msg, err := n.b.Accepted(chainID, requestID, containerIDs)
if err != nil { if err != nil {
n.log.Error("attempted to pack too large of an Accepted message.\nNumber of containerIDs: %d", n.log.Error("failed to build Accepted(%s, %d, %s): %s",
containerIDs.Len()) chainID,
requestID,
containerIDs,
err)
return // Packing message failed return // Packing message failed
} }
@ -344,33 +365,17 @@ func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID ui
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send an Accepted message to: %s", validatorID) n.log.Debug("failed to send Accepted(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerIDs)
n.accepted.numFailed.Inc() n.accepted.numFailed.Inc()
} else { } else {
n.accepted.numSent.Inc() n.accepted.numSent.Inc()
} }
} }
// Get implements the Sender interface.
func (n *network) Get(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) {
msg, err := n.b.Get(chainID, requestID, containerID)
n.log.AssertNoError(err)
n.stateLock.Lock()
defer n.stateLock.Unlock()
peer, sent := n.peers[validatorID.Key()]
if sent {
sent = peer.send(msg)
}
if !sent {
n.log.Debug("failed to send a Get message to: %s", validatorID)
n.get.numFailed.Inc()
} else {
n.get.numSent.Inc()
}
}
// GetAncestors implements the Sender interface. // GetAncestors implements the Sender interface.
func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) { func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) {
msg, err := n.b.GetAncestors(chainID, requestID, containerID) msg, err := n.b.GetAncestors(chainID, requestID, containerID)
@ -387,36 +392,18 @@ func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestI
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send GetAncestors(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerID)
n.executor.Add(func() { n.router.GetAncestorsFailed(validatorID, chainID, requestID) })
n.getAncestors.numFailed.Inc() n.getAncestors.numFailed.Inc()
n.log.Debug("failed to send a GetAncestors message to: %s", validatorID)
} else { } else {
n.getAncestors.numSent.Inc() n.getAncestors.numSent.Inc()
} }
} }
// Put implements the Sender interface.
func (n *network) Put(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) {
msg, err := n.b.Put(chainID, requestID, containerID, container)
if err != nil {
n.log.Error("failed to build Put message because of container of size %d", len(container))
return
}
n.stateLock.Lock()
defer n.stateLock.Unlock()
peer, sent := n.peers[validatorID.Key()]
if sent {
sent = peer.send(msg)
}
if !sent {
n.log.Debug("failed to send a Put message to: %s", validatorID)
n.put.numFailed.Inc()
} else {
n.put.numSent.Inc()
}
}
// MultiPut implements the Sender interface. // MultiPut implements the Sender interface.
func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containers [][]byte) { func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containers [][]byte) {
msg, err := n.b.MultiPut(chainID, requestID, containers) msg, err := n.b.MultiPut(chainID, requestID, containers)
@ -433,22 +420,90 @@ func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID ui
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send a MultiPut message to: %s", validatorID) n.log.Debug("failed to send MultiPut(%s, %s, %d, %d)",
validatorID,
chainID,
requestID,
len(containers))
n.multiPut.numFailed.Inc() n.multiPut.numFailed.Inc()
} else { } else {
n.multiPut.numSent.Inc() n.multiPut.numSent.Inc()
} }
} }
// Get implements the Sender interface.
func (n *network) Get(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) {
msg, err := n.b.Get(chainID, requestID, containerID)
n.log.AssertNoError(err)
n.stateLock.Lock()
defer n.stateLock.Unlock()
peer, sent := n.peers[validatorID.Key()]
if sent {
sent = peer.send(msg)
}
if !sent {
n.log.Debug("failed to send Get(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerID)
n.executor.Add(func() { n.router.GetFailed(validatorID, chainID, requestID) })
n.get.numFailed.Inc()
} else {
n.get.numSent.Inc()
}
}
// Put implements the Sender interface.
func (n *network) Put(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) {
msg, err := n.b.Put(chainID, requestID, containerID, container)
if err != nil {
n.log.Error("failed to build Put(%s, %d, %s): %s. len(container) : %d",
chainID,
requestID,
containerID,
err,
len(container))
return
}
n.stateLock.Lock()
defer n.stateLock.Unlock()
peer, sent := n.peers[validatorID.Key()]
if sent {
sent = peer.send(msg)
}
if !sent {
n.log.Debug("failed to send Put(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerID)
n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container})
n.put.numFailed.Inc()
} else {
n.put.numSent.Inc()
}
}
// PushQuery implements the Sender interface. // PushQuery implements the Sender interface.
func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) { func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) {
msg, err := n.b.PushQuery(chainID, requestID, containerID, container) msg, err := n.b.PushQuery(chainID, requestID, containerID, container)
if err != nil { if err != nil {
n.log.Error("failed to build PushQuery(%s, %d, %s): %s. len(container): %d",
chainID,
requestID,
containerID,
err,
len(container))
n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container})
for _, validatorID := range validatorIDs.List() { for _, validatorID := range validatorIDs.List() {
vID := validatorID vID := validatorID
n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) })
} }
n.log.Error("attempted to pack too large of a PushQuery message.\nContainer length: %d", len(container))
return // Packing message failed return // Packing message failed
} }
@ -462,7 +517,12 @@ func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed sending a PushQuery message to: %s", vID) n.log.Debug("failed to send PushQuery(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerID)
n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container})
n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) })
n.pushQuery.numFailed.Inc() n.pushQuery.numFailed.Inc()
} else { } else {
@ -486,7 +546,11 @@ func (n *network) PullQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed sending a PullQuery message to: %s", vID) n.log.Debug("failed to send PullQuery(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
containerID)
n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) })
n.pullQuery.numFailed.Inc() n.pullQuery.numFailed.Inc()
} else { } else {
@ -499,7 +563,11 @@ func (n *network) PullQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID
func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint32, votes ids.Set) { func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint32, votes ids.Set) {
msg, err := n.b.Chits(chainID, requestID, votes) msg, err := n.b.Chits(chainID, requestID, votes)
if err != nil { if err != nil {
n.log.Error("failed to build Chits message because of %d votes", votes.Len()) n.log.Error("failed to build Chits(%s, %d, %s): %s",
chainID,
requestID,
votes,
err)
return return
} }
@ -511,7 +579,11 @@ func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint3
sent = peer.send(msg) sent = peer.send(msg)
} }
if !sent { if !sent {
n.log.Debug("failed to send a Chits message to: %s", validatorID) n.log.Debug("failed to send Chits(%s, %s, %d, %s)",
validatorID,
chainID,
requestID,
votes)
n.chits.numFailed.Inc() n.chits.numFailed.Inc()
} else { } else {
n.chits.numSent.Inc() n.chits.numSent.Inc()
@ -521,7 +593,8 @@ func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint3
// Gossip attempts to gossip the container to the network // Gossip attempts to gossip the container to the network
func (n *network) Gossip(chainID, containerID ids.ID, container []byte) { func (n *network) Gossip(chainID, containerID ids.ID, container []byte) {
if err := n.gossipContainer(chainID, containerID, container); err != nil { if err := n.gossipContainer(chainID, containerID, container); err != nil {
n.log.Error("error gossiping container %s to %s: %s", containerID, chainID, err) n.log.Debug("failed to Gossip(%s, %s): %s", chainID, containerID, err)
n.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container})
} }
} }
@ -695,7 +768,9 @@ func (n *network) gossip() {
} }
msg, err := n.b.PeerList(ips) msg, err := n.b.PeerList(ips)
if err != nil { if err != nil {
n.log.Warn("failed to gossip PeerList message due to %s", err) n.log.Error("failed to build peer list to gossip: %s. len(ips): %d",
err,
len(ips))
continue continue
} }

View File

@ -34,6 +34,7 @@ type Config struct {
// Staking configuration // Staking configuration
StakingIP utils.IPDesc StakingIP utils.IPDesc
EnableP2PTLS bool
EnableStaking bool EnableStaking bool
StakingKeyFile string StakingKeyFile string
StakingCertFile string StakingCertFile string

View File

@ -119,7 +119,7 @@ func (n *Node) initNetworking() error {
dialer := network.NewDialer(TCP) dialer := network.NewDialer(TCP)
var serverUpgrader, clientUpgrader network.Upgrader var serverUpgrader, clientUpgrader network.Upgrader
if n.Config.EnableStaking { if n.Config.EnableP2PTLS {
cert, err := tls.LoadX509KeyPair(n.Config.StakingCertFile, n.Config.StakingKeyFile) cert, err := tls.LoadX509KeyPair(n.Config.StakingCertFile, n.Config.StakingKeyFile)
if err != nil { if err != nil {
return err return err
@ -253,7 +253,7 @@ func (n *Node) initDatabase() error {
// Otherwise, it is a hash of the TLS certificate that this node // Otherwise, it is a hash of the TLS certificate that this node
// uses for P2P communication // uses for P2P communication
func (n *Node) initNodeID() error { func (n *Node) initNodeID() error {
if !n.Config.EnableStaking { if !n.Config.EnableP2PTLS {
n.ID = ids.NewShortID(hashing.ComputeHash160Array([]byte(n.Config.StakingIP.String()))) n.ID = ids.NewShortID(hashing.ComputeHash160Array([]byte(n.Config.StakingIP.String())))
n.Log.Info("Set the node's ID to %s", n.ID) n.Log.Info("Set the node's ID to %s", n.ID)
return nil return nil

View File

@ -231,6 +231,7 @@ func (ta *Topological) pushVotes(
kahnNodes map[[32]byte]kahnNode, kahnNodes map[[32]byte]kahnNode,
leaves []ids.ID) ids.Bag { leaves []ids.ID) ids.Bag {
votes := make(ids.UniqueBag) votes := make(ids.UniqueBag)
txConflicts := make(map[[32]byte]ids.Set)
for len(leaves) > 0 { for len(leaves) > 0 {
newLeavesSize := len(leaves) - 1 newLeavesSize := len(leaves) - 1
@ -245,6 +246,12 @@ func (ta *Topological) pushVotes(
// Give the votes to the consumer // Give the votes to the consumer
txID := tx.ID() txID := tx.ID()
votes.UnionSet(txID, kahn.votes) votes.UnionSet(txID, kahn.votes)
// Map txID to set of Conflicts
txKey := txID.Key()
if _, exists := txConflicts[txKey]; !exists {
txConflicts[txKey] = ta.cg.Conflicts(tx)
}
} }
for _, dep := range vtx.Parents() { for _, dep := range vtx.Parents() {
@ -265,6 +272,18 @@ func (ta *Topological) pushVotes(
} }
} }
// Create bag of votes for conflicting transactions
conflictingVotes := make(ids.UniqueBag)
for txHash, conflicts := range txConflicts {
txID := ids.NewID(txHash)
for conflictTxHash := range conflicts {
conflictTxID := ids.NewID(conflictTxHash)
conflictingVotes.UnionSet(txID, votes.GetSet(conflictTxID))
}
}
votes.Difference(&conflictingVotes)
return votes.Bag(ta.params.Alpha) return votes.Bag(ta.params.Alpha)
} }

View File

@ -104,6 +104,78 @@ func TestAvalancheVoting(t *testing.T) {
} }
} }
func TestAvalancheIgnoreInvalidVoting(t *testing.T) {
params := Parameters{
Parameters: snowball.Parameters{
Metrics: prometheus.NewRegistry(),
K: 3,
Alpha: 2,
BetaVirtuous: 1,
BetaRogue: 1,
},
Parents: 2,
BatchSize: 1,
}
vts := []Vertex{&Vtx{
id: GenerateID(),
status: choices.Accepted,
}, &Vtx{
id: GenerateID(),
status: choices.Accepted,
}}
utxos := []ids.ID{GenerateID()}
ta := Topological{}
ta.Initialize(snow.DefaultContextTest(), params, vts)
tx0 := &snowstorm.TestTx{
Identifier: GenerateID(),
Stat: choices.Processing,
}
tx0.Ins.Add(utxos[0])
vtx0 := &Vtx{
dependencies: vts,
id: GenerateID(),
txs: []snowstorm.Tx{tx0},
height: 1,
status: choices.Processing,
}
tx1 := &snowstorm.TestTx{
Identifier: GenerateID(),
Stat: choices.Processing,
}
tx1.Ins.Add(utxos[0])
vtx1 := &Vtx{
dependencies: vts,
id: GenerateID(),
txs: []snowstorm.Tx{tx1},
height: 1,
status: choices.Processing,
}
ta.Add(vtx0)
ta.Add(vtx1)
sm := make(ids.UniqueBag)
sm.Add(0, vtx0.id)
sm.Add(1, vtx1.id)
// Add Illegal Vote cast by Response 2
sm.Add(2, vtx0.id)
sm.Add(2, vtx1.id)
ta.RecordPoll(sm)
if ta.Finalized() {
t.Fatalf("An avalanche instance finalized too early")
}
}
func TestAvalancheTransitiveVoting(t *testing.T) { func TestAvalancheTransitiveVoting(t *testing.T) {
params := Parameters{ params := Parameters{
Parameters: snowball.Parameters{ Parameters: snowball.Parameters{

View File

@ -27,11 +27,13 @@ func (sb *unarySnowball) Extend(beta int, choice int) BinarySnowball {
bs := &binarySnowball{ bs := &binarySnowball{
binarySnowflake: binarySnowflake{ binarySnowflake: binarySnowflake{
binarySlush: binarySlush{preference: choice}, binarySlush: binarySlush{preference: choice},
confidence: sb.confidence,
beta: beta, beta: beta,
finalized: sb.Finalized(), finalized: sb.Finalized(),
}, },
preference: choice, preference: choice,
} }
bs.numSuccessfulPolls[choice] = sb.numSuccessfulPolls
return bs return bs
} }

View File

@ -42,11 +42,32 @@ func TestUnarySnowball(t *testing.T) {
binarySnowball := sbClone.Extend(beta, 0) binarySnowball := sbClone.Extend(beta, 0)
expected := "SB(Preference = 0, NumSuccessfulPolls[0] = 2, NumSuccessfulPolls[1] = 0, SF(Confidence = 1, Finalized = false, SL(Preference = 0)))"
if result := binarySnowball.String(); result != expected {
t.Fatalf("Expected:\n%s\nReturned:\n%s", expected, result)
}
binarySnowball.RecordUnsuccessfulPoll() binarySnowball.RecordUnsuccessfulPoll()
for i := 0; i < 3; i++ {
if binarySnowball.Preference() != 0 {
t.Fatalf("Wrong preference")
} else if binarySnowball.Finalized() {
t.Fatalf("Should not have finalized")
}
binarySnowball.RecordSuccessfulPoll(1)
binarySnowball.RecordUnsuccessfulPoll()
}
if binarySnowball.Preference() != 1 {
t.Fatalf("Wrong preference")
} else if binarySnowball.Finalized() {
t.Fatalf("Should not have finalized")
}
binarySnowball.RecordSuccessfulPoll(1) binarySnowball.RecordSuccessfulPoll(1)
if binarySnowball.Preference() != 1 {
if binarySnowball.Finalized() { t.Fatalf("Wrong preference")
} else if binarySnowball.Finalized() {
t.Fatalf("Should not have finalized") t.Fatalf("Should not have finalized")
} }
@ -57,4 +78,9 @@ func TestUnarySnowball(t *testing.T) {
} else if !binarySnowball.Finalized() { } else if !binarySnowball.Finalized() {
t.Fatalf("Should have finalized") t.Fatalf("Should have finalized")
} }
expected = "SB(NumSuccessfulPolls = 2, SF(Confidence = 1, Finalized = false))"
if str := sb.String(); str != expected {
t.Fatalf("Wrong state. Expected:\n%s\nGot:\n%s", expected, str)
}
} }

View File

@ -78,9 +78,12 @@ func (i *issuer) Update() {
vdrSet.Add(vdr.ID()) vdrSet.Add(vdr.ID())
} }
toSample := ids.ShortSet{} // Copy to a new variable because we may remove an element in sender.Sender
toSample.Union(vdrSet) // and we don't want that to affect the set of validators we wait for [ie vdrSet]
i.t.RequestID++ i.t.RequestID++
if numVdrs := len(vdrs); numVdrs == p.K && i.t.polls.Add(i.t.RequestID, vdrSet.Len()) { if numVdrs := len(vdrs); numVdrs == p.K && i.t.polls.Add(i.t.RequestID, vdrSet) {
i.t.Config.Sender.PushQuery(vdrSet, i.t.RequestID, vtxID, i.vtx.Bytes()) i.t.Config.Sender.PushQuery(toSample, i.t.RequestID, vtxID, i.vtx.Bytes())
} else if numVdrs < p.K { } else if numVdrs < p.K {
i.t.Config.Context.Log.Error("Query for %s was dropped due to an insufficient number of validators", vtxID) i.t.Config.Context.Log.Error("Query for %s was dropped due to an insufficient number of validators", vtxID)
} }

View File

@ -38,10 +38,10 @@ type polls struct {
// Add to the current set of polls // Add to the current set of polls
// Returns true if the poll was registered correctly and the network sample // Returns true if the poll was registered correctly and the network sample
// should be made. // should be made.
func (p *polls) Add(requestID uint32, numPolled int) bool { func (p *polls) Add(requestID uint32, vdrs ids.ShortSet) bool {
poll, exists := p.m[requestID] poll, exists := p.m[requestID]
if !exists { if !exists {
poll.numPending = numPolled poll.polled = vdrs
p.m[requestID] = poll p.m[requestID] = poll
p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics
@ -59,7 +59,7 @@ func (p *polls) Vote(requestID uint32, vdr ids.ShortID, votes []ids.ID) (ids.Uni
return nil, false return nil, false
} }
poll.Vote(votes) poll.Vote(votes, vdr)
if poll.Finished() { if poll.Finished() {
p.log.Verbo("Poll is finished") p.log.Verbo("Poll is finished")
delete(p.m, requestID) delete(p.m, requestID)
@ -83,19 +83,19 @@ func (p *polls) String() string {
// poll represents the current state of a network poll for a vertex // poll represents the current state of a network poll for a vertex
type poll struct { type poll struct {
votes ids.UniqueBag votes ids.UniqueBag
numPending int polled ids.ShortSet
} }
// Vote registers a vote for this poll // Vote registers a vote for this poll
func (p *poll) Vote(votes []ids.ID) { func (p *poll) Vote(votes []ids.ID, vdr ids.ShortID) {
if p.numPending > 0 { if p.polled.Contains(vdr) {
p.numPending-- p.polled.Remove(vdr)
p.votes.Add(uint(p.numPending), votes...) p.votes.Add(uint(p.polled.Len()), votes...)
} }
} }
// Finished returns true if the poll has completed, with no more required // Finished returns true if the poll has completed, with no more required
// responses // responses
func (p poll) Finished() bool { return p.numPending <= 0 } func (p poll) Finished() bool { return p.polled.Len() == 0 }
func (p poll) String() string { return fmt.Sprintf("Waiting on %d chits", p.numPending) } func (p poll) String() string { return fmt.Sprintf("Waiting on %d chits", p.polled.Len()) }

View File

@ -471,8 +471,11 @@ func (t *Transitive) issueRepoll() {
vdrSet.Add(vdr.ID()) vdrSet.Add(vdr.ID())
} }
vdrCopy := ids.ShortSet{}
vdrCopy.Union((vdrSet))
t.RequestID++ t.RequestID++
if numVdrs := len(vdrs); numVdrs == p.K && t.polls.Add(t.RequestID, vdrSet.Len()) { if numVdrs := len(vdrs); numVdrs == p.K && t.polls.Add(t.RequestID, vdrCopy) {
t.Config.Sender.PullQuery(vdrSet, t.RequestID, vtxID) t.Config.Sender.PullQuery(vdrSet, t.RequestID, vtxID)
} else if numVdrs < p.K { } else if numVdrs < p.K {
t.Config.Context.Log.Error("re-query for %s was dropped due to an insufficient number of validators", vtxID) t.Config.Context.Log.Error("re-query for %s was dropped due to an insufficient number of validators", vtxID)

View File

@ -3085,3 +3085,120 @@ func TestEngineDuplicatedIssuance(t *testing.T) {
te.Notify(common.PendingTxs) te.Notify(common.PendingTxs)
} }
func TestEngineDoubleChit(t *testing.T) {
config := DefaultConfig()
config.Params.Alpha = 2
config.Params.K = 2
vdr0 := validators.GenerateRandomValidator(1)
vdr1 := validators.GenerateRandomValidator(1)
vals := validators.NewSet()
vals.Add(vdr0)
vals.Add(vdr1)
config.Validators = vals
sender := &common.SenderTest{}
sender.T = t
config.Sender = sender
sender.Default(true)
sender.CantGetAcceptedFrontier = false
st := &stateTest{t: t}
config.State = st
st.Default(true)
gVtx := &Vtx{
id: GenerateID(),
status: choices.Accepted,
}
mVtx := &Vtx{
id: GenerateID(),
status: choices.Accepted,
}
vts := []avalanche.Vertex{gVtx, mVtx}
utxos := []ids.ID{GenerateID()}
tx := &TestTx{
TestTx: snowstorm.TestTx{
Identifier: GenerateID(),
Stat: choices.Processing,
},
}
tx.Ins.Add(utxos[0])
vtx := &Vtx{
parents: vts,
id: GenerateID(),
txs: []snowstorm.Tx{tx},
height: 1,
status: choices.Processing,
bytes: []byte{1, 1, 2, 3},
}
st.edge = func() []ids.ID { return []ids.ID{vts[0].ID(), vts[1].ID()} }
st.getVertex = func(id ids.ID) (avalanche.Vertex, error) {
switch {
case id.Equals(gVtx.ID()):
return gVtx, nil
case id.Equals(mVtx.ID()):
return mVtx, nil
}
t.Fatalf("Unknown vertex")
panic("Should have errored")
}
te := &Transitive{}
te.Initialize(config)
te.finishBootstrapping()
reqID := new(uint32)
sender.PushQueryF = func(inVdrs ids.ShortSet, requestID uint32, vtxID ids.ID, _ []byte) {
*reqID = requestID
if inVdrs.Len() != 2 {
t.Fatalf("Wrong number of validators")
}
if !vtxID.Equals(vtx.ID()) {
t.Fatalf("Wrong vertex requested")
}
}
st.getVertex = func(id ids.ID) (avalanche.Vertex, error) {
switch {
case id.Equals(vtx.ID()):
return vtx, nil
}
t.Fatalf("Unknown vertex")
panic("Should have errored")
}
te.insert(vtx)
votes := ids.Set{}
votes.Add(vtx.ID())
if status := tx.Status(); status != choices.Processing {
t.Fatalf("Wrong tx status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr0.ID(), *reqID, votes)
if status := tx.Status(); status != choices.Processing {
t.Fatalf("Wrong tx status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr0.ID(), *reqID, votes)
if status := tx.Status(); status != choices.Processing {
t.Fatalf("Wrong tx status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr1.ID(), *reqID, votes)
if status := tx.Status(); status != choices.Accepted {
t.Fatalf("Wrong tx status: %s ; expected: %s", status, choices.Accepted)
}
}

View File

@ -22,11 +22,11 @@ type polls struct {
// Add to the current set of polls // Add to the current set of polls
// Returns true if the poll was registered correctly and the network sample // Returns true if the poll was registered correctly and the network sample
// should be made. // should be made.
func (p *polls) Add(requestID uint32, numPolled int) bool { func (p *polls) Add(requestID uint32, vdrs ids.ShortSet) bool {
poll, exists := p.m[requestID] poll, exists := p.m[requestID]
if !exists { if !exists {
poll.alpha = p.alpha poll.alpha = p.alpha
poll.numPolled = numPolled poll.polled = vdrs
p.m[requestID] = poll p.m[requestID] = poll
p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics
@ -42,7 +42,7 @@ func (p *polls) Vote(requestID uint32, vdr ids.ShortID, vote ids.ID) (ids.Bag, b
if !exists { if !exists {
return ids.Bag{}, false return ids.Bag{}, false
} }
poll.Vote(vote) poll.Vote(vote, vdr)
if poll.Finished() { if poll.Finished() {
delete(p.m, requestID) delete(p.m, requestID)
p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics
@ -60,7 +60,7 @@ func (p *polls) CancelVote(requestID uint32, vdr ids.ShortID) (ids.Bag, bool) {
return ids.Bag{}, false return ids.Bag{}, false
} }
poll.CancelVote() poll.CancelVote(vdr)
if poll.Finished() { if poll.Finished() {
delete(p.m, requestID) delete(p.m, requestID)
p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics p.numPolls.Set(float64(len(p.m))) // Tracks performance statistics
@ -83,22 +83,18 @@ func (p *polls) String() string {
// poll represents the current state of a network poll for a block // poll represents the current state of a network poll for a block
type poll struct { type poll struct {
alpha int alpha int
votes ids.Bag votes ids.Bag
numPolled int polled ids.ShortSet
} }
// Vote registers a vote for this poll // Vote registers a vote for this poll
func (p *poll) CancelVote() { func (p *poll) CancelVote(vdr ids.ShortID) { p.polled.Remove(vdr) }
if p.numPolled > 0 {
p.numPolled--
}
}
// Vote registers a vote for this poll // Vote registers a vote for this poll
func (p *poll) Vote(vote ids.ID) { func (p *poll) Vote(vote ids.ID, vdr ids.ShortID) {
if p.numPolled > 0 { if p.polled.Contains(vdr) {
p.numPolled-- p.polled.Remove(vdr)
p.votes.Add(vote) p.votes.Add(vote)
} }
} }
@ -106,13 +102,14 @@ func (p *poll) Vote(vote ids.ID) {
// Finished returns true if the poll has completed, with no more required // Finished returns true if the poll has completed, with no more required
// responses // responses
func (p poll) Finished() bool { func (p poll) Finished() bool {
remaining := p.polled.Len()
received := p.votes.Len() received := p.votes.Len()
_, freq := p.votes.Mode() _, freq := p.votes.Mode()
return p.numPolled == 0 || // All k nodes responded return remaining == 0 || // All k nodes responded
freq >= p.alpha || // An alpha majority has returned freq >= p.alpha || // An alpha majority has returned
received+p.numPolled < p.alpha // An alpha majority can never return received+remaining < p.alpha // An alpha majority can never return
} }
func (p poll) String() string { func (p poll) String() string {
return fmt.Sprintf("Waiting on %d chits", p.numPolled) return fmt.Sprintf("Waiting on %d chits from %s", p.polled.Len(), p.polled)
} }

View File

@ -542,18 +542,15 @@ func (t *Transitive) pullSample(blkID ids.ID) {
vdrSet.Add(vdr.ID()) vdrSet.Add(vdr.ID())
} }
if numVdrs := len(vdrs); numVdrs != p.K { toSample := ids.ShortSet{}
t.Config.Context.Log.Error("query for %s was dropped due to an insufficient number of validators", blkID) toSample.Union(vdrSet)
return
}
t.RequestID++ t.RequestID++
if !t.polls.Add(t.RequestID, vdrSet.Len()) { if numVdrs := len(vdrs); numVdrs == p.K && t.polls.Add(t.RequestID, vdrSet) {
t.Config.Context.Log.Error("query for %s was dropped due to use of a duplicated requestID", blkID) t.Config.Sender.PullQuery(toSample, t.RequestID, blkID)
return } else if numVdrs < p.K {
t.Config.Context.Log.Error("query for %s was dropped due to an insufficient number of validators", blkID)
} }
t.Config.Sender.PullQuery(vdrSet, t.RequestID, blkID)
} }
// send a push request for this block // send a push request for this block
@ -566,19 +563,15 @@ func (t *Transitive) pushSample(blk snowman.Block) {
vdrSet.Add(vdr.ID()) vdrSet.Add(vdr.ID())
} }
blkID := blk.ID() toSample := ids.ShortSet{}
if numVdrs := len(vdrs); numVdrs != p.K { toSample.Union(vdrSet)
t.Config.Context.Log.Error("query for %s was dropped due to an insufficient number of validators", blkID)
return
}
t.RequestID++ t.RequestID++
if !t.polls.Add(t.RequestID, vdrSet.Len()) { if numVdrs := len(vdrs); numVdrs == p.K && t.polls.Add(t.RequestID, vdrSet) {
t.Config.Context.Log.Error("query for %s was dropped due to use of a duplicated requestID", blkID) t.Config.Sender.PushQuery(toSample, t.RequestID, blk.ID(), blk.Bytes())
return } else if numVdrs < p.K {
t.Config.Context.Log.Error("query for %s was dropped due to an insufficient number of validators", blk.ID())
} }
t.Config.Sender.PushQuery(vdrSet, t.RequestID, blkID, blk.Bytes())
} }
func (t *Transitive) deliver(blk snowman.Block) error { func (t *Transitive) deliver(blk snowman.Block) error {

View File

@ -1522,3 +1522,124 @@ func TestEngineAggressivePolling(t *testing.T) {
t.Fatalf("Should have sent an additional pull query") t.Fatalf("Should have sent an additional pull query")
} }
} }
func TestEngineDoubleChit(t *testing.T) {
config := DefaultConfig()
config.Params = snowball.Parameters{
Metrics: prometheus.NewRegistry(),
K: 2,
Alpha: 2,
BetaVirtuous: 1,
BetaRogue: 2,
}
vdr0 := validators.GenerateRandomValidator(1)
vdr1 := validators.GenerateRandomValidator(1)
vals := validators.NewSet()
config.Validators = vals
vals.Add(vdr0)
vals.Add(vdr1)
sender := &common.SenderTest{}
sender.T = t
config.Sender = sender
sender.Default(true)
vm := &VMTest{}
vm.T = t
config.VM = vm
vm.Default(true)
vm.CantSetPreference = false
gBlk := &Blk{
id: GenerateID(),
status: choices.Accepted,
}
vm.LastAcceptedF = func() ids.ID { return gBlk.ID() }
sender.CantGetAcceptedFrontier = false
vm.GetBlockF = func(id ids.ID) (snowman.Block, error) {
switch {
case id.Equals(gBlk.ID()):
return gBlk, nil
}
t.Fatalf("Unknown block")
panic("Should have errored")
}
te := &Transitive{}
te.Initialize(config)
te.finishBootstrapping()
vm.LastAcceptedF = nil
sender.CantGetAcceptedFrontier = true
blk := &Blk{
parent: gBlk,
id: GenerateID(),
status: choices.Processing,
bytes: []byte{1},
}
queried := new(bool)
queryRequestID := new(uint32)
sender.PushQueryF = func(inVdrs ids.ShortSet, requestID uint32, blkID ids.ID, blkBytes []byte) {
if *queried {
t.Fatalf("Asked multiple times")
}
*queried = true
*queryRequestID = requestID
vdrSet := ids.ShortSet{}
vdrSet.Add(vdr0.ID(), vdr1.ID())
if !inVdrs.Equals(vdrSet) {
t.Fatalf("Asking wrong validator for preference")
}
if !blk.ID().Equals(blkID) {
t.Fatalf("Asking for wrong block")
}
}
te.insert(blk)
vm.GetBlockF = func(id ids.ID) (snowman.Block, error) {
switch {
case id.Equals(gBlk.ID()):
return gBlk, nil
case id.Equals(blk.ID()):
return blk, nil
}
t.Fatalf("Unknown block")
panic("Should have errored")
}
blkSet := ids.Set{}
blkSet.Add(blk.ID())
if status := blk.Status(); status != choices.Processing {
t.Fatalf("Wrong status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr0.ID(), *queryRequestID, blkSet)
if status := blk.Status(); status != choices.Processing {
t.Fatalf("Wrong status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr0.ID(), *queryRequestID, blkSet)
if status := blk.Status(); status != choices.Processing {
t.Fatalf("Wrong status: %s ; expected: %s", status, choices.Processing)
}
te.Chits(vdr1.ID(), *queryRequestID, blkSet)
if status := blk.Status(); status != choices.Accepted {
t.Fatalf("Wrong status: %s ; expected: %s", status, choices.Accepted)
}
}

View File

@ -7,6 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/networking/timeout" "github.com/ava-labs/gecko/snow/networking/timeout"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
@ -67,7 +69,7 @@ func (sr *ChainRouter) RemoveChain(chainID ids.ID) {
sr.lock.RLock() sr.lock.RLock()
chain, exists := sr.chains[chainID.Key()] chain, exists := sr.chains[chainID.Key()]
if !exists { if !exists {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("can't remove unknown chain %s", chainID)
sr.lock.RUnlock() sr.lock.RUnlock()
return return
} }
@ -95,7 +97,7 @@ func (sr *ChainRouter) GetAcceptedFrontier(validatorID ids.ShortID, chainID ids.
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAcceptedFrontier(validatorID, requestID) chain.GetAcceptedFrontier(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("GetAcceptedFrontier(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -111,7 +113,7 @@ func (sr *ChainRouter) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID,
sr.timeouts.Cancel(validatorID, chainID, requestID) sr.timeouts.Cancel(validatorID, chainID, requestID)
} }
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("AcceptedFrontier(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs)
} }
} }
@ -126,7 +128,7 @@ func (sr *ChainRouter) GetAcceptedFrontierFailed(validatorID ids.ShortID, chainI
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAcceptedFrontierFailed(validatorID, requestID) chain.GetAcceptedFrontierFailed(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Error("GetAcceptedFrontierFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -140,7 +142,7 @@ func (sr *ChainRouter) GetAccepted(validatorID ids.ShortID, chainID ids.ID, requ
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAccepted(validatorID, requestID, containerIDs) chain.GetAccepted(validatorID, requestID, containerIDs)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("GetAccepted(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs)
} }
} }
@ -156,7 +158,7 @@ func (sr *ChainRouter) Accepted(validatorID ids.ShortID, chainID ids.ID, request
sr.timeouts.Cancel(validatorID, chainID, requestID) sr.timeouts.Cancel(validatorID, chainID, requestID)
} }
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("Accepted(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs)
} }
} }
@ -171,7 +173,7 @@ func (sr *ChainRouter) GetAcceptedFailed(validatorID ids.ShortID, chainID ids.ID
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAcceptedFailed(validatorID, requestID) chain.GetAcceptedFailed(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Error("GetAcceptedFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -185,7 +187,7 @@ func (sr *ChainRouter) GetAncestors(validatorID ids.ShortID, chainID ids.ID, req
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAncestors(validatorID, requestID, containerID) chain.GetAncestors(validatorID, requestID, containerID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("GetAncestors(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -202,7 +204,7 @@ func (sr *ChainRouter) MultiPut(validatorID ids.ShortID, chainID ids.ID, request
sr.timeouts.Cancel(validatorID, chainID, requestID) sr.timeouts.Cancel(validatorID, chainID, requestID)
} }
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("MultiPut(%s, %s, %d, %d) dropped due to unknown chain", validatorID, chainID, requestID, len(containers))
} }
} }
@ -216,7 +218,7 @@ func (sr *ChainRouter) GetAncestorsFailed(validatorID ids.ShortID, chainID ids.I
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetAncestorsFailed(validatorID, requestID) chain.GetAncestorsFailed(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Error("GetAncestorsFailed(%s, %s, %d, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -229,7 +231,7 @@ func (sr *ChainRouter) Get(validatorID ids.ShortID, chainID ids.ID, requestID ui
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.Get(validatorID, requestID, containerID) chain.Get(validatorID, requestID, containerID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("Get(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID)
} }
} }
@ -246,7 +248,8 @@ func (sr *ChainRouter) Put(validatorID ids.ShortID, chainID ids.ID, requestID ui
sr.timeouts.Cancel(validatorID, chainID, requestID) sr.timeouts.Cancel(validatorID, chainID, requestID)
} }
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("Put(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID)
sr.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container})
} }
} }
@ -260,7 +263,7 @@ func (sr *ChainRouter) GetFailed(validatorID ids.ShortID, chainID ids.ID, reques
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.GetFailed(validatorID, requestID) chain.GetFailed(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Error("GetFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }
@ -273,7 +276,8 @@ func (sr *ChainRouter) PushQuery(validatorID ids.ShortID, chainID ids.ID, reques
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.PushQuery(validatorID, requestID, containerID, container) chain.PushQuery(validatorID, requestID, containerID, container)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("PushQuery(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID)
sr.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container})
} }
} }
@ -286,7 +290,7 @@ func (sr *ChainRouter) PullQuery(validatorID ids.ShortID, chainID ids.ID, reques
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.PullQuery(validatorID, requestID, containerID) chain.PullQuery(validatorID, requestID, containerID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("PullQuery(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID)
} }
} }
@ -302,7 +306,7 @@ func (sr *ChainRouter) Chits(validatorID ids.ShortID, chainID ids.ID, requestID
sr.timeouts.Cancel(validatorID, chainID, requestID) sr.timeouts.Cancel(validatorID, chainID, requestID)
} }
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Debug("Chits(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, votes)
} }
} }
@ -316,7 +320,7 @@ func (sr *ChainRouter) QueryFailed(validatorID ids.ShortID, chainID ids.ID, requ
if chain, exists := sr.chains[chainID.Key()]; exists { if chain, exists := sr.chains[chainID.Key()]; exists {
chain.QueryFailed(validatorID, requestID) chain.QueryFailed(validatorID, requestID)
} else { } else {
sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) sr.log.Error("QueryFailed(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID)
} }
} }

394
utils/codec/codec.go Normal file
View File

@ -0,0 +1,394 @@
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package codec
import (
"errors"
"fmt"
"reflect"
"unicode"
"github.com/ava-labs/gecko/utils/wrappers"
)
const (
defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal()
defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal(). Should be <= math.MaxUint32.
// initial capacity of byte slice that values are marshaled into.
// Larger value --> need less memory allocations but possibly have allocated but unused memory
// Smaller value --> need more memory allocations but more efficient use of allocated memory
initialSliceCap = 256
)
var (
errNil = errors.New("can't marshal/unmarshal nil pointer or interface")
errNeedPointer = errors.New("argument to unmarshal should be a pointer")
)
// Codec handles marshaling and unmarshaling of structs
type codec struct {
maxSize int
maxSliceLen int
typeIDToType map[uint32]reflect.Type
typeToTypeID map[reflect.Type]uint32
// Key: a struct type
// Value: Slice where each element is index in the struct type
// of a field that is serialized/deserialized
// e.g. Foo --> [1,5,8] means Foo.Field(1), etc. are to be serialized/deserialized
// We assume this cache is pretty small (a few hundred keys at most)
// and doesn't take up much memory
serializedFieldIndices map[reflect.Type][]int
}
// Codec marshals and unmarshals
type Codec interface {
RegisterType(interface{}) error
Marshal(interface{}) ([]byte, error)
Unmarshal([]byte, interface{}) error
}
// New returns a new codec
func New(maxSize, maxSliceLen int) Codec {
return &codec{
maxSize: maxSize,
maxSliceLen: maxSliceLen,
typeIDToType: map[uint32]reflect.Type{},
typeToTypeID: map[reflect.Type]uint32{},
serializedFieldIndices: map[reflect.Type][]int{},
}
}
// NewDefault returns a new codec with reasonable default values
func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) }
// RegisterType is used to register types that may be unmarshaled into an interface
// [val] is a value of the type being registered
func (c *codec) RegisterType(val interface{}) error {
valType := reflect.TypeOf(val)
if _, exists := c.typeToTypeID[valType]; exists {
return fmt.Errorf("type %v has already been registered", valType)
}
c.typeIDToType[uint32(len(c.typeIDToType))] = reflect.TypeOf(val)
c.typeToTypeID[valType] = uint32(len(c.typeIDToType) - 1)
return nil
}
// A few notes:
// 1) See codec_test.go for examples of usage
// 2) We use "marshal" and "serialize" interchangeably, and "unmarshal" and "deserialize" interchangeably
// 3) To include a field of a struct in the serialized form, add the tag `serialize:"true"` to it
// 4) These typed members of a struct may be serialized:
// bool, string, uint[8,16,32,64], int[8,16,32,64],
// structs, slices, arrays, interface.
// structs, slices and arrays can only be serialized if their constituent values can be.
// 5) To marshal an interface, you must pass a pointer to the value
// 6) To unmarshal an interface, you must call codec.RegisterType([instance of the type that fulfills the interface]).
// 7) Serialized fields must be exported
// 8) nil slices are marshaled as empty slices
// To marshal an interface, [value] must be a pointer to the interface
func (c *codec) Marshal(value interface{}) ([]byte, error) {
if value == nil {
return nil, errNil // can't marshal nil
}
p := &wrappers.Packer{MaxSize: c.maxSize, Bytes: make([]byte, 0, initialSliceCap)}
if err := c.marshal(reflect.ValueOf(value), p); err != nil {
return nil, err
}
return p.Bytes, nil
}
// marshal writes the byte representation of [value] to [p]
// [value]'s underlying value must not be a nil pointer or interface
func (c *codec) marshal(value reflect.Value, p *wrappers.Packer) error {
valueKind := value.Kind()
switch valueKind {
case reflect.Interface, reflect.Ptr, reflect.Invalid:
if value.IsNil() { // Can't marshal nil (except nil slices)
return errNil
}
}
switch valueKind {
case reflect.Uint8:
p.PackByte(uint8(value.Uint()))
return p.Err
case reflect.Int8:
p.PackByte(uint8(value.Int()))
return p.Err
case reflect.Uint16:
p.PackShort(uint16(value.Uint()))
return p.Err
case reflect.Int16:
p.PackShort(uint16(value.Int()))
return p.Err
case reflect.Uint32:
p.PackInt(uint32(value.Uint()))
return p.Err
case reflect.Int32:
p.PackInt(uint32(value.Int()))
return p.Err
case reflect.Uint64:
p.PackLong(value.Uint())
return p.Err
case reflect.Int64:
p.PackLong(uint64(value.Int()))
return p.Err
case reflect.String:
p.PackStr(value.String())
return p.Err
case reflect.Bool:
p.PackBool(value.Bool())
return p.Err
case reflect.Uintptr, reflect.Ptr:
return c.marshal(value.Elem(), p)
case reflect.Interface:
underlyingValue := value.Interface()
typeID, ok := c.typeToTypeID[reflect.TypeOf(underlyingValue)] // Get the type ID of the value being marshaled
if !ok {
return fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(underlyingValue).String())
}
p.PackInt(typeID) // Pack type ID so we know what to unmarshal this into
if p.Err != nil {
return p.Err
}
if err := c.marshal(value.Elem(), p); err != nil {
return err
}
return p.Err
case reflect.Slice:
numElts := value.Len() // # elements in the slice/array. 0 if this slice is nil.
if numElts > c.maxSliceLen {
return fmt.Errorf("slice length, %d, exceeds maximum length, %d", numElts, c.maxSliceLen)
}
p.PackInt(uint32(numElts)) // pack # elements
if p.Err != nil {
return p.Err
}
for i := 0; i < numElts; i++ { // Process each element in the slice
if err := c.marshal(value.Index(i), p); err != nil {
return err
}
}
return nil
case reflect.Array:
numElts := value.Len()
if numElts > c.maxSliceLen {
return fmt.Errorf("array length, %d, exceeds maximum length, %d", numElts, c.maxSliceLen)
}
for i := 0; i < numElts; i++ { // Process each element in the array
if err := c.marshal(value.Index(i), p); err != nil {
return err
}
}
return nil
case reflect.Struct:
serializedFields, err := c.getSerializedFieldIndices(value.Type())
if err != nil {
return err
}
for _, fieldIndex := range serializedFields { // Go through all fields of this struct that are serialized
if err := c.marshal(value.Field(fieldIndex), p); err != nil { // Serialize the field and write to byte array
return err
}
}
return nil
default:
return fmt.Errorf("can't marshal unknown kind %s", valueKind)
}
}
// Unmarshal unmarshals [bytes] into [dest], where
// [dest] must be a pointer or interface
func (c *codec) Unmarshal(bytes []byte, dest interface{}) error {
switch {
case len(bytes) > c.maxSize:
return fmt.Errorf("byte array exceeds maximum length, %d", c.maxSize)
case dest == nil:
return errNil
}
destPtr := reflect.ValueOf(dest)
if destPtr.Kind() != reflect.Ptr {
return errNeedPointer
}
p := &wrappers.Packer{MaxSize: c.maxSize, Bytes: bytes}
destVal := destPtr.Elem()
if err := c.unmarshal(p, destVal); err != nil {
return err
}
return nil
}
// Unmarshal from p.Bytes into [value]. [value] must be addressable.
func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error {
switch value.Kind() {
case reflect.Uint8:
value.SetUint(uint64(p.UnpackByte()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal uint8: %s", p.Err)
}
return nil
case reflect.Int8:
value.SetInt(int64(p.UnpackByte()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal int8: %s", p.Err)
}
return nil
case reflect.Uint16:
value.SetUint(uint64(p.UnpackShort()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal uint16: %s", p.Err)
}
return nil
case reflect.Int16:
value.SetInt(int64(p.UnpackShort()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal int16: %s", p.Err)
}
return nil
case reflect.Uint32:
value.SetUint(uint64(p.UnpackInt()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal uint32: %s", p.Err)
}
return nil
case reflect.Int32:
value.SetInt(int64(p.UnpackInt()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal int32: %s", p.Err)
}
return nil
case reflect.Uint64:
value.SetUint(uint64(p.UnpackLong()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal uint64: %s", p.Err)
}
return nil
case reflect.Int64:
value.SetInt(int64(p.UnpackLong()))
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal int64: %s", p.Err)
}
return nil
case reflect.Bool:
value.SetBool(p.UnpackBool())
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal bool: %s", p.Err)
}
return nil
case reflect.Slice:
numElts := int(p.UnpackInt())
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal slice: %s", p.Err)
}
// set [value] to be a slice of the appropriate type/capacity (right now it is nil)
value.Set(reflect.MakeSlice(value.Type(), numElts, numElts))
// Unmarshal each element into the appropriate index of the slice
for i := 0; i < numElts; i++ {
if err := c.unmarshal(p, value.Index(i)); err != nil {
return fmt.Errorf("couldn't unmarshal slice element: %s", err)
}
}
return nil
case reflect.Array:
for i := 0; i < value.Len(); i++ {
if err := c.unmarshal(p, value.Index(i)); err != nil {
return fmt.Errorf("couldn't unmarshal array element: %s", err)
}
}
return nil
case reflect.String:
value.SetString(p.UnpackStr())
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal string: %s", p.Err)
}
return nil
case reflect.Interface:
typeID := p.UnpackInt() // Get the type ID
if p.Err != nil {
return fmt.Errorf("couldn't unmarshal interface: %s", p.Err)
}
// Get a type that implements the interface
implementingType, ok := c.typeIDToType[typeID]
if !ok {
return fmt.Errorf("couldn't unmarshal interface: unknown type ID %d", typeID)
}
// Ensure type actually does implement the interface
if valueType := value.Type(); !implementingType.Implements(valueType) {
return fmt.Errorf("couldn't unmarshal interface: %s does not implement interface %s", implementingType, valueType)
}
intfImplementor := reflect.New(implementingType).Elem() // instance of the proper type
// Unmarshal into the struct
if err := c.unmarshal(p, intfImplementor); err != nil {
return fmt.Errorf("couldn't unmarshal interface: %s", err)
}
// And assign the filled struct to the value
value.Set(intfImplementor)
return nil
case reflect.Struct:
// Get indices of fields that will be unmarshaled into
serializedFieldIndices, err := c.getSerializedFieldIndices(value.Type())
if err != nil {
return fmt.Errorf("couldn't unmarshal struct: %s", err)
}
// Go through the fields and umarshal into them
for _, index := range serializedFieldIndices {
if err := c.unmarshal(p, value.Field(index)); err != nil {
return fmt.Errorf("couldn't unmarshal struct: %s", err)
}
}
return nil
case reflect.Ptr:
// Get the type this pointer points to
t := value.Type().Elem()
// Create a new pointer to a new value of the underlying type
v := reflect.New(t)
// Fill the value
if err := c.unmarshal(p, v.Elem()); err != nil {
return fmt.Errorf("couldn't unmarshal pointer: %s", err)
}
// Assign to the top-level struct's member
value.Set(v)
return nil
case reflect.Invalid:
return errNil
default:
return fmt.Errorf("can't unmarshal unknown type %s", value.Kind().String())
}
}
// Returns the indices of the serializable fields of [t], which is a struct type
// Returns an error if a field has tag "serialize: true" but the field is unexported
// e.g. getSerializedFieldIndices(Foo) --> [1,5,8] means Foo.Field(1), Foo.Field(5), Foo.Field(8)
// are to be serialized/deserialized
func (c *codec) getSerializedFieldIndices(t reflect.Type) ([]int, error) {
if c.serializedFieldIndices == nil {
c.serializedFieldIndices = make(map[reflect.Type][]int)
}
if serializedFields, ok := c.serializedFieldIndices[t]; ok { // use pre-computed result
return serializedFields, nil
}
numFields := t.NumField()
serializedFields := make([]int, 0, numFields)
for i := 0; i < numFields; i++ { // Go through all fields of this struct
field := t.Field(i)
if field.Tag.Get("serialize") != "true" { // Skip fields we don't need to serialize
continue
}
if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields
return []int{}, fmt.Errorf("can't marshal unexported field %s", field.Name)
}
serializedFields = append(serializedFields, i)
}
c.serializedFieldIndices[t] = serializedFields // cache result
return serializedFields, nil
}

View File

@ -35,13 +35,22 @@ func BenchmarkMarshal(b *testing.B) {
}, },
MyPointer: &temp, MyPointer: &temp,
} }
var unmarshaledMyStructInstance myStruct
codec := NewDefault() codec := NewDefault()
codec.RegisterType(&MyInnerStruct{}) // Register the types that may be unmarshaled into interfaces codec.RegisterType(&MyInnerStruct{}) // Register the types that may be unmarshaled into interfaces
codec.RegisterType(&MyInnerStruct2{}) codec.RegisterType(&MyInnerStruct2{})
codec.Marshal(myStructInstance) // warm up serializedFields cache
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
codec.Marshal(myStructInstance) bytes, err := codec.Marshal(myStructInstance)
if err != nil {
b.Fatal(err)
}
if err := codec.Unmarshal(bytes, &unmarshaledMyStructInstance); err != nil {
b.Fatal(err)
}
} }
} }

View File

@ -5,6 +5,7 @@ package codec
import ( import (
"bytes" "bytes"
"math"
"reflect" "reflect"
"testing" "testing"
) )
@ -104,36 +105,8 @@ func TestStruct(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(myStructUnmarshaled.Member1, myStructInstance.Member1) { if !reflect.DeepEqual(*myStructUnmarshaled, myStructInstance) {
t.Fatal("expected unmarshaled struct to be same as original struct") t.Fatal("should be same")
} else if !bytes.Equal(myStructUnmarshaled.MySlice, myStructInstance.MySlice) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MySlice2, myStructInstance.MySlice2) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MySlice3, myStructInstance.MySlice3) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MySlice3, myStructInstance.MySlice3) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MySlice4, myStructInstance.MySlice4) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct, myStructInstance.InnerStruct) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct2, myStructInstance.InnerStruct2) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MyArray2, myStructInstance.MyArray2) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MyArray3, myStructInstance.MyArray3) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MyArray4, myStructInstance.MyArray4) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MyInterface, myStructInstance.MyInterface) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MySlice5, myStructInstance.MySlice5) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct3, myStructInstance.InnerStruct3) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} else if !reflect.DeepEqual(myStructUnmarshaled.MyPointer, myStructInstance.MyPointer) {
t.Fatal("expected unmarshaled struct to be same as original struct")
} }
} }
@ -173,6 +146,28 @@ func TestSlice(t *testing.T) {
} }
} }
// Test marshalling/unmarshalling largest possible slice
func TestMaxSizeSlice(t *testing.T) {
mySlice := make([]string, math.MaxUint16, math.MaxUint16)
mySlice[0] = "first!"
mySlice[math.MaxUint16-1] = "last!"
codec := NewDefault()
bytes, err := codec.Marshal(mySlice)
if err != nil {
t.Fatal(err)
}
var sliceUnmarshaled []string
if err := codec.Unmarshal(bytes, &sliceUnmarshaled); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(mySlice, sliceUnmarshaled) {
t.Fatal("expected marshaled and unmarshaled values to match")
}
}
// Test marshalling a bool
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
myBool := true myBool := true
codec := NewDefault() codec := NewDefault()
@ -191,6 +186,7 @@ func TestBool(t *testing.T) {
} }
} }
// Test marshalling an array
func TestArray(t *testing.T) { func TestArray(t *testing.T) {
myArr := [5]uint64{5, 6, 7, 8, 9} myArr := [5]uint64{5, 6, 7, 8, 9}
codec := NewDefault() codec := NewDefault()
@ -209,6 +205,26 @@ func TestArray(t *testing.T) {
} }
} }
// Test marshalling a really big array
func TestBigArray(t *testing.T) {
myArr := [30000]uint64{5, 6, 7, 8, 9}
codec := NewDefault()
bytes, err := codec.Marshal(myArr)
if err != nil {
t.Fatal(err)
}
var myArrUnmarshaled [30000]uint64
if err := codec.Unmarshal(bytes, &myArrUnmarshaled); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(myArr, myArrUnmarshaled) {
t.Fatal("expected marshaled and unmarshaled values to match")
}
}
// Test marshalling a pointer to a struct
func TestPointerToStruct(t *testing.T) { func TestPointerToStruct(t *testing.T) {
myPtr := &MyInnerStruct{Str: "Hello!"} myPtr := &MyInnerStruct{Str: "Hello!"}
codec := NewDefault() codec := NewDefault()
@ -227,6 +243,7 @@ func TestPointerToStruct(t *testing.T) {
} }
} }
// Test marshalling a slice of structs
func TestSliceOfStruct(t *testing.T) { func TestSliceOfStruct(t *testing.T) {
mySlice := []MyInnerStruct3{ mySlice := []MyInnerStruct3{
MyInnerStruct3{ MyInnerStruct3{
@ -257,6 +274,7 @@ func TestSliceOfStruct(t *testing.T) {
} }
} }
// Test marshalling an interface
func TestInterface(t *testing.T) { func TestInterface(t *testing.T) {
codec := NewDefault() codec := NewDefault()
codec.RegisterType(&MyInnerStruct2{}) codec.RegisterType(&MyInnerStruct2{})
@ -278,6 +296,7 @@ func TestInterface(t *testing.T) {
} }
} }
// Test marshalling a slice of interfaces
func TestSliceOfInterface(t *testing.T) { func TestSliceOfInterface(t *testing.T) {
mySlice := []Foo{ mySlice := []Foo{
&MyInnerStruct{ &MyInnerStruct{
@ -304,6 +323,7 @@ func TestSliceOfInterface(t *testing.T) {
} }
} }
// Test marshalling an array of interfaces
func TestArrayOfInterface(t *testing.T) { func TestArrayOfInterface(t *testing.T) {
myArray := [2]Foo{ myArray := [2]Foo{
&MyInnerStruct{ &MyInnerStruct{
@ -330,6 +350,7 @@ func TestArrayOfInterface(t *testing.T) {
} }
} }
// Test marshalling a pointer to an interface
func TestPointerToInterface(t *testing.T) { func TestPointerToInterface(t *testing.T) {
var myinnerStruct Foo = &MyInnerStruct{Str: "Hello!"} var myinnerStruct Foo = &MyInnerStruct{Str: "Hello!"}
var myPtr *Foo = &myinnerStruct var myPtr *Foo = &myinnerStruct
@ -352,6 +373,7 @@ func TestPointerToInterface(t *testing.T) {
} }
} }
// Test marshalling a string
func TestString(t *testing.T) { func TestString(t *testing.T) {
myString := "Ayy" myString := "Ayy"
codec := NewDefault() codec := NewDefault()
@ -370,7 +392,7 @@ func TestString(t *testing.T) {
} }
} }
// Ensure a nil slice is unmarshaled as an empty slice // Ensure a nil slice is unmarshaled to slice with length 0
func TestNilSlice(t *testing.T) { func TestNilSlice(t *testing.T) {
type structWithSlice struct { type structWithSlice struct {
Slice []byte `serialize:"true"` Slice []byte `serialize:"true"`
@ -389,12 +411,12 @@ func TestNilSlice(t *testing.T) {
} }
if structUnmarshaled.Slice == nil || len(structUnmarshaled.Slice) != 0 { if structUnmarshaled.Slice == nil || len(structUnmarshaled.Slice) != 0 {
t.Fatal("expected slice to be empty slice") t.Fatal("expected slice to be non-nil and length 0")
} }
} }
// Ensure that trying to serialize a struct with an unexported member // Ensure that trying to serialize a struct with an unexported member
// that has `serialize:"true"` returns errUnexportedField // that has `serialize:"true"` returns error
func TestSerializeUnexportedField(t *testing.T) { func TestSerializeUnexportedField(t *testing.T) {
type s struct { type s struct {
ExportedField string `serialize:"true"` ExportedField string `serialize:"true"`
@ -407,8 +429,8 @@ func TestSerializeUnexportedField(t *testing.T) {
} }
codec := NewDefault() codec := NewDefault()
if _, err := codec.Marshal(myS); err != errMarshalUnexportedField { if _, err := codec.Marshal(myS); err == nil {
t.Fatalf("expected err to be errUnexportedField but was %v", err) t.Fatalf("expected err but got none")
} }
} }
@ -426,12 +448,12 @@ func TestSerializeOfNoSerializeField(t *testing.T) {
codec := NewDefault() codec := NewDefault()
marshalled, err := codec.Marshal(myS) marshalled, err := codec.Marshal(myS)
if err != nil { if err != nil {
t.Fatalf("Unexpected error %q", err) t.Fatal(err)
} }
unmarshalled := s{} unmarshalled := s{}
err = codec.Unmarshal(marshalled, &unmarshalled) err = codec.Unmarshal(marshalled, &unmarshalled)
if err != nil { if err != nil {
t.Fatalf("Unexpected error %q", err) t.Fatal(err)
} }
expectedUnmarshalled := s{SerializedField: "Serialize me"} expectedUnmarshalled := s{SerializedField: "Serialize me"}
if !reflect.DeepEqual(unmarshalled, expectedUnmarshalled) { if !reflect.DeepEqual(unmarshalled, expectedUnmarshalled) {
@ -443,11 +465,12 @@ type simpleSliceStruct struct {
Arr []uint32 `serialize:"true"` Arr []uint32 `serialize:"true"`
} }
func TestEmptySliceSerialization(t *testing.T) { // Test marshalling of nil slice
func TestNilSliceSerialization(t *testing.T) {
codec := NewDefault() codec := NewDefault()
val := &simpleSliceStruct{} val := &simpleSliceStruct{}
expected := []byte{0, 0, 0, 0} expected := []byte{0, 0, 0, 0} // nil slice marshaled as 0 length slice
result, err := codec.Marshal(val) result, err := codec.Marshal(val)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -456,6 +479,36 @@ func TestEmptySliceSerialization(t *testing.T) {
if !bytes.Equal(expected, result) { if !bytes.Equal(expected, result) {
t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result) t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result)
} }
valUnmarshaled := &simpleSliceStruct{}
if err = codec.Unmarshal(result, &valUnmarshaled); err != nil {
t.Fatal(err)
} else if len(valUnmarshaled.Arr) != 0 {
t.Fatal("should be 0 length")
}
}
// Test marshaling a slice that has 0 elements (but isn't nil)
func TestEmptySliceSerialization(t *testing.T) {
codec := NewDefault()
val := &simpleSliceStruct{Arr: make([]uint32, 0, 1)}
expected := []byte{0, 0, 0, 0} // 0 for size
result, err := codec.Marshal(val)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, result) {
t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result)
}
valUnmarshaled := &simpleSliceStruct{}
if err = codec.Unmarshal(result, &valUnmarshaled); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(valUnmarshaled, val) {
t.Fatal("should be same")
}
} }
type emptyStruct struct{} type emptyStruct struct{}
@ -464,13 +517,14 @@ type nestedSliceStruct struct {
Arr []emptyStruct `serialize:"true"` Arr []emptyStruct `serialize:"true"`
} }
// Test marshaling slice that is not nil and not empty
func TestSliceWithEmptySerialization(t *testing.T) { func TestSliceWithEmptySerialization(t *testing.T) {
codec := NewDefault() codec := NewDefault()
val := &nestedSliceStruct{ val := &nestedSliceStruct{
Arr: make([]emptyStruct, 1000), Arr: make([]emptyStruct, 1000),
} }
expected := []byte{0x00, 0x00, 0x03, 0xE8} expected := []byte{0x00, 0x00, 0x03, 0xE8} //1000 for numElts
result, err := codec.Marshal(val) result, err := codec.Marshal(val)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -485,7 +539,7 @@ func TestSliceWithEmptySerialization(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if len(unmarshaled.Arr) != 1000 { if len(unmarshaled.Arr) != 1000 {
t.Fatalf("Should have created an array of length %d", 1000) t.Fatalf("Should have created a slice of length %d", 1000)
} }
} }
@ -493,20 +547,15 @@ func TestSliceWithEmptySerializationOutOfMemory(t *testing.T) {
codec := NewDefault() codec := NewDefault()
val := &nestedSliceStruct{ val := &nestedSliceStruct{
Arr: make([]emptyStruct, 1000000), Arr: make([]emptyStruct, defaultMaxSliceLength+1),
} }
expected := []byte{0x00, 0x0f, 0x42, 0x40} // 1,000,000 in hex bytes, err := codec.Marshal(val)
result, err := codec.Marshal(val) if err == nil {
if err != nil { t.Fatal("should have failed due to slice length too large")
t.Fatal(err)
}
if !bytes.Equal(expected, result) {
t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result)
} }
unmarshaled := nestedSliceStruct{} unmarshaled := nestedSliceStruct{}
if err := codec.Unmarshal(expected, &unmarshaled); err == nil { if err := codec.Unmarshal(bytes, &unmarshaled); err == nil {
t.Fatalf("Should have errored due to excess memory requested") t.Fatalf("Should have errored due to excess memory requested")
} }
} }

View File

@ -61,26 +61,23 @@ func (p *Packer) CheckSpace(bytes int) {
} }
} }
// Expand ensures that there is [bytes] bytes left of space in the byte array. // Expand ensures that there is [bytes] bytes left of space in the byte slice.
// If this is not allowed due to the maximum size, an error is added to the // If this is not allowed due to the maximum size, an error is added to the packer
// packer // In order to understand this code, its important to understand the difference
// between a slice's length and its capacity.
func (p *Packer) Expand(bytes int) { func (p *Packer) Expand(bytes int) {
p.CheckSpace(0) neededSize := bytes + p.Offset // Need byte slice's length to be at least [neededSize]
if p.Errored() { switch {
case neededSize <= len(p.Bytes): // Byte slice has sufficient length already
return return
} case neededSize > p.MaxSize: // Lengthening the byte slice would cause it to grow too large
p.Err = errBadLength
neededSize := bytes + p.Offset
if neededSize <= len(p.Bytes) {
return return
} case neededSize <= cap(p.Bytes): // Byte slice has sufficient capacity to lengthen it without mem alloc
if neededSize > p.MaxSize {
p.Add(errBadLength)
} else if neededSize > cap(p.Bytes) {
p.Bytes = append(p.Bytes[:cap(p.Bytes)], make([]byte, neededSize-cap(p.Bytes))...)
} else {
p.Bytes = p.Bytes[:neededSize] p.Bytes = p.Bytes[:neededSize]
return
default: // Add capacity/length to byte slice
p.Bytes = append(p.Bytes[:cap(p.Bytes)], make([]byte, neededSize-cap(p.Bytes))...)
} }
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -11,7 +11,7 @@ import (
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
const ( const (

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/database/versiondb" "github.com/ava-labs/gecko/database/versiondb"
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -16,7 +16,7 @@ import (
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -12,7 +12,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -8,7 +8,7 @@ import (
"errors" "errors"
"sort" "sort"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -11,7 +11,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -10,7 +10,7 @@ import (
"github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -8,7 +8,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -56,7 +56,7 @@ type IssueTxReply struct {
// IssueTx attempts to issue a transaction into consensus // IssueTx attempts to issue a transaction into consensus
func (service *Service) IssueTx(r *http.Request, args *IssueTxArgs, reply *IssueTxReply) error { func (service *Service) IssueTx(r *http.Request, args *IssueTxArgs, reply *IssueTxReply) error {
service.vm.ctx.Log.Verbo("IssueTx called with %s", args.Tx) service.vm.ctx.Log.Info("AVM: IssueTx called with %s", args.Tx)
txID, err := service.vm.IssueTx(args.Tx.Bytes, nil) txID, err := service.vm.IssueTx(args.Tx.Bytes, nil)
if err != nil { if err != nil {
@ -79,7 +79,7 @@ type GetTxStatusReply struct {
// GetTxStatus returns the status of the specified transaction // GetTxStatus returns the status of the specified transaction
func (service *Service) GetTxStatus(r *http.Request, args *GetTxStatusArgs, reply *GetTxStatusReply) error { func (service *Service) GetTxStatus(r *http.Request, args *GetTxStatusArgs, reply *GetTxStatusReply) error {
service.vm.ctx.Log.Verbo("GetTxStatus called with %s", args.TxID) service.vm.ctx.Log.Info("AVM: GetTxStatus called with %s", args.TxID)
if args.TxID.IsZero() { if args.TxID.IsZero() {
return errNilTxID return errNilTxID
@ -106,7 +106,7 @@ type GetTxReply struct {
// GetTx returns the specified transaction // GetTx returns the specified transaction
func (service *Service) GetTx(r *http.Request, args *GetTxArgs, reply *GetTxReply) error { func (service *Service) GetTx(r *http.Request, args *GetTxArgs, reply *GetTxReply) error {
service.vm.ctx.Log.Verbo("GetTx called with %s", args.TxID) service.vm.ctx.Log.Info("AVM: GetTx called with %s", args.TxID)
if args.TxID.IsZero() { if args.TxID.IsZero() {
return errNilTxID return errNilTxID
@ -136,7 +136,7 @@ type GetUTXOsReply struct {
// GetUTXOs creates an empty account with the name passed in // GetUTXOs creates an empty account with the name passed in
func (service *Service) GetUTXOs(r *http.Request, args *GetUTXOsArgs, reply *GetUTXOsReply) error { func (service *Service) GetUTXOs(r *http.Request, args *GetUTXOsArgs, reply *GetUTXOsReply) error {
service.vm.ctx.Log.Verbo("GetUTXOs called with %s", args.Addresses) service.vm.ctx.Log.Info("AVM: GetUTXOs called with %s", args.Addresses)
addrSet := ids.Set{} addrSet := ids.Set{}
for _, addr := range args.Addresses { for _, addr := range args.Addresses {
@ -178,7 +178,7 @@ type GetAssetDescriptionReply struct {
// GetAssetDescription creates an empty account with the name passed in // GetAssetDescription creates an empty account with the name passed in
func (service *Service) GetAssetDescription(_ *http.Request, args *GetAssetDescriptionArgs, reply *GetAssetDescriptionReply) error { func (service *Service) GetAssetDescription(_ *http.Request, args *GetAssetDescriptionArgs, reply *GetAssetDescriptionReply) error {
service.vm.ctx.Log.Verbo("GetAssetDescription called with %s", args.AssetID) service.vm.ctx.Log.Info("AVM: GetAssetDescription called with %s", args.AssetID)
assetID, err := service.vm.Lookup(args.AssetID) assetID, err := service.vm.Lookup(args.AssetID)
if err != nil { if err != nil {
@ -222,7 +222,7 @@ type GetBalanceReply struct {
// GetBalance returns the amount of an asset that an address at least partially owns // GetBalance returns the amount of an asset that an address at least partially owns
func (service *Service) GetBalance(r *http.Request, args *GetBalanceArgs, reply *GetBalanceReply) error { func (service *Service) GetBalance(r *http.Request, args *GetBalanceArgs, reply *GetBalanceReply) error {
service.vm.ctx.Log.Verbo("GetBalance called with address: %s assetID: %s", args.Address, args.AssetID) service.vm.ctx.Log.Info("AVM: GetBalance called with address: %s assetID: %s", args.Address, args.AssetID)
address, err := service.vm.Parse(args.Address) address, err := service.vm.Parse(args.Address)
if err != nil { if err != nil {
@ -287,7 +287,7 @@ type GetAllBalancesReply struct {
// Note that balances include assets that the address only _partially_ owns // Note that balances include assets that the address only _partially_ owns
// (ie is one of several addresses specified in a multi-sig) // (ie is one of several addresses specified in a multi-sig)
func (service *Service) GetAllBalances(r *http.Request, args *GetAllBalancesArgs, reply *GetAllBalancesReply) error { func (service *Service) GetAllBalances(r *http.Request, args *GetAllBalancesArgs, reply *GetAllBalancesReply) error {
service.vm.ctx.Log.Verbo("GetAllBalances called with address: %s", args.Address) service.vm.ctx.Log.Info("AVM: GetAllBalances called with address: %s", args.Address)
address, err := service.vm.Parse(args.Address) address, err := service.vm.Parse(args.Address)
if err != nil { if err != nil {
@ -360,7 +360,7 @@ type CreateFixedCapAssetReply struct {
// CreateFixedCapAsset returns ID of the newly created asset // CreateFixedCapAsset returns ID of the newly created asset
func (service *Service) CreateFixedCapAsset(r *http.Request, args *CreateFixedCapAssetArgs, reply *CreateFixedCapAssetReply) error { func (service *Service) CreateFixedCapAsset(r *http.Request, args *CreateFixedCapAssetArgs, reply *CreateFixedCapAssetReply) error {
service.vm.ctx.Log.Verbo("CreateFixedCapAsset called with name: %s symbol: %s number of holders: %d", service.vm.ctx.Log.Info("AVM: CreateFixedCapAsset called with name: %s symbol: %s number of holders: %d",
args.Name, args.Name,
args.Symbol, args.Symbol,
len(args.InitialHolders), len(args.InitialHolders),
@ -445,7 +445,7 @@ type CreateVariableCapAssetReply struct {
// CreateVariableCapAsset returns ID of the newly created asset // CreateVariableCapAsset returns ID of the newly created asset
func (service *Service) CreateVariableCapAsset(r *http.Request, args *CreateVariableCapAssetArgs, reply *CreateVariableCapAssetReply) error { func (service *Service) CreateVariableCapAsset(r *http.Request, args *CreateVariableCapAssetArgs, reply *CreateVariableCapAssetReply) error {
service.vm.ctx.Log.Verbo("CreateFixedCapAsset called with name: %s symbol: %s number of minters: %d", service.vm.ctx.Log.Info("AVM: CreateFixedCapAsset called with name: %s symbol: %s number of minters: %d",
args.Name, args.Name,
args.Symbol, args.Symbol,
len(args.MinterSets), len(args.MinterSets),
@ -523,7 +523,7 @@ type CreateAddressReply struct {
// CreateAddress creates an address for the user [args.Username] // CreateAddress creates an address for the user [args.Username]
func (service *Service) CreateAddress(r *http.Request, args *CreateAddressArgs, reply *CreateAddressReply) error { func (service *Service) CreateAddress(r *http.Request, args *CreateAddressArgs, reply *CreateAddressReply) error {
service.vm.ctx.Log.Verbo("CreateAddress called for user '%s'", args.Username) service.vm.ctx.Log.Info("AVM: CreateAddress called for user '%s'", args.Username)
db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password) db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password)
if err != nil { if err != nil {
@ -603,7 +603,7 @@ type ExportKeyReply struct {
// ExportKey returns a private key from the provided user // ExportKey returns a private key from the provided user
func (service *Service) ExportKey(r *http.Request, args *ExportKeyArgs, reply *ExportKeyReply) error { func (service *Service) ExportKey(r *http.Request, args *ExportKeyArgs, reply *ExportKeyReply) error {
service.vm.ctx.Log.Verbo("ExportKey called for user '%s'", args.Username) service.vm.ctx.Log.Info("AVM: ExportKey called for user '%s'", args.Username)
address, err := service.vm.Parse(args.Address) address, err := service.vm.Parse(args.Address)
if err != nil { if err != nil {
@ -645,7 +645,7 @@ type ImportKeyReply struct {
// ImportKey adds a private key to the provided user // ImportKey adds a private key to the provided user
func (service *Service) ImportKey(r *http.Request, args *ImportKeyArgs, reply *ImportKeyReply) error { func (service *Service) ImportKey(r *http.Request, args *ImportKeyArgs, reply *ImportKeyReply) error {
service.vm.ctx.Log.Verbo("ImportKey called for user '%s'", args.Username) service.vm.ctx.Log.Info("AVM: ImportKey called for user '%s'", args.Username)
db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password) db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password)
if err != nil { if err != nil {
@ -666,13 +666,20 @@ func (service *Service) ImportKey(r *http.Request, args *ImportKeyArgs, reply *I
} }
addresses, _ := user.Addresses(db) addresses, _ := user.Addresses(db)
addresses = append(addresses, sk.PublicKey().Address())
newAddress := sk.PublicKey().Address()
reply.Address = service.vm.Format(newAddress.Bytes())
for _, address := range addresses {
if newAddress.Equals(address) {
return nil
}
}
addresses = append(addresses, newAddress)
if err := user.SetAddresses(db, addresses); err != nil { if err := user.SetAddresses(db, addresses); err != nil {
return fmt.Errorf("problem saving addresses: %w", err) return fmt.Errorf("problem saving addresses: %w", err)
} }
reply.Address = service.vm.Format(sk.PublicKey().Address().Bytes())
return nil return nil
} }
@ -692,7 +699,7 @@ type SendReply struct {
// Send returns the ID of the newly created transaction // Send returns the ID of the newly created transaction
func (service *Service) Send(r *http.Request, args *SendArgs, reply *SendReply) error { func (service *Service) Send(r *http.Request, args *SendArgs, reply *SendReply) error {
service.vm.ctx.Log.Verbo("Send called with username: %s", args.Username) service.vm.ctx.Log.Info("AVM: Send called with username: %s", args.Username)
if args.Amount == 0 { if args.Amount == 0 {
return errInvalidAmount return errInvalidAmount
@ -873,7 +880,7 @@ type CreateMintTxReply struct {
// CreateMintTx returns the newly created unsigned transaction // CreateMintTx returns the newly created unsigned transaction
func (service *Service) CreateMintTx(r *http.Request, args *CreateMintTxArgs, reply *CreateMintTxReply) error { func (service *Service) CreateMintTx(r *http.Request, args *CreateMintTxArgs, reply *CreateMintTxReply) error {
service.vm.ctx.Log.Verbo("CreateMintTx called") service.vm.ctx.Log.Info("AVM: CreateMintTx called")
if args.Amount == 0 { if args.Amount == 0 {
return errInvalidMintAmount return errInvalidMintAmount
@ -990,7 +997,7 @@ type SignMintTxReply struct {
// SignMintTx returns the newly signed transaction // SignMintTx returns the newly signed transaction
func (service *Service) SignMintTx(r *http.Request, args *SignMintTxArgs, reply *SignMintTxReply) error { func (service *Service) SignMintTx(r *http.Request, args *SignMintTxArgs, reply *SignMintTxReply) error {
service.vm.ctx.Log.Verbo("SignMintTx called") service.vm.ctx.Log.Info("AVM: SignMintTx called")
minter, err := service.vm.Parse(args.Minter) minter, err := service.vm.Parse(args.Minter)
if err != nil { if err != nil {
@ -1116,7 +1123,7 @@ type ImportAVAReply struct {
// The AVA must have already been exported from the P-Chain. // The AVA must have already been exported from the P-Chain.
// Returns the ID of the newly created atomic transaction // Returns the ID of the newly created atomic transaction
func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, reply *ImportAVAReply) error { func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, reply *ImportAVAReply) error {
service.vm.ctx.Log.Verbo("ImportAVA called with username: %s", args.Username) service.vm.ctx.Log.Info("AVM: ImportAVA called with username: %s", args.Username)
toBytes, err := service.vm.Parse(args.To) toBytes, err := service.vm.Parse(args.To)
if err != nil { if err != nil {
@ -1268,7 +1275,7 @@ type ExportAVAReply struct {
// After this tx is accepted, the AVA must be imported to the P-chain with an importTx. // After this tx is accepted, the AVA must be imported to the P-chain with an importTx.
// Returns the ID of the newly created atomic transaction // Returns the ID of the newly created atomic transaction
func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, reply *ExportAVAReply) error { func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, reply *ExportAVAReply) error {
service.vm.ctx.Log.Verbo("ExportAVA called with username: %s", args.Username) service.vm.ctx.Log.Info("AVM: ExportAVA called with username: %s", args.Username)
if args.Amount == 0 { if args.Amount == 0 {
return errInvalidAmount return errInvalidAmount

View File

@ -9,8 +9,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ava-labs/gecko/api/keystore"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/choices" "github.com/ava-labs/gecko/snow/choices"
"github.com/ava-labs/gecko/utils/crypto"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
) )
@ -340,3 +342,113 @@ func TestCreateVariableCapAsset(t *testing.T) {
t.Fatalf("Wrong assetID returned from CreateFixedCapAsset %s", reply.AssetID) t.Fatalf("Wrong assetID returned from CreateFixedCapAsset %s", reply.AssetID)
} }
} }
func TestImportAvmKey(t *testing.T) {
_, vm, s := setup(t)
defer func() {
vm.Shutdown()
ctx.Lock.Unlock()
}()
userKeystore := keystore.CreateTestKeystore(t)
username := "bobby"
password := "StrnasfqewiurPasswdn56d"
if err := userKeystore.AddUser(username, password); err != nil {
t.Fatal(err)
}
vm.ctx.Keystore = userKeystore.NewBlockchainKeyStore(vm.ctx.ChainID)
_, err := vm.ctx.Keystore.GetDatabase(username, password)
if err != nil {
t.Fatal(err)
}
factory := crypto.FactorySECP256K1R{}
skIntf, err := factory.NewPrivateKey()
if err != nil {
t.Fatalf("problem generating private key: %w", err)
}
sk := skIntf.(*crypto.PrivateKeySECP256K1R)
args := ImportKeyArgs{
Username: username,
Password: password,
PrivateKey: formatting.CB58{Bytes: sk.Bytes()},
}
reply := ImportKeyReply{}
if err = s.ImportKey(nil, &args, &reply); err != nil {
t.Fatal(err)
}
}
func TestImportAvmKeyNoDuplicates(t *testing.T) {
_, vm, s := setup(t)
defer func() {
vm.Shutdown()
ctx.Lock.Unlock()
}()
userKeystore := keystore.CreateTestKeystore(t)
username := "bobby"
password := "StrnasfqewiurPasswdn56d"
if err := userKeystore.AddUser(username, password); err != nil {
t.Fatal(err)
}
vm.ctx.Keystore = userKeystore.NewBlockchainKeyStore(vm.ctx.ChainID)
_, err := vm.ctx.Keystore.GetDatabase(username, password)
if err != nil {
t.Fatal(err)
}
factory := crypto.FactorySECP256K1R{}
skIntf, err := factory.NewPrivateKey()
if err != nil {
t.Fatalf("problem generating private key: %w", err)
}
sk := skIntf.(*crypto.PrivateKeySECP256K1R)
args := ImportKeyArgs{
Username: username,
Password: password,
PrivateKey: formatting.CB58{Bytes: sk.Bytes()},
}
reply := ImportKeyReply{}
if err = s.ImportKey(nil, &args, &reply); err != nil {
t.Fatal(err)
}
expectedAddress := vm.Format(sk.PublicKey().Address().Bytes())
if reply.Address != expectedAddress {
t.Fatalf("Reply address: %s did not match expected address: %s", reply.Address, expectedAddress)
}
reply2 := ImportKeyReply{}
if err = s.ImportKey(nil, &args, &reply2); err != nil {
t.Fatal(err)
}
if reply2.Address != expectedAddress {
t.Fatalf("Reply address: %s did not match expected address: %s", reply2.Address, expectedAddress)
}
addrsArgs := ListAddressesArgs{
Username: username,
Password: password,
}
addrsReply := ListAddressesResponse{}
if err := s.ListAddresses(nil, &addrsArgs, &addrsReply); err != nil {
t.Fatal(err)
}
if len(addrsReply.Addresses) != 1 {
t.Fatal("Importing the same key twice created duplicate addresses")
}
if addrsReply.Addresses[0] != expectedAddress {
t.Fatal("List addresses returned an incorrect address")
}
}

View File

@ -11,7 +11,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/utils/wrappers"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
cjson "github.com/ava-labs/gecko/utils/json" cjson "github.com/ava-labs/gecko/utils/json"

View File

@ -10,7 +10,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/utils/units"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -25,7 +25,7 @@ import (
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
"github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/utils/wrappers"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
cjson "github.com/ava-labs/gecko/utils/json" cjson "github.com/ava-labs/gecko/utils/json"
) )

View File

@ -7,7 +7,7 @@ import (
"testing" "testing"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
func TestAssetVerifyNil(t *testing.T) { func TestAssetVerifyNil(t *testing.T) {

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/choices" "github.com/ava-labs/gecko/snow/choices"
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
// Addressable is the interface a feature extension must provide to be able to // Addressable is the interface a feature extension must provide to be able to

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/database/memdb" "github.com/ava-labs/gecko/database/memdb"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )

View File

@ -11,7 +11,7 @@ import (
"github.com/ava-labs/gecko/database/prefixdb" "github.com/ava-labs/gecko/database/prefixdb"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/choices" "github.com/ava-labs/gecko/snow/choices"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
var ( var (

View File

@ -10,7 +10,7 @@ import (
"github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/utils/crypto" "github.com/ava-labs/gecko/utils/crypto"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/components/verify"
) )

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -7,7 +7,7 @@ import (
"testing" "testing"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
func TestUTXOIDVerifyNil(t *testing.T) { func TestUTXOIDVerifyNil(t *testing.T) {

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -1,345 +0,0 @@
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package codec
import (
"errors"
"fmt"
"reflect"
"unicode"
"github.com/ava-labs/gecko/utils/wrappers"
)
const (
defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal()
defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal()
)
// ErrBadCodec is returned when one tries to perform an operation
// using an unknown codec
var (
errBadCodec = errors.New("wrong or unknown codec used")
errNil = errors.New("can't marshal nil value")
errUnmarshalNil = errors.New("can't unmarshal into nil")
errNeedPointer = errors.New("must unmarshal into a pointer")
errMarshalUnregisteredType = errors.New("can't marshal an unregistered type")
errUnmarshalUnregisteredType = errors.New("can't unmarshal an unregistered type")
errUnknownType = errors.New("don't know how to marshal/unmarshal this type")
errMarshalUnexportedField = errors.New("can't serialize an unexported field")
errUnmarshalUnexportedField = errors.New("can't deserialize into an unexported field")
errOutOfMemory = errors.New("out of memory")
errSliceTooLarge = errors.New("slice too large")
)
// Codec handles marshaling and unmarshaling of structs
type codec struct {
maxSize int
maxSliceLen int
typeIDToType map[uint32]reflect.Type
typeToTypeID map[reflect.Type]uint32
}
// Codec marshals and unmarshals
type Codec interface {
RegisterType(interface{}) error
Marshal(interface{}) ([]byte, error)
Unmarshal([]byte, interface{}) error
}
// New returns a new codec
func New(maxSize, maxSliceLen int) Codec {
return codec{
maxSize: maxSize,
maxSliceLen: maxSliceLen,
typeIDToType: map[uint32]reflect.Type{},
typeToTypeID: map[reflect.Type]uint32{},
}
}
// NewDefault returns a new codec with reasonable default values
func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) }
// RegisterType is used to register types that may be unmarshaled into an interface typed value
// [val] is a value of the type being registered
func (c codec) RegisterType(val interface{}) error {
valType := reflect.TypeOf(val)
if _, exists := c.typeToTypeID[valType]; exists {
return fmt.Errorf("type %v has already been registered", valType)
}
c.typeIDToType[uint32(len(c.typeIDToType))] = reflect.TypeOf(val)
c.typeToTypeID[valType] = uint32(len(c.typeIDToType) - 1)
return nil
}
// A few notes:
// 1) See codec_test.go for examples of usage
// 2) We use "marshal" and "serialize" interchangeably, and "unmarshal" and "deserialize" interchangeably
// 3) To include a field of a struct in the serialized form, add the tag `serialize:"true"` to it
// 4) These typed members of a struct may be serialized:
// bool, string, uint[8,16,32,64, int[8,16,32,64],
// structs, slices, arrays, interface.
// structs, slices and arrays can only be serialized if their constituent parts can be.
// 5) To marshal an interface typed value, you must pass a _pointer_ to the value
// 6) If you want to be able to unmarshal into an interface typed value,
// you must call codec.RegisterType([instance of the type that fulfills the interface]).
// 7) nil slices will be unmarshaled as an empty slice of the appropriate type
// 8) Serialized fields must be exported
// Marshal returns the byte representation of [value]
// If you want to marshal an interface, [value] must be a pointer
// to the interface
func (c codec) Marshal(value interface{}) ([]byte, error) {
if value == nil {
return nil, errNil
}
return c.marshal(reflect.ValueOf(value))
}
// Marshal [value] to bytes
func (c codec) marshal(value reflect.Value) ([]byte, error) {
p := wrappers.Packer{MaxSize: c.maxSize, Bytes: []byte{}}
t := value.Type()
valueKind := value.Kind()
switch valueKind {
case reflect.Interface, reflect.Ptr, reflect.Slice:
if value.IsNil() {
return nil, errNil
}
}
switch valueKind {
case reflect.Uint8:
p.PackByte(uint8(value.Uint()))
return p.Bytes, p.Err
case reflect.Int8:
p.PackByte(uint8(value.Int()))
return p.Bytes, p.Err
case reflect.Uint16:
p.PackShort(uint16(value.Uint()))
return p.Bytes, p.Err
case reflect.Int16:
p.PackShort(uint16(value.Int()))
return p.Bytes, p.Err
case reflect.Uint32:
p.PackInt(uint32(value.Uint()))
return p.Bytes, p.Err
case reflect.Int32:
p.PackInt(uint32(value.Int()))
return p.Bytes, p.Err
case reflect.Uint64:
p.PackLong(value.Uint())
return p.Bytes, p.Err
case reflect.Int64:
p.PackLong(uint64(value.Int()))
return p.Bytes, p.Err
case reflect.Uintptr, reflect.Ptr:
return c.marshal(value.Elem())
case reflect.String:
p.PackStr(value.String())
return p.Bytes, p.Err
case reflect.Bool:
p.PackBool(value.Bool())
return p.Bytes, p.Err
case reflect.Interface:
typeID, ok := c.typeToTypeID[reflect.TypeOf(value.Interface())] // Get the type ID of the value being marshaled
if !ok {
return nil, fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(value.Interface()).String())
}
p.PackInt(typeID)
bytes, err := c.Marshal(value.Interface())
if err != nil {
return nil, err
}
p.PackFixedBytes(bytes)
if p.Errored() {
return nil, p.Err
}
return p.Bytes, err
case reflect.Array, reflect.Slice:
numElts := value.Len() // # elements in the slice/array (assumed to be <= 2^31 - 1)
// If this is a slice, pack the number of elements in the slice
if valueKind == reflect.Slice {
p.PackInt(uint32(numElts))
}
for i := 0; i < numElts; i++ { // Pack each element in the slice/array
eltBytes, err := c.marshal(value.Index(i))
if err != nil {
return nil, err
}
p.PackFixedBytes(eltBytes)
}
return p.Bytes, p.Err
case reflect.Struct:
for i := 0; i < t.NumField(); i++ { // Go through all fields of this struct
field := t.Field(i)
if !shouldSerialize(field) { // Skip fields we don't need to serialize
continue
}
if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields
return nil, errMarshalUnexportedField
}
fieldVal := value.Field(i) // The field we're serializing
if fieldVal.Kind() == reflect.Slice && fieldVal.IsNil() {
p.PackInt(0)
continue
}
fieldBytes, err := c.marshal(fieldVal) // Serialize the field
if err != nil {
return nil, err
}
p.PackFixedBytes(fieldBytes)
}
return p.Bytes, p.Err
case reflect.Invalid:
return nil, errUnmarshalNil
default:
return nil, errUnknownType
}
}
// Unmarshal unmarshals [bytes] into [dest], where
// [dest] must be a pointer or interface
func (c codec) Unmarshal(bytes []byte, dest interface{}) error {
p := &wrappers.Packer{Bytes: bytes}
if len(bytes) > c.maxSize {
return errSliceTooLarge
}
if dest == nil {
return errNil
}
destPtr := reflect.ValueOf(dest)
if destPtr.Kind() != reflect.Ptr {
return errNeedPointer
}
destVal := destPtr.Elem()
err := c.unmarshal(p, destVal)
if err != nil {
return err
}
if p.Offset != len(p.Bytes) {
return fmt.Errorf("has %d leftover bytes after unmarshalling", len(p.Bytes)-p.Offset)
}
return nil
}
// Unmarshal bytes from [p] into [field]
// [field] must be addressable
func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
kind := field.Kind()
switch kind {
case reflect.Uint8:
field.SetUint(uint64(p.UnpackByte()))
case reflect.Int8:
field.SetInt(int64(p.UnpackByte()))
case reflect.Uint16:
field.SetUint(uint64(p.UnpackShort()))
case reflect.Int16:
field.SetInt(int64(p.UnpackShort()))
case reflect.Uint32:
field.SetUint(uint64(p.UnpackInt()))
case reflect.Int32:
field.SetInt(int64(p.UnpackInt()))
case reflect.Uint64:
field.SetUint(p.UnpackLong())
case reflect.Int64:
field.SetInt(int64(p.UnpackLong()))
case reflect.Bool:
field.SetBool(p.UnpackBool())
case reflect.Slice:
sliceLen := int(p.UnpackInt()) // number of elements in the slice
if sliceLen < 0 || sliceLen > c.maxSliceLen {
return errSliceTooLarge
}
// First set [field] to be a slice of the appropriate type/capacity (right now [field] is nil)
slice := reflect.MakeSlice(field.Type(), sliceLen, sliceLen)
field.Set(slice)
// Unmarshal each element into the appropriate index of the slice
for i := 0; i < sliceLen; i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
}
case reflect.Array:
for i := 0; i < field.Len(); i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
}
case reflect.String:
field.SetString(p.UnpackStr())
case reflect.Interface:
// Get the type ID
typeID := p.UnpackInt()
// Get a struct that implements the interface
typ, ok := c.typeIDToType[typeID]
if !ok {
return errUnmarshalUnregisteredType
}
// Ensure struct actually does implement the interface
fieldType := field.Type()
if !typ.Implements(fieldType) {
return fmt.Errorf("%s does not implement interface %s", typ, fieldType)
}
concreteInstancePtr := reflect.New(typ) // instance of the proper type
// Unmarshal into the struct
if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil {
return err
}
// And assign the filled struct to the field
field.Set(concreteInstancePtr.Elem())
case reflect.Struct:
// Type of this struct
structType := reflect.TypeOf(field.Interface())
// Go through all the fields and umarshal into each
for i := 0; i < structType.NumField(); i++ {
structField := structType.Field(i)
if !shouldSerialize(structField) { // Skip fields we don't need to unmarshal
continue
}
if unicode.IsLower(rune(structField.Name[0])) { // Only unmarshal into exported field
return errUnmarshalUnexportedField
}
field := field.Field(i) // Get the field
if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field
return err
}
if p.Errored() { // If there was an error just return immediately
return p.Err
}
}
case reflect.Ptr:
// Get the type this pointer points to
underlyingType := field.Type().Elem()
// Create a new pointer to a new value of the underlying type
underlyingValue := reflect.New(underlyingType)
// Fill the value
if err := c.unmarshal(p, underlyingValue.Elem()); err != nil {
return err
}
// Assign to the top-level struct's member
field.Set(underlyingValue)
case reflect.Invalid:
return errUnmarshalNil
default:
return errUnknownType
}
return p.Err
}
// Returns true iff [field] should be serialized
func shouldSerialize(field reflect.StructField) bool {
return field.Tag.Get("serialize") == "true"
}

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -234,7 +234,7 @@ type GetCurrentValidatorsReply struct {
// GetCurrentValidators returns the list of current validators // GetCurrentValidators returns the list of current validators
func (service *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidatorsArgs, reply *GetCurrentValidatorsReply) error { func (service *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidatorsArgs, reply *GetCurrentValidatorsReply) error {
service.vm.Ctx.Log.Debug("GetCurrentValidators called") service.vm.Ctx.Log.Info("Platform: GetCurrentValidators called")
if args.SubnetID.IsZero() { if args.SubnetID.IsZero() {
args.SubnetID = DefaultSubnetID args.SubnetID = DefaultSubnetID
@ -298,7 +298,7 @@ type GetPendingValidatorsReply struct {
// GetPendingValidators returns the list of current validators // GetPendingValidators returns the list of current validators
func (service *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidatorsArgs, reply *GetPendingValidatorsReply) error { func (service *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidatorsArgs, reply *GetPendingValidatorsReply) error {
service.vm.Ctx.Log.Debug("GetPendingValidators called") service.vm.Ctx.Log.Info("Platform: GetPendingValidators called")
if args.SubnetID.IsZero() { if args.SubnetID.IsZero() {
args.SubnetID = DefaultSubnetID args.SubnetID = DefaultSubnetID
@ -360,7 +360,7 @@ type SampleValidatorsReply struct {
// SampleValidators returns a sampling of the list of current validators // SampleValidators returns a sampling of the list of current validators
func (service *Service) SampleValidators(_ *http.Request, args *SampleValidatorsArgs, reply *SampleValidatorsReply) error { func (service *Service) SampleValidators(_ *http.Request, args *SampleValidatorsArgs, reply *SampleValidatorsReply) error {
service.vm.Ctx.Log.Debug("Sample called with {Size = %d}", args.Size) service.vm.Ctx.Log.Info("Platform: SampleValidators called with {Size = %d}", args.Size)
if args.SubnetID.IsZero() { if args.SubnetID.IsZero() {
args.SubnetID = DefaultSubnetID args.SubnetID = DefaultSubnetID
@ -437,7 +437,7 @@ type ListAccountsReply struct {
// ListAccounts lists all of the accounts controlled by [args.Username] // ListAccounts lists all of the accounts controlled by [args.Username]
func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, reply *ListAccountsReply) error { func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, reply *ListAccountsReply) error {
service.vm.Ctx.Log.Debug("listAccounts called for user '%s'", args.Username) service.vm.Ctx.Log.Info("Platform: ListAccounts called for user '%s'", args.Username)
// db holds the user's info that pertains to the Platform Chain // db holds the user's info that pertains to the Platform Chain
userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password) userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password)
@ -499,7 +499,7 @@ type CreateAccountReply struct {
// The account's ID is [privKey].PublicKey().Address(), where [privKey] is a // The account's ID is [privKey].PublicKey().Address(), where [privKey] is a
// private key controlled by the user. // private key controlled by the user.
func (service *Service) CreateAccount(_ *http.Request, args *CreateAccountArgs, reply *CreateAccountReply) error { func (service *Service) CreateAccount(_ *http.Request, args *CreateAccountArgs, reply *CreateAccountReply) error {
service.vm.Ctx.Log.Debug("createAccount called for user '%s'", args.Username) service.vm.Ctx.Log.Info("Platform: CreateAccount called for user '%s'", args.Username)
// userDB holds the user's info that pertains to the Platform Chain // userDB holds the user's info that pertains to the Platform Chain
userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password) userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password)
@ -569,7 +569,7 @@ type AddDefaultSubnetValidatorArgs struct {
// AddDefaultSubnetValidator returns an unsigned transaction to add a validator to the default subnet // AddDefaultSubnetValidator returns an unsigned transaction to add a validator to the default subnet
// The returned unsigned transaction should be signed using Sign() // The returned unsigned transaction should be signed using Sign()
func (service *Service) AddDefaultSubnetValidator(_ *http.Request, args *AddDefaultSubnetValidatorArgs, reply *CreateTxResponse) error { func (service *Service) AddDefaultSubnetValidator(_ *http.Request, args *AddDefaultSubnetValidatorArgs, reply *CreateTxResponse) error {
service.vm.Ctx.Log.Debug("AddDefaultSubnetValidator called") service.vm.Ctx.Log.Info("Platform: AddDefaultSubnetValidator called")
switch { switch {
case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID
@ -626,7 +626,7 @@ type AddDefaultSubnetDelegatorArgs struct {
// to the default subnet // to the default subnet
// The returned unsigned transaction should be signed using Sign() // The returned unsigned transaction should be signed using Sign()
func (service *Service) AddDefaultSubnetDelegator(_ *http.Request, args *AddDefaultSubnetDelegatorArgs, reply *CreateTxResponse) error { func (service *Service) AddDefaultSubnetDelegator(_ *http.Request, args *AddDefaultSubnetDelegatorArgs, reply *CreateTxResponse) error {
service.vm.Ctx.Log.Debug("AddDefaultSubnetDelegator called") service.vm.Ctx.Log.Info("Platform: AddDefaultSubnetDelegator called")
switch { switch {
case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID
@ -741,7 +741,7 @@ type CreateSubnetArgs struct {
// CreateSubnet returns an unsigned transaction to create a new subnet. // CreateSubnet returns an unsigned transaction to create a new subnet.
// The unsigned transaction must be signed with the key of [args.Payer] // The unsigned transaction must be signed with the key of [args.Payer]
func (service *Service) CreateSubnet(_ *http.Request, args *CreateSubnetArgs, response *CreateTxResponse) error { func (service *Service) CreateSubnet(_ *http.Request, args *CreateSubnetArgs, response *CreateTxResponse) error {
service.vm.Ctx.Log.Debug("platform.createSubnet called") service.vm.Ctx.Log.Info("Platform: CreateSubnet called")
switch { switch {
case args.PayerNonce == 0: case args.PayerNonce == 0:
@ -796,7 +796,7 @@ type ExportAVAArgs struct {
// The unsigned transaction must be signed with the key of the account exporting the AVA // The unsigned transaction must be signed with the key of the account exporting the AVA
// and paying the transaction fee // and paying the transaction fee
func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, response *CreateTxResponse) error { func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, response *CreateTxResponse) error {
service.vm.Ctx.Log.Debug("platform.ExportAVA called") service.vm.Ctx.Log.Info("Platform: ExportAVA called")
switch { switch {
case args.PayerNonce == 0: case args.PayerNonce == 0:
@ -858,7 +858,7 @@ type SignResponse struct {
// Sign [args.bytes] // Sign [args.bytes]
func (service *Service) Sign(_ *http.Request, args *SignArgs, reply *SignResponse) error { func (service *Service) Sign(_ *http.Request, args *SignArgs, reply *SignResponse) error {
service.vm.Ctx.Log.Debug("sign called") service.vm.Ctx.Log.Info("Platform: Sign called")
if args.Signer == "" { if args.Signer == "" {
return errNilSigner return errNilSigner
@ -938,7 +938,7 @@ func (service *Service) signAddDefaultSubnetValidatorTx(tx *addDefaultSubnetVali
// Sign [unsigned] with [key] // Sign [unsigned] with [key]
func (service *Service) signAddDefaultSubnetDelegatorTx(tx *addDefaultSubnetDelegatorTx, key *crypto.PrivateKeySECP256K1R) (*addDefaultSubnetDelegatorTx, error) { func (service *Service) signAddDefaultSubnetDelegatorTx(tx *addDefaultSubnetDelegatorTx, key *crypto.PrivateKeySECP256K1R) (*addDefaultSubnetDelegatorTx, error) {
service.vm.Ctx.Log.Debug("signAddDefaultSubnetValidatorTx called") service.vm.Ctx.Log.Debug("signAddDefaultSubnetDelegatorTx called")
// TODO: Should we check if tx is already signed? // TODO: Should we check if tx is already signed?
unsignedIntf := interface{}(&tx.UnsignedAddDefaultSubnetDelegatorTx) unsignedIntf := interface{}(&tx.UnsignedAddDefaultSubnetDelegatorTx)
@ -961,7 +961,7 @@ func (service *Service) signAddDefaultSubnetDelegatorTx(tx *addDefaultSubnetDele
// Sign [xt] with [key] // Sign [xt] with [key]
func (service *Service) signCreateSubnetTx(tx *CreateSubnetTx, key *crypto.PrivateKeySECP256K1R) (*CreateSubnetTx, error) { func (service *Service) signCreateSubnetTx(tx *CreateSubnetTx, key *crypto.PrivateKeySECP256K1R) (*CreateSubnetTx, error) {
service.vm.Ctx.Log.Debug("signAddDefaultSubnetValidatorTx called") service.vm.Ctx.Log.Debug("signCreateSubnetTx called")
// TODO: Should we check if tx is already signed? // TODO: Should we check if tx is already signed?
unsignedIntf := interface{}(&tx.UnsignedCreateSubnetTx) unsignedIntf := interface{}(&tx.UnsignedCreateSubnetTx)
@ -984,7 +984,7 @@ func (service *Service) signCreateSubnetTx(tx *CreateSubnetTx, key *crypto.Priva
// Sign [tx] with [key] // Sign [tx] with [key]
func (service *Service) signExportTx(tx *ExportTx, key *crypto.PrivateKeySECP256K1R) (*ExportTx, error) { func (service *Service) signExportTx(tx *ExportTx, key *crypto.PrivateKeySECP256K1R) (*ExportTx, error) {
service.vm.Ctx.Log.Debug("platform.signAddDefaultSubnetValidatorTx called") service.vm.Ctx.Log.Debug("signExportTx called")
// TODO: Should we check if tx is already signed? // TODO: Should we check if tx is already signed?
unsignedIntf := interface{}(&tx.UnsignedExportTx) unsignedIntf := interface{}(&tx.UnsignedExportTx)
@ -1075,7 +1075,7 @@ type ImportAVAArgs struct {
// The AVA must have already been exported from the X-Chain. // The AVA must have already been exported from the X-Chain.
// The unsigned transaction must be signed with the key of the tx fee payer. // The unsigned transaction must be signed with the key of the tx fee payer.
func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, response *SignResponse) error { func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, response *SignResponse) error {
service.vm.Ctx.Log.Debug("platform.ImportAVA called") service.vm.Ctx.Log.Info("Platform: ImportAVA called")
switch { switch {
case args.To == "": case args.To == "":
@ -1263,7 +1263,7 @@ type IssueTxResponse struct {
// IssueTx issues the transaction [args.Tx] to the network // IssueTx issues the transaction [args.Tx] to the network
func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *IssueTxResponse) error { func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *IssueTxResponse) error {
service.vm.Ctx.Log.Debug("issueTx called") service.vm.Ctx.Log.Info("Platform: IssueTx called")
genTx := genericTx{} genTx := genericTx{}
if err := Codec.Unmarshal(args.Tx.Bytes, &genTx); err != nil { if err := Codec.Unmarshal(args.Tx.Bytes, &genTx); err != nil {
@ -1275,7 +1275,7 @@ func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *Is
if err := tx.initialize(service.vm); err != nil { if err := tx.initialize(service.vm); err != nil {
return fmt.Errorf("error initializing tx: %s", err) return fmt.Errorf("error initializing tx: %s", err)
} }
service.vm.unissuedEvents.Push(tx) service.vm.unissuedEvents.Add(tx)
response.TxID = tx.ID() response.TxID = tx.ID()
case DecisionTx: case DecisionTx:
if err := tx.initialize(service.vm); err != nil { if err := tx.initialize(service.vm); err != nil {
@ -1290,7 +1290,7 @@ func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *Is
service.vm.unissuedAtomicTxs = append(service.vm.unissuedAtomicTxs, tx) service.vm.unissuedAtomicTxs = append(service.vm.unissuedAtomicTxs, tx)
response.TxID = tx.ID() response.TxID = tx.ID()
default: default:
return errors.New("Could not parse given tx. Must be a TimedTx, DecisionTx, or AtomicTx") return errors.New("Could not parse given tx. Provided tx needs to be a TimedTx, DecisionTx, or AtomicTx")
} }
service.vm.resetTimer() service.vm.resetTimer()
@ -1327,7 +1327,7 @@ type CreateBlockchainArgs struct {
// CreateBlockchain returns an unsigned transaction to create a new blockchain // CreateBlockchain returns an unsigned transaction to create a new blockchain
// Must be signed with the Subnet's control keys and with a key that pays the transaction fee before issuance // Must be signed with the Subnet's control keys and with a key that pays the transaction fee before issuance
func (service *Service) CreateBlockchain(_ *http.Request, args *CreateBlockchainArgs, response *CreateTxResponse) error { func (service *Service) CreateBlockchain(_ *http.Request, args *CreateBlockchainArgs, response *CreateTxResponse) error {
service.vm.Ctx.Log.Debug("createBlockchain called") service.vm.Ctx.Log.Info("Platform: CreateBlockchain called")
switch { switch {
case args.PayerNonce == 0: case args.PayerNonce == 0:
@ -1410,7 +1410,7 @@ type GetBlockchainStatusReply struct {
// GetBlockchainStatus gets the status of a blockchain with the ID [args.BlockchainID]. // GetBlockchainStatus gets the status of a blockchain with the ID [args.BlockchainID].
func (service *Service) GetBlockchainStatus(_ *http.Request, args *GetBlockchainStatusArgs, reply *GetBlockchainStatusReply) error { func (service *Service) GetBlockchainStatus(_ *http.Request, args *GetBlockchainStatusArgs, reply *GetBlockchainStatusReply) error {
service.vm.Ctx.Log.Debug("getBlockchainStatus called") service.vm.Ctx.Log.Info("Platform: GetBlockchainStatus called")
switch { switch {
case args.BlockchainID == "": case args.BlockchainID == "":
@ -1490,7 +1490,7 @@ type ValidatedByResponse struct {
// ValidatedBy returns the ID of the Subnet that validates [args.BlockchainID] // ValidatedBy returns the ID of the Subnet that validates [args.BlockchainID]
func (service *Service) ValidatedBy(_ *http.Request, args *ValidatedByArgs, response *ValidatedByResponse) error { func (service *Service) ValidatedBy(_ *http.Request, args *ValidatedByArgs, response *ValidatedByResponse) error {
service.vm.Ctx.Log.Debug("validatedBy called") service.vm.Ctx.Log.Info("Platform: ValidatedBy called")
switch { switch {
case args.BlockchainID == "": case args.BlockchainID == "":
@ -1522,7 +1522,7 @@ type ValidatesResponse struct {
// Validates returns the IDs of the blockchains validated by [args.SubnetID] // Validates returns the IDs of the blockchains validated by [args.SubnetID]
func (service *Service) Validates(_ *http.Request, args *ValidatesArgs, response *ValidatesResponse) error { func (service *Service) Validates(_ *http.Request, args *ValidatesArgs, response *ValidatesResponse) error {
service.vm.Ctx.Log.Debug("validates called") service.vm.Ctx.Log.Info("Platform: Validates called")
switch { switch {
case args.SubnetID == "": case args.SubnetID == "":
@ -1576,7 +1576,7 @@ type GetBlockchainsResponse struct {
// GetBlockchains returns all of the blockchains that exist // GetBlockchains returns all of the blockchains that exist
func (service *Service) GetBlockchains(_ *http.Request, args *struct{}, response *GetBlockchainsResponse) error { func (service *Service) GetBlockchains(_ *http.Request, args *struct{}, response *GetBlockchainsResponse) error {
service.vm.Ctx.Log.Debug("getBlockchains called") service.vm.Ctx.Log.Info("Platform: GetBlockchains called")
chains, err := service.vm.getChains(service.vm.DB) chains, err := service.vm.getChains(service.vm.DB)
if err != nil { if err != nil {

View File

@ -6,6 +6,9 @@ package platformvm
import ( import (
"encoding/json" "encoding/json"
"testing" "testing"
"time"
"github.com/ava-labs/gecko/utils/formatting"
) )
func TestAddDefaultSubnetValidator(t *testing.T) { func TestAddDefaultSubnetValidator(t *testing.T) {
@ -50,3 +53,184 @@ func TestImportKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestIssueTxKeepsTimedEventsSorted(t *testing.T) {
vm := defaultVM()
vm.Ctx.Lock.Lock()
defer func() {
vm.Shutdown()
vm.Ctx.Lock.Unlock()
}()
service := Service{vm: vm}
pendingValidatorStartTime1 := defaultGenesisTime.Add(3 * time.Second)
pendingValidatorEndTime1 := pendingValidatorStartTime1.Add(MinimumStakingDuration)
nodeIDKey1, _ := vm.factory.NewPrivateKey()
nodeID1 := nodeIDKey1.PublicKey().Address()
addPendingValidatorTx1, err := vm.newAddDefaultSubnetValidatorTx(
defaultNonce+1,
defaultStakeAmount,
uint64(pendingValidatorStartTime1.Unix()),
uint64(pendingValidatorEndTime1.Unix()),
nodeID1,
nodeID1,
NumberOfShares,
testNetworkID,
defaultKey,
)
if err != nil {
t.Fatal(err)
}
txBytes1, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx1})
if err != nil {
t.Fatal(err)
}
args1 := &IssueTxArgs{}
args1.Tx = formatting.CB58{Bytes: txBytes1}
reply1 := IssueTxResponse{}
err = service.IssueTx(nil, args1, &reply1)
if err != nil {
t.Fatal(err)
}
pendingValidatorStartTime2 := defaultGenesisTime.Add(2 * time.Second)
pendingValidatorEndTime2 := pendingValidatorStartTime2.Add(MinimumStakingDuration)
nodeIDKey2, _ := vm.factory.NewPrivateKey()
nodeID2 := nodeIDKey2.PublicKey().Address()
addPendingValidatorTx2, err := vm.newAddDefaultSubnetValidatorTx(
defaultNonce+1,
defaultStakeAmount,
uint64(pendingValidatorStartTime2.Unix()),
uint64(pendingValidatorEndTime2.Unix()),
nodeID2,
nodeID2,
NumberOfShares,
testNetworkID,
defaultKey,
)
if err != nil {
t.Fatal(err)
}
txBytes2, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx2})
if err != nil {
t.Fatal(err)
}
args2 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes2}}
reply2 := IssueTxResponse{}
err = service.IssueTx(nil, &args2, &reply2)
if err != nil {
t.Fatal(err)
}
pendingValidatorStartTime3 := defaultGenesisTime.Add(10 * time.Second)
pendingValidatorEndTime3 := pendingValidatorStartTime3.Add(MinimumStakingDuration)
nodeIDKey3, _ := vm.factory.NewPrivateKey()
nodeID3 := nodeIDKey3.PublicKey().Address()
addPendingValidatorTx3, err := vm.newAddDefaultSubnetValidatorTx(
defaultNonce+1,
defaultStakeAmount,
uint64(pendingValidatorStartTime3.Unix()),
uint64(pendingValidatorEndTime3.Unix()),
nodeID3,
nodeID3,
NumberOfShares,
testNetworkID,
defaultKey,
)
if err != nil {
t.Fatal(err)
}
txBytes3, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx3})
if err != nil {
t.Fatal(err)
}
args3 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes3}}
reply3 := IssueTxResponse{}
err = service.IssueTx(nil, &args3, &reply3)
if err != nil {
t.Fatal(err)
}
pendingValidatorStartTime4 := defaultGenesisTime.Add(1 * time.Second)
pendingValidatorEndTime4 := pendingValidatorStartTime4.Add(MinimumStakingDuration)
nodeIDKey4, _ := vm.factory.NewPrivateKey()
nodeID4 := nodeIDKey4.PublicKey().Address()
addPendingValidatorTx4, err := vm.newAddDefaultSubnetValidatorTx(
defaultNonce+1,
defaultStakeAmount,
uint64(pendingValidatorStartTime4.Unix()),
uint64(pendingValidatorEndTime4.Unix()),
nodeID4,
nodeID4,
NumberOfShares,
testNetworkID,
defaultKey,
)
if err != nil {
t.Fatal(err)
}
txBytes4, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx4})
if err != nil {
t.Fatal(err)
}
args4 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes4}}
reply4 := IssueTxResponse{}
err = service.IssueTx(nil, &args4, &reply4)
if err != nil {
t.Fatal(err)
}
pendingValidatorStartTime5 := defaultGenesisTime.Add(50 * time.Second)
pendingValidatorEndTime5 := pendingValidatorStartTime5.Add(MinimumStakingDuration)
nodeIDKey5, _ := vm.factory.NewPrivateKey()
nodeID5 := nodeIDKey5.PublicKey().Address()
addPendingValidatorTx5, err := vm.newAddDefaultSubnetValidatorTx(
defaultNonce+1,
defaultStakeAmount,
uint64(pendingValidatorStartTime5.Unix()),
uint64(pendingValidatorEndTime5.Unix()),
nodeID5,
nodeID5,
NumberOfShares,
testNetworkID,
defaultKey,
)
if err != nil {
t.Fatal(err)
}
txBytes5, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx5})
if err != nil {
t.Fatal(err)
}
args5 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes5}}
reply5 := IssueTxResponse{}
err = service.IssueTx(nil, &args5, &reply5)
if err != nil {
t.Fatal(err)
}
currentEvent := vm.unissuedEvents.Remove()
for vm.unissuedEvents.Len() > 0 {
nextEvent := vm.unissuedEvents.Remove()
if !currentEvent.StartTime().Before(nextEvent.StartTime()) {
t.Fatal("IssueTx does not keep event heap ordered")
}
currentEvent = nextEvent
}
}

View File

@ -4,7 +4,6 @@
package platformvm package platformvm
import ( import (
"container/heap"
"errors" "errors"
"net/http" "net/http"
@ -174,8 +173,8 @@ func (*StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, repl
return errAccountHasNoValue return errAccountHasNoValue
} }
accounts = append(accounts, newAccount( accounts = append(accounts, newAccount(
account.Address, // ID account.Address, // ID
0, // nonce 0, // nonce
uint64(account.Balance), // balance uint64(account.Balance), // balance
)) ))
} }
@ -210,7 +209,7 @@ func (*StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, repl
return err return err
} }
heap.Push(validators, tx) validators.Add(tx)
} }
// Specify the chains that exist at genesis. // Specify the chains that exist at genesis.

View File

@ -111,3 +111,77 @@ func TestBuildGenesisInvalidEndtime(t *testing.T) {
t.Fatalf("Should have errored due to an invalid end time") t.Fatalf("Should have errored due to an invalid end time")
} }
} }
func TestBuildGenesisReturnsSortedValidators(t *testing.T) {
id := ids.NewShortID([20]byte{1})
account := APIAccount{
Address: id,
Balance: 123456789,
}
weight := json.Uint64(987654321)
validator1 := APIDefaultSubnetValidator{
APIValidator: APIValidator{
StartTime: 0,
EndTime: 20,
Weight: &weight,
ID: id,
},
Destination: id,
}
validator2 := APIDefaultSubnetValidator{
APIValidator: APIValidator{
StartTime: 3,
EndTime: 15,
Weight: &weight,
ID: id,
},
Destination: id,
}
validator3 := APIDefaultSubnetValidator{
APIValidator: APIValidator{
StartTime: 1,
EndTime: 10,
Weight: &weight,
ID: id,
},
Destination: id,
}
args := BuildGenesisArgs{
Accounts: []APIAccount{
account,
},
Validators: []APIDefaultSubnetValidator{
validator1,
validator2,
validator3,
},
Time: 5,
}
reply := BuildGenesisReply{}
ss := StaticService{}
if err := ss.BuildGenesis(nil, &args, &reply); err != nil {
t.Fatalf("BuildGenesis should not have errored")
}
genesis := &Genesis{}
if err := Codec.Unmarshal(reply.Bytes.Bytes, genesis); err != nil {
t.Fatal(err)
}
validators := genesis.Validators
if validators.Len() == 0 {
t.Fatal("Validators should contain 3 validators")
}
currentValidator := validators.Remove()
for validators.Len() > 0 {
nextValidator := validators.Remove()
if currentValidator.EndTime().Unix() > nextValidator.EndTime().Unix() {
t.Fatalf("Validators returned by genesis should be a min heap sorted by end time")
}
currentValidator = nextValidator
}
}

View File

@ -4,7 +4,6 @@
package platformvm package platformvm
import ( import (
"container/heap"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -27,7 +26,7 @@ import (
"github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/utils/units"
"github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/utils/wrappers"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/core" "github.com/ava-labs/gecko/vms/components/core"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )
@ -406,10 +405,14 @@ func (vm *VM) createChain(tx *CreateChainTx) {
} }
// Bootstrapping marks this VM as bootstrapping // Bootstrapping marks this VM as bootstrapping
func (vm *VM) Bootstrapping() error { return nil } func (vm *VM) Bootstrapping() error {
return vm.fx.Bootstrapping()
}
// Bootstrapped marks this VM as bootstrapped // Bootstrapped marks this VM as bootstrapped
func (vm *VM) Bootstrapped() error { return nil } func (vm *VM) Bootstrapped() error {
return vm.fx.Bootstrapped()
}
// Shutdown this blockchain // Shutdown this blockchain
func (vm *VM) Shutdown() error { func (vm *VM) Shutdown() error {
@ -698,7 +701,7 @@ func (vm *VM) resetTimer() {
vm.SnowmanVM.NotifyBlockReady() // Should issue a ProposeAddValidator vm.SnowmanVM.NotifyBlockReady() // Should issue a ProposeAddValidator
return return
} }
// If the tx doesn't meet the syncrony bound, drop it // If the tx doesn't meet the synchrony bound, drop it
vm.unissuedEvents.Remove() vm.unissuedEvents.Remove()
vm.Ctx.Log.Debug("dropping tx to add validator because its start time has passed") vm.Ctx.Log.Debug("dropping tx to add validator because its start time has passed")
} }
@ -780,8 +783,8 @@ func (vm *VM) calculateValidators(db database.Database, timestamp time.Time, sub
if timestamp.Before(nextTx.StartTime()) { if timestamp.Before(nextTx.StartTime()) {
break break
} }
heap.Push(current, nextTx) current.Add(nextTx)
heap.Pop(pending) pending.Remove()
started.Add(nextTx.Vdr().ID()) started.Add(nextTx.Vdr().ID())
} }
return current, pending, started, stopped, nil return current, pending, started, stopped, nil

View File

@ -5,7 +5,6 @@ package platformvm
import ( import (
"bytes" "bytes"
"container/heap"
"errors" "errors"
"testing" "testing"
"time" "time"
@ -193,6 +192,8 @@ func defaultVM() *VM {
panic("no subnets found") panic("no subnets found")
} // end delete } // end delete
vm.registerDBTypes()
return vm return vm
} }
@ -226,7 +227,7 @@ func GenesisCurrentValidators() *EventHeap {
testNetworkID, // network ID testNetworkID, // network ID
key, // key paying tx fee and stake key, // key paying tx fee and stake
) )
heap.Push(validators, validator) validators.Add(validator)
} }
return validators return validators
} }
@ -1011,7 +1012,7 @@ func TestCreateSubnet(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
vm.unissuedEvents.Push(addValidatorTx) vm.unissuedEvents.Add(addValidatorTx)
blk, err = vm.BuildBlock() // should add validator to the new subnet blk, err = vm.BuildBlock() // should add validator to the new subnet
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -9,7 +9,7 @@ import (
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/ava-labs/gecko/utils/crypto" "github.com/ava-labs/gecko/utils/crypto"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
func TestCredentialVerify(t *testing.T) { func TestCredentialVerify(t *testing.T) {

View File

@ -12,7 +12,7 @@ import (
"github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
var ( var (

View File

@ -7,7 +7,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
func TestTransferInputAmount(t *testing.T) { func TestTransferInputAmount(t *testing.T) {

View File

@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
func TestOutputAmount(t *testing.T) { func TestOutputAmount(t *testing.T) {

View File

@ -6,7 +6,7 @@ package secp256k1fx
import ( import (
"github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/timer"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
) )
// VM that this Fx must be run by // VM that this Fx must be run by

View File

@ -12,7 +12,7 @@ import (
"github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow"
"github.com/ava-labs/gecko/snow/consensus/snowman" "github.com/ava-labs/gecko/snow/consensus/snowman"
"github.com/ava-labs/gecko/snow/engine/common" "github.com/ava-labs/gecko/snow/engine/common"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/components/core" "github.com/ava-labs/gecko/vms/components/core"
) )

View File

@ -19,7 +19,7 @@ import (
"github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/utils/wrappers"
"github.com/ava-labs/gecko/vms/avm" "github.com/ava-labs/gecko/vms/avm"
"github.com/ava-labs/gecko/vms/components/ava" "github.com/ava-labs/gecko/vms/components/ava"
"github.com/ava-labs/gecko/vms/components/codec" "github.com/ava-labs/gecko/utils/codec"
"github.com/ava-labs/gecko/vms/secp256k1fx" "github.com/ava-labs/gecko/vms/secp256k1fx"
) )