diff --git a/api/admin/service.go b/api/admin/service.go index e05a440..0718dfd 100644 --- a/api/admin/service.go +++ b/api/admin/service.go @@ -57,7 +57,7 @@ type GetNodeVersionReply struct { // GetNodeVersion returns the version this node is running func (service *Admin) GetNodeVersion(_ *http.Request, _ *struct{}, reply *GetNodeVersionReply) error { - service.log.Debug("Admin: GetNodeVersion called") + service.log.Info("Admin: GetNodeVersion called") reply.Version = service.version.String() return nil @@ -70,7 +70,7 @@ type GetNodeIDReply struct { // GetNodeID returns the node ID of this node func (service *Admin) GetNodeID(_ *http.Request, _ *struct{}, reply *GetNodeIDReply) error { - service.log.Debug("Admin: GetNodeID called") + service.log.Info("Admin: GetNodeID called") reply.NodeID = service.nodeID return nil @@ -83,7 +83,7 @@ type GetNetworkIDReply struct { // GetNetworkID returns the network ID this node is running on func (service *Admin) GetNetworkID(_ *http.Request, _ *struct{}, reply *GetNetworkIDReply) error { - service.log.Debug("Admin: GetNetworkID called") + service.log.Info("Admin: GetNetworkID called") reply.NetworkID = cjson.Uint32(service.networkID) return nil @@ -96,7 +96,7 @@ type GetNetworkNameReply struct { // GetNetworkName returns the network name this node is running on func (service *Admin) GetNetworkName(_ *http.Request, _ *struct{}, reply *GetNetworkNameReply) error { - service.log.Debug("Admin: GetNetworkName called") + service.log.Info("Admin: GetNetworkName called") reply.NetworkName = genesis.NetworkName(service.networkID) return nil @@ -114,7 +114,7 @@ type GetBlockchainIDReply struct { // GetBlockchainID returns the blockchain ID that resolves the alias that was supplied func (service *Admin) GetBlockchainID(_ *http.Request, args *GetBlockchainIDArgs, reply *GetBlockchainIDReply) error { - service.log.Debug("Admin: GetBlockchainID called") + service.log.Info("Admin: GetBlockchainID called") bID, err := service.chainManager.Lookup(args.Alias) reply.BlockchainID = bID.String() @@ -128,7 +128,7 @@ type PeersReply struct { // Peers returns the list of current validators func (service *Admin) Peers(_ *http.Request, _ *struct{}, reply *PeersReply) error { - service.log.Debug("Admin: Peers called") + service.log.Info("Admin: Peers called") reply.Peers = service.networking.Peers() return nil } @@ -145,7 +145,7 @@ type StartCPUProfilerReply struct { // StartCPUProfiler starts a cpu profile writing to the specified file func (service *Admin) StartCPUProfiler(_ *http.Request, args *StartCPUProfilerArgs, reply *StartCPUProfilerReply) error { - service.log.Debug("Admin: StartCPUProfiler called with %s", args.Filename) + service.log.Info("Admin: StartCPUProfiler called with %s", args.Filename) reply.Success = true return service.performance.StartCPUProfiler(args.Filename) } @@ -157,7 +157,7 @@ type StopCPUProfilerReply struct { // StopCPUProfiler stops the cpu profile func (service *Admin) StopCPUProfiler(_ *http.Request, _ *struct{}, reply *StopCPUProfilerReply) error { - service.log.Debug("Admin: StopCPUProfiler called") + service.log.Info("Admin: StopCPUProfiler called") reply.Success = true return service.performance.StopCPUProfiler() } @@ -174,7 +174,7 @@ type MemoryProfileReply struct { // MemoryProfile runs a memory profile writing to the specified file func (service *Admin) MemoryProfile(_ *http.Request, args *MemoryProfileArgs, reply *MemoryProfileReply) error { - service.log.Debug("Admin: MemoryProfile called with %s", args.Filename) + service.log.Info("Admin: MemoryProfile called with %s", args.Filename) reply.Success = true return service.performance.MemoryProfile(args.Filename) } @@ -191,7 +191,7 @@ type LockProfileReply struct { // LockProfile runs a mutex profile writing to the specified file func (service *Admin) LockProfile(_ *http.Request, args *LockProfileArgs, reply *LockProfileReply) error { - service.log.Debug("Admin: LockProfile called with %s", args.Filename) + service.log.Info("Admin: LockProfile called with %s", args.Filename) reply.Success = true return service.performance.LockProfile(args.Filename) } @@ -209,7 +209,7 @@ type AliasReply struct { // Alias attempts to alias an HTTP endpoint to a new name func (service *Admin) Alias(_ *http.Request, args *AliasArgs, reply *AliasReply) error { - service.log.Debug("Admin: Alias called with URL: %s, Alias: %s", args.Endpoint, args.Alias) + service.log.Info("Admin: Alias called with URL: %s, Alias: %s", args.Endpoint, args.Alias) reply.Success = true return service.httpServer.AddAliasesWithReadLock(args.Endpoint, args.Alias) } @@ -227,7 +227,7 @@ type AliasChainReply struct { // AliasChain attempts to alias a chain to a new name func (service *Admin) AliasChain(_ *http.Request, args *AliasChainArgs, reply *AliasChainReply) error { - service.log.Debug("Admin: AliasChain called with Chain: %s, Alias: %s", args.Chain, args.Alias) + service.log.Info("Admin: AliasChain called with Chain: %s, Alias: %s", args.Chain, args.Alias) chainID, err := service.chainManager.Lookup(args.Chain) if err != nil { diff --git a/api/health/service.go b/api/health/service.go index db33640..fdd405b 100644 --- a/api/health/service.go +++ b/api/health/service.go @@ -74,7 +74,7 @@ type GetLivenessReply struct { // GetLiveness returns a summation of the health of the node func (h *Health) GetLiveness(_ *http.Request, _ *GetLivenessArgs, reply *GetLivenessReply) error { - h.log.Debug("Health: GetLiveness called") + h.log.Info("Health: GetLiveness called") reply.Checks, reply.Healthy = h.health.Results() return nil } diff --git a/api/ipcs/server.go b/api/ipcs/server.go index 30bcc5d..72a78c1 100644 --- a/api/ipcs/server.go +++ b/api/ipcs/server.go @@ -61,6 +61,7 @@ type PublishBlockchainReply struct { // PublishBlockchain publishes the finalized accepted transactions from the blockchainID over the IPC func (ipc *IPCs) PublishBlockchain(r *http.Request, args *PublishBlockchainArgs, reply *PublishBlockchainReply) error { + ipc.log.Info("IPCs: PublishBlockchain called with BlockchainID: %s", args.BlockchainID) chainID, err := ipc.chainManager.Lookup(args.BlockchainID) if err != nil { ipc.log.Error("unknown blockchainID: %s", err) @@ -116,6 +117,7 @@ type UnpublishBlockchainReply struct { // UnpublishBlockchain closes publishing of a blockchainID func (ipc *IPCs) UnpublishBlockchain(r *http.Request, args *UnpublishBlockchainArgs, reply *UnpublishBlockchainReply) error { + ipc.log.Info("IPCs: UnpublishBlockchain called with BlockchainID: %s", args.BlockchainID) chainID, err := ipc.chainManager.Lookup(args.BlockchainID) if err != nil { ipc.log.Error("unknown blockchainID %s: %s", args.BlockchainID, err) diff --git a/api/keystore/service.go b/api/keystore/service.go index 16aca06..ac9e4e6 100644 --- a/api/keystore/service.go +++ b/api/keystore/service.go @@ -8,18 +8,20 @@ import ( "fmt" "net/http" "sync" + "testing" "github.com/gorilla/rpc/v2" "github.com/ava-labs/gecko/chains/atomic" "github.com/ava-labs/gecko/database" "github.com/ava-labs/gecko/database/encdb" + "github.com/ava-labs/gecko/database/memdb" "github.com/ava-labs/gecko/database/prefixdb" "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow/engine/common" "github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/logging" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" jsoncodec "github.com/ava-labs/gecko/utils/json" zxcvbn "github.com/nbutton23/zxcvbn-go" @@ -29,8 +31,17 @@ const ( // maxUserPassLen is the maximum length of the username or password allowed maxUserPassLen = 1024 - // requiredPassScore defines the score a password must achieve to be accepted - // as a password with strong characteristics by the zxcvbn package + // maxCheckedPassLen limits the length of the password that should be + // strength checked. + // + // As per issue https://github.com/ava-labs/gecko/issues/195 it was found + // the longer the length of password the slower zxcvbn.PasswordStrength() + // performs. To avoid performance issues, and a DoS vector, we only check + // the first 50 characters of the password. + maxCheckedPassLen = 50 + + // requiredPassScore defines the score a password must achieve to be + // accepted as a password with strong characteristics by the zxcvbn package // // The scoring mechanism defined is as follows; // @@ -135,37 +146,11 @@ func (ks *Keystore) CreateUser(_ *http.Request, args *CreateUserArgs, reply *Cre ks.lock.Lock() defer ks.lock.Unlock() - ks.log.Verbo("CreateUser called with %.*s", maxUserPassLen, args.Username) - - if len(args.Username) > maxUserPassLen || len(args.Password) > maxUserPassLen { - return errUserPassMaxLength - } - - if args.Username == "" { - return errEmptyUsername - } - if usr, err := ks.getUser(args.Username); err == nil || usr != nil { - return fmt.Errorf("user already exists: %s", args.Username) - } - - if zxcvbn.PasswordStrength(args.Password, nil).Score < requiredPassScore { - return errWeakPassword - } - - usr := &User{} - if err := usr.Initialize(args.Password); err != nil { + ks.log.Info("Keystore: CreateUser called with %.*s", maxUserPassLen, args.Username) + if err := ks.AddUser(args.Username, args.Password); err != nil { return err } - usrBytes, err := ks.codec.Marshal(usr) - if err != nil { - return err - } - - if err := ks.userDB.Put([]byte(args.Username), usrBytes); err != nil { - return err - } - ks.users[args.Username] = usr reply.Success = true return nil } @@ -183,7 +168,7 @@ func (ks *Keystore) ListUsers(_ *http.Request, args *ListUsersArgs, reply *ListU ks.lock.Lock() defer ks.lock.Unlock() - ks.log.Verbo("ListUsers called") + ks.log.Info("Keystore: ListUsers called") reply.Users = []string{} @@ -211,7 +196,7 @@ func (ks *Keystore) ExportUser(_ *http.Request, args *ExportUserArgs, reply *Exp ks.lock.Lock() defer ks.lock.Unlock() - ks.log.Verbo("ExportUser called for %s", args.Username) + ks.log.Info("Keystore: ExportUser called for %s", args.Username) usr, err := ks.getUser(args.Username) if err != nil { @@ -264,7 +249,7 @@ func (ks *Keystore) ImportUser(r *http.Request, args *ImportUserArgs, reply *Imp ks.lock.Lock() defer ks.lock.Unlock() - ks.log.Verbo("ImportUser called for %s", args.Username) + ks.log.Info("Keystore: ImportUser called for %s", args.Username) if args.Username == "" { return errEmptyUsername @@ -324,7 +309,7 @@ func (ks *Keystore) DeleteUser(_ *http.Request, args *DeleteUserArgs, reply *Del ks.lock.Lock() defer ks.lock.Unlock() - ks.log.Verbo("DeleteUser called with %s", args.Username) + ks.log.Info("Keystore: DeleteUser called with %s", args.Username) if args.Username == "" { return errEmptyUsername @@ -403,3 +388,51 @@ func (ks *Keystore) GetDatabase(bID ids.ID, username, password string) (database return encDB, nil } + +// AddUser attempts to register this username and password as a new user of the +// keystore. +func (ks *Keystore) AddUser(username, password string) error { + if len(username) > maxUserPassLen || len(password) > maxUserPassLen { + return errUserPassMaxLength + } + + if username == "" { + return errEmptyUsername + } + if usr, err := ks.getUser(username); err == nil || usr != nil { + return fmt.Errorf("user already exists: %s", username) + } + + checkPass := password + if len(password) > maxCheckedPassLen { + checkPass = password[:maxCheckedPassLen] + } + + if zxcvbn.PasswordStrength(checkPass, nil).Score < requiredPassScore { + return errWeakPassword + } + + usr := &User{} + if err := usr.Initialize(password); err != nil { + return err + } + + usrBytes, err := ks.codec.Marshal(usr) + if err != nil { + return err + } + + if err := ks.userDB.Put([]byte(username), usrBytes); err != nil { + return err + } + ks.users[username] = usr + + return nil +} + +// CreateTestKeystore returns a new keystore that can be utilized for testing +func CreateTestKeystore(t *testing.T) *Keystore { + ks := &Keystore{} + ks.Initialize(logging.NoLog{}, memdb.New()) + return ks +} diff --git a/api/keystore/service_test.go b/api/keystore/service_test.go index 9ec5cfa..3e0b18f 100644 --- a/api/keystore/service_test.go +++ b/api/keystore/service_test.go @@ -10,9 +10,7 @@ import ( "reflect" "testing" - "github.com/ava-labs/gecko/database/memdb" "github.com/ava-labs/gecko/ids" - "github.com/ava-labs/gecko/utils/logging" ) var ( @@ -22,8 +20,7 @@ var ( ) func TestServiceListNoUsers(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) reply := ListUsersReply{} if err := ks.ListUsers(nil, &ListUsersArgs{}, &reply); err != nil { @@ -35,8 +32,7 @@ func TestServiceListNoUsers(t *testing.T) { } func TestServiceCreateUser(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -75,8 +71,7 @@ func genStr(n int) string { // TestServiceCreateUserArgsChecks generates excessively long usernames or // passwords to assure the santity checks on string length are not exceeded func TestServiceCreateUserArgsCheck(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -117,8 +112,7 @@ func TestServiceCreateUserArgsCheck(t *testing.T) { // TestServiceCreateUserWeakPassword tests creating a new user with a weak // password to ensure the password strength check is working func TestServiceCreateUserWeakPassword(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -138,8 +132,7 @@ func TestServiceCreateUserWeakPassword(t *testing.T) { } func TestServiceCreateDuplicate(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -166,8 +159,7 @@ func TestServiceCreateDuplicate(t *testing.T) { } func TestServiceCreateUserNoName(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) reply := CreateUserReply{} if err := ks.CreateUser(nil, &CreateUserArgs{ @@ -178,8 +170,7 @@ func TestServiceCreateUserNoName(t *testing.T) { } func TestServiceUseBlockchainDB(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -218,8 +209,7 @@ func TestServiceUseBlockchainDB(t *testing.T) { } func TestServiceExportImport(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) { reply := CreateUserReply{} @@ -252,8 +242,7 @@ func TestServiceExportImport(t *testing.T) { t.Fatal(err) } - newKS := Keystore{} - newKS.Initialize(logging.NoLog{}, memdb.New()) + newKS := CreateTestKeystore(t) { reply := ImportUserReply{} @@ -358,11 +347,10 @@ func TestServiceDeleteUser(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - ks := Keystore{} - ks.Initialize(logging.NoLog{}, memdb.New()) + ks := CreateTestKeystore(t) if tt.setup != nil { - if err := tt.setup(&ks); err != nil { + if err := tt.setup(ks); err != nil { t.Fatalf("failed to create user setup in keystore: %v", err) } } diff --git a/chains/atomic/memory.go b/chains/atomic/memory.go index 448e6c9..9774711 100644 --- a/chains/atomic/memory.go +++ b/chains/atomic/memory.go @@ -12,7 +12,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) type rcLock struct { diff --git a/database/encdb/db.go b/database/encdb/db.go index eb06549..4814805 100644 --- a/database/encdb/db.go +++ b/database/encdb/db.go @@ -14,7 +14,7 @@ import ( "github.com/ava-labs/gecko/database/nodb" "github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils/hashing" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) // Database encrypts all values that are provided diff --git a/genesis/genesis.go b/genesis/genesis.go index 4cad047..c4245b9 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -14,7 +14,7 @@ import ( "github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/vms/avm" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/nftfx" "github.com/ava-labs/gecko/vms/platformvm" "github.com/ava-labs/gecko/vms/propertyfx" diff --git a/go.sum b/go.sum index 774be35..d79e9a8 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/AppsFlyer/go-sundheit v0.2.0 h1:FArqX+HbqZ6U32RC3giEAWRUpkggqxHj91KIvxNgwjU= github.com/AppsFlyer/go-sundheit v0.2.0/go.mod h1:rCRkVTMQo7/krF7xQ9X0XEF1an68viFR6/Gy02q+4ds= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Microsoft/go-winio v0.4.11 h1:zoIOcVf0xPN1tnMVbTtEdI+P8OofVk3NObnwOQ6nK2Q= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/sarama v1.26.1/go.mod h1:NbSGBSSndYaIhRcBtY9V0U7AyH+x71bG668AuWys/yU= diff --git a/main/main.go b/main/main.go index 5aca025..98cb581 100644 --- a/main/main.go +++ b/main/main.go @@ -45,7 +45,10 @@ func main() { } // Track if sybil control is enforced - if !Config.EnableStaking { + if !Config.EnableStaking && Config.EnableP2PTLS { + log.Warn("Staking is disabled. Sybil control is not enforced.") + } + if !Config.EnableStaking && !Config.EnableP2PTLS { log.Warn("Staking and p2p encryption are disabled. Packet spoofing is possible.") } diff --git a/main/params.go b/main/params.go index eef8e60..6dcad06 100644 --- a/main/params.go +++ b/main/params.go @@ -37,6 +37,7 @@ const ( var ( Config = node.Config{} Err error + defaultNetworkName = genesis.TestnetName defaultDbDir = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "db")) defaultStakingKeyPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.key")) defaultStakingCertPath = os.ExpandEnv(filepath.Join("$HOME", ".gecko", "staking", "staker.crt")) @@ -49,7 +50,8 @@ var ( ) var ( - errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs") + errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs") + errStakingRequiresTLS = errors.New("if staking is enabled, network TLS must also be enabled") ) // GetIPs returns the default IPs for each network @@ -169,7 +171,7 @@ func init() { version := fs.Bool("version", false, "If true, print version and quit") // NetworkID: - networkName := fs.String("network-id", genesis.TestnetName, "Network ID this node will connect to") + networkName := fs.String("network-id", defaultNetworkName, "Network ID this node will connect to") // Ava fees: fs.Uint64Var(&Config.AvaTxFee, "ava-tx-fee", 0, "Ava transaction fee, in $nAva") @@ -200,7 +202,9 @@ func init() { // Staking: consensusPort := fs.Uint("staking-port", 9651, "Port of the consensus server") - fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Require TLS to authenticate staking connections") + // TODO - keeping same flag for backwards compatibility, should be changed to "staking-enabled" + fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Enable staking. If enabled, Network TLS is required.") + fs.BoolVar(&Config.EnableP2PTLS, "p2p-tls-enabled", true, "Require TLS to authenticate network communication") fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", defaultStakingKeyPath, "TLS private key for staking") fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", defaultStakingCertPath, "TLS certificate for staking") @@ -234,7 +238,15 @@ func init() { ferr := fs.Parse(os.Args[1:]) if *version { // If --version used, print version and exit - fmt.Println(node.Version.String()) + networkID, err := genesis.NetworkID(defaultNetworkName) + if errs.Add(err); err != nil { + return + } + networkGeneration := genesis.NetworkName(networkID) + fmt.Printf( + "%s [database=%s, network=%s/%s]\n", + node.Version, dbVersion, defaultNetworkName, networkGeneration, + ) os.Exit(0) } @@ -318,7 +330,13 @@ func init() { *bootstrapIDs = strings.Join(defaultBootstrapIDs, ",") } } - if Config.EnableStaking { + + if Config.EnableStaking && !Config.EnableP2PTLS { + errs.Add(errStakingRequiresTLS) + return + } + + if Config.EnableP2PTLS { i := 0 cb58 := formatting.CB58{} for _, id := range strings.Split(*bootstrapIDs, ",") { diff --git a/network/network.go b/network/network.go index 0300b09..a280731 100644 --- a/network/network.go +++ b/network/network.go @@ -21,6 +21,7 @@ import ( "github.com/ava-labs/gecko/snow/triggers" "github.com/ava-labs/gecko/snow/validators" "github.com/ava-labs/gecko/utils" + "github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/random" "github.com/ava-labs/gecko/utils/timer" @@ -278,8 +279,11 @@ func (n *network) GetAcceptedFrontier(validatorIDs ids.ShortSet, chainID ids.ID, func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) { msg, err := n.b.AcceptedFrontier(chainID, requestID, containerIDs) if err != nil { - n.log.Error("attempted to pack too large of an AcceptedFrontier message.\nNumber of containerIDs: %d", - containerIDs.Len()) + n.log.Error("failed to build AcceptedFrontier(%s, %d, %s): %s", + chainID, + requestID, + containerIDs, + err) return // Packing message failed } @@ -291,7 +295,11 @@ func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requ sent = peer.send(msg) } if !sent { - n.log.Debug("failed to send an AcceptedFrontier message to: %s", validatorID) + n.log.Debug("failed to send AcceptedFrontier(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerIDs) n.acceptedFrontier.numFailed.Inc() } else { n.acceptedFrontier.numSent.Inc() @@ -302,6 +310,11 @@ func (n *network) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, requ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerIDs ids.Set) { msg, err := n.b.GetAccepted(chainID, requestID, containerIDs) if err != nil { + n.log.Error("failed to build GetAccepted(%s, %d, %s): %s", + chainID, + requestID, + containerIDs, + err) for _, validatorID := range validatorIDs.List() { vID := validatorID n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) }) @@ -319,6 +332,11 @@ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, request sent = peer.send(msg) } if !sent { + n.log.Debug("failed to send GetAccepted(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerIDs) n.executor.Add(func() { n.router.GetAcceptedFailed(vID, chainID, requestID) }) n.getAccepted.numFailed.Inc() } else { @@ -331,8 +349,11 @@ func (n *network) GetAccepted(validatorIDs ids.ShortSet, chainID ids.ID, request func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerIDs ids.Set) { msg, err := n.b.Accepted(chainID, requestID, containerIDs) if err != nil { - n.log.Error("attempted to pack too large of an Accepted message.\nNumber of containerIDs: %d", - containerIDs.Len()) + n.log.Error("failed to build Accepted(%s, %d, %s): %s", + chainID, + requestID, + containerIDs, + err) return // Packing message failed } @@ -344,33 +365,17 @@ func (n *network) Accepted(validatorID ids.ShortID, chainID ids.ID, requestID ui sent = peer.send(msg) } if !sent { - n.log.Debug("failed to send an Accepted message to: %s", validatorID) + n.log.Debug("failed to send Accepted(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerIDs) n.accepted.numFailed.Inc() } else { n.accepted.numSent.Inc() } } -// Get implements the Sender interface. -func (n *network) Get(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) { - msg, err := n.b.Get(chainID, requestID, containerID) - n.log.AssertNoError(err) - - n.stateLock.Lock() - defer n.stateLock.Unlock() - - peer, sent := n.peers[validatorID.Key()] - if sent { - sent = peer.send(msg) - } - if !sent { - n.log.Debug("failed to send a Get message to: %s", validatorID) - n.get.numFailed.Inc() - } else { - n.get.numSent.Inc() - } -} - // GetAncestors implements the Sender interface. func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) { msg, err := n.b.GetAncestors(chainID, requestID, containerID) @@ -387,36 +392,18 @@ func (n *network) GetAncestors(validatorID ids.ShortID, chainID ids.ID, requestI sent = peer.send(msg) } if !sent { + n.log.Debug("failed to send GetAncestors(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerID) + n.executor.Add(func() { n.router.GetAncestorsFailed(validatorID, chainID, requestID) }) n.getAncestors.numFailed.Inc() - n.log.Debug("failed to send a GetAncestors message to: %s", validatorID) } else { n.getAncestors.numSent.Inc() } } -// Put implements the Sender interface. -func (n *network) Put(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) { - msg, err := n.b.Put(chainID, requestID, containerID, container) - if err != nil { - n.log.Error("failed to build Put message because of container of size %d", len(container)) - return - } - - n.stateLock.Lock() - defer n.stateLock.Unlock() - - peer, sent := n.peers[validatorID.Key()] - if sent { - sent = peer.send(msg) - } - if !sent { - n.log.Debug("failed to send a Put message to: %s", validatorID) - n.put.numFailed.Inc() - } else { - n.put.numSent.Inc() - } -} - // MultiPut implements the Sender interface. func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containers [][]byte) { msg, err := n.b.MultiPut(chainID, requestID, containers) @@ -433,22 +420,90 @@ func (n *network) MultiPut(validatorID ids.ShortID, chainID ids.ID, requestID ui sent = peer.send(msg) } if !sent { - n.log.Debug("failed to send a MultiPut message to: %s", validatorID) + n.log.Debug("failed to send MultiPut(%s, %s, %d, %d)", + validatorID, + chainID, + requestID, + len(containers)) n.multiPut.numFailed.Inc() } else { n.multiPut.numSent.Inc() } } +// Get implements the Sender interface. +func (n *network) Get(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID) { + msg, err := n.b.Get(chainID, requestID, containerID) + n.log.AssertNoError(err) + + n.stateLock.Lock() + defer n.stateLock.Unlock() + + peer, sent := n.peers[validatorID.Key()] + if sent { + sent = peer.send(msg) + } + if !sent { + n.log.Debug("failed to send Get(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerID) + n.executor.Add(func() { n.router.GetFailed(validatorID, chainID, requestID) }) + n.get.numFailed.Inc() + } else { + n.get.numSent.Inc() + } +} + +// Put implements the Sender interface. +func (n *network) Put(validatorID ids.ShortID, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) { + msg, err := n.b.Put(chainID, requestID, containerID, container) + if err != nil { + n.log.Error("failed to build Put(%s, %d, %s): %s. len(container) : %d", + chainID, + requestID, + containerID, + err, + len(container)) + return + } + + n.stateLock.Lock() + defer n.stateLock.Unlock() + + peer, sent := n.peers[validatorID.Key()] + if sent { + sent = peer.send(msg) + } + if !sent { + n.log.Debug("failed to send Put(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerID) + n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container}) + n.put.numFailed.Inc() + } else { + n.put.numSent.Inc() + } +} + // PushQuery implements the Sender interface. func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID uint32, containerID ids.ID, container []byte) { msg, err := n.b.PushQuery(chainID, requestID, containerID, container) if err != nil { + n.log.Error("failed to build PushQuery(%s, %d, %s): %s. len(container): %d", + chainID, + requestID, + containerID, + err, + len(container)) + n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container}) for _, validatorID := range validatorIDs.List() { vID := validatorID n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) } - n.log.Error("attempted to pack too large of a PushQuery message.\nContainer length: %d", len(container)) return // Packing message failed } @@ -462,7 +517,12 @@ func (n *network) PushQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID sent = peer.send(msg) } if !sent { - n.log.Debug("failed sending a PushQuery message to: %s", vID) + n.log.Debug("failed to send PushQuery(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerID) + n.log.Verbo("container: %s", formatting.DumpBytes{Bytes: container}) n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) n.pushQuery.numFailed.Inc() } else { @@ -486,7 +546,11 @@ func (n *network) PullQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID sent = peer.send(msg) } if !sent { - n.log.Debug("failed sending a PullQuery message to: %s", vID) + n.log.Debug("failed to send PullQuery(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + containerID) n.executor.Add(func() { n.router.QueryFailed(vID, chainID, requestID) }) n.pullQuery.numFailed.Inc() } else { @@ -499,7 +563,11 @@ func (n *network) PullQuery(validatorIDs ids.ShortSet, chainID ids.ID, requestID func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint32, votes ids.Set) { msg, err := n.b.Chits(chainID, requestID, votes) if err != nil { - n.log.Error("failed to build Chits message because of %d votes", votes.Len()) + n.log.Error("failed to build Chits(%s, %d, %s): %s", + chainID, + requestID, + votes, + err) return } @@ -511,7 +579,11 @@ func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint3 sent = peer.send(msg) } if !sent { - n.log.Debug("failed to send a Chits message to: %s", validatorID) + n.log.Debug("failed to send Chits(%s, %s, %d, %s)", + validatorID, + chainID, + requestID, + votes) n.chits.numFailed.Inc() } else { n.chits.numSent.Inc() @@ -521,7 +593,8 @@ func (n *network) Chits(validatorID ids.ShortID, chainID ids.ID, requestID uint3 // Gossip attempts to gossip the container to the network func (n *network) Gossip(chainID, containerID ids.ID, container []byte) { if err := n.gossipContainer(chainID, containerID, container); err != nil { - n.log.Error("error gossiping container %s to %s: %s", containerID, chainID, err) + n.log.Debug("failed to Gossip(%s, %s): %s", chainID, containerID, err) + n.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container}) } } @@ -695,7 +768,9 @@ func (n *network) gossip() { } msg, err := n.b.PeerList(ips) if err != nil { - n.log.Warn("failed to gossip PeerList message due to %s", err) + n.log.Error("failed to build peer list to gossip: %s. len(ips): %d", + err, + len(ips)) continue } diff --git a/node/config.go b/node/config.go index 74ff491..2504276 100644 --- a/node/config.go +++ b/node/config.go @@ -34,6 +34,7 @@ type Config struct { // Staking configuration StakingIP utils.IPDesc + EnableP2PTLS bool EnableStaking bool StakingKeyFile string StakingCertFile string diff --git a/node/node.go b/node/node.go index ea0e8fc..5e817fa 100644 --- a/node/node.go +++ b/node/node.go @@ -119,7 +119,7 @@ func (n *Node) initNetworking() error { dialer := network.NewDialer(TCP) var serverUpgrader, clientUpgrader network.Upgrader - if n.Config.EnableStaking { + if n.Config.EnableP2PTLS { cert, err := tls.LoadX509KeyPair(n.Config.StakingCertFile, n.Config.StakingKeyFile) if err != nil { return err @@ -253,7 +253,7 @@ func (n *Node) initDatabase() error { // Otherwise, it is a hash of the TLS certificate that this node // uses for P2P communication func (n *Node) initNodeID() error { - if !n.Config.EnableStaking { + if !n.Config.EnableP2PTLS { n.ID = ids.NewShortID(hashing.ComputeHash160Array([]byte(n.Config.StakingIP.String()))) n.Log.Info("Set the node's ID to %s", n.ID) return nil diff --git a/snow/networking/router/subnet_router.go b/snow/networking/router/chain_router.go similarity index 85% rename from snow/networking/router/subnet_router.go rename to snow/networking/router/chain_router.go index 5bf977c..4505bec 100644 --- a/snow/networking/router/subnet_router.go +++ b/snow/networking/router/chain_router.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/ava-labs/gecko/utils/formatting" + "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow/networking/timeout" "github.com/ava-labs/gecko/utils/logging" @@ -67,7 +69,7 @@ func (sr *ChainRouter) RemoveChain(chainID ids.ID) { sr.lock.RLock() chain, exists := sr.chains[chainID.Key()] if !exists { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("can't remove unknown chain %s", chainID) sr.lock.RUnlock() return } @@ -95,7 +97,7 @@ func (sr *ChainRouter) GetAcceptedFrontier(validatorID ids.ShortID, chainID ids. if chain, exists := sr.chains[chainID.Key()]; exists { chain.GetAcceptedFrontier(validatorID, requestID) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("GetAcceptedFrontier(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } } @@ -111,7 +113,7 @@ func (sr *ChainRouter) AcceptedFrontier(validatorID ids.ShortID, chainID ids.ID, sr.timeouts.Cancel(validatorID, chainID, requestID) } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("AcceptedFrontier(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs) } } @@ -132,7 +134,7 @@ func (sr *ChainRouter) GetAcceptedFrontierFailed(validatorID ids.ShortID, chainI return } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Error("GetAcceptedFrontierFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } sr.timeouts.Cancel(validatorID, chainID, requestID) } @@ -147,7 +149,7 @@ func (sr *ChainRouter) GetAccepted(validatorID ids.ShortID, chainID ids.ID, requ if chain, exists := sr.chains[chainID.Key()]; exists { chain.GetAccepted(validatorID, requestID, containerIDs) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("GetAccepted(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs) } } @@ -163,7 +165,7 @@ func (sr *ChainRouter) Accepted(validatorID ids.ShortID, chainID ids.ID, request sr.timeouts.Cancel(validatorID, chainID, requestID) } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("Accepted(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerIDs) } } @@ -183,7 +185,7 @@ func (sr *ChainRouter) GetAcceptedFailed(validatorID ids.ShortID, chainID ids.ID return } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Error("GetAcceptedFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } sr.timeouts.Cancel(validatorID, chainID, requestID) } @@ -198,7 +200,7 @@ func (sr *ChainRouter) GetAncestors(validatorID ids.ShortID, chainID ids.ID, req if chain, exists := sr.chains[chainID.Key()]; exists { chain.GetAncestors(validatorID, requestID, containerID) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("GetAncestors(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } } @@ -215,7 +217,7 @@ func (sr *ChainRouter) MultiPut(validatorID ids.ShortID, chainID ids.ID, request sr.timeouts.Cancel(validatorID, chainID, requestID) } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("MultiPut(%s, %s, %d, %d) dropped due to unknown chain", validatorID, chainID, requestID, len(containers)) } } @@ -234,7 +236,7 @@ func (sr *ChainRouter) GetAncestorsFailed(validatorID ids.ShortID, chainID ids.I return } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Error("GetAncestorsFailed(%s, %s, %d, %d) dropped due to unknown chain", validatorID, chainID, requestID) } sr.timeouts.Cancel(validatorID, chainID, requestID) } @@ -248,7 +250,7 @@ func (sr *ChainRouter) Get(validatorID ids.ShortID, chainID ids.ID, requestID ui if chain, exists := sr.chains[chainID.Key()]; exists { chain.Get(validatorID, requestID, containerID) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("Get(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID) } } @@ -265,7 +267,8 @@ func (sr *ChainRouter) Put(validatorID ids.ShortID, chainID ids.ID, requestID ui sr.timeouts.Cancel(validatorID, chainID, requestID) } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("Put(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID) + sr.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container}) } } @@ -284,7 +287,7 @@ func (sr *ChainRouter) GetFailed(validatorID ids.ShortID, chainID ids.ID, reques return } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Error("GetFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } sr.timeouts.Cancel(validatorID, chainID, requestID) } @@ -298,7 +301,8 @@ func (sr *ChainRouter) PushQuery(validatorID ids.ShortID, chainID ids.ID, reques if chain, exists := sr.chains[chainID.Key()]; exists { chain.PushQuery(validatorID, requestID, containerID, container) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("PushQuery(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID) + sr.log.Verbo("container:\n%s", formatting.DumpBytes{Bytes: container}) } } @@ -311,7 +315,7 @@ func (sr *ChainRouter) PullQuery(validatorID ids.ShortID, chainID ids.ID, reques if chain, exists := sr.chains[chainID.Key()]; exists { chain.PullQuery(validatorID, requestID, containerID) } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("PullQuery(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, containerID) } } @@ -327,7 +331,7 @@ func (sr *ChainRouter) Chits(validatorID ids.ShortID, chainID ids.ID, requestID sr.timeouts.Cancel(validatorID, chainID, requestID) } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Debug("Chits(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID, votes) } } @@ -346,7 +350,7 @@ func (sr *ChainRouter) QueryFailed(validatorID ids.ShortID, chainID ids.ID, requ return } } else { - sr.log.Debug("message referenced a chain, %s, this node doesn't validate", chainID) + sr.log.Error("QueryFailed(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID) } sr.timeouts.Cancel(validatorID, chainID, requestID) } diff --git a/utils/codec/codec.go b/utils/codec/codec.go new file mode 100644 index 0000000..6521993 --- /dev/null +++ b/utils/codec/codec.go @@ -0,0 +1,394 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package codec + +import ( + "errors" + "fmt" + "reflect" + "unicode" + + "github.com/ava-labs/gecko/utils/wrappers" +) + +const ( + defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal() + defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal(). Should be <= math.MaxUint32. + + // initial capacity of byte slice that values are marshaled into. + // Larger value --> need less memory allocations but possibly have allocated but unused memory + // Smaller value --> need more memory allocations but more efficient use of allocated memory + initialSliceCap = 256 +) + +var ( + errNil = errors.New("can't marshal/unmarshal nil pointer or interface") + errNeedPointer = errors.New("argument to unmarshal should be a pointer") +) + +// Codec handles marshaling and unmarshaling of structs +type codec struct { + maxSize int + maxSliceLen int + + typeIDToType map[uint32]reflect.Type + typeToTypeID map[reflect.Type]uint32 + + // Key: a struct type + // Value: Slice where each element is index in the struct type + // of a field that is serialized/deserialized + // e.g. Foo --> [1,5,8] means Foo.Field(1), etc. are to be serialized/deserialized + // We assume this cache is pretty small (a few hundred keys at most) + // and doesn't take up much memory + serializedFieldIndices map[reflect.Type][]int +} + +// Codec marshals and unmarshals +type Codec interface { + RegisterType(interface{}) error + Marshal(interface{}) ([]byte, error) + Unmarshal([]byte, interface{}) error +} + +// New returns a new codec +func New(maxSize, maxSliceLen int) Codec { + return &codec{ + maxSize: maxSize, + maxSliceLen: maxSliceLen, + typeIDToType: map[uint32]reflect.Type{}, + typeToTypeID: map[reflect.Type]uint32{}, + serializedFieldIndices: map[reflect.Type][]int{}, + } +} + +// NewDefault returns a new codec with reasonable default values +func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) } + +// RegisterType is used to register types that may be unmarshaled into an interface +// [val] is a value of the type being registered +func (c *codec) RegisterType(val interface{}) error { + valType := reflect.TypeOf(val) + if _, exists := c.typeToTypeID[valType]; exists { + return fmt.Errorf("type %v has already been registered", valType) + } + c.typeIDToType[uint32(len(c.typeIDToType))] = reflect.TypeOf(val) + c.typeToTypeID[valType] = uint32(len(c.typeIDToType) - 1) + return nil +} + +// A few notes: +// 1) See codec_test.go for examples of usage +// 2) We use "marshal" and "serialize" interchangeably, and "unmarshal" and "deserialize" interchangeably +// 3) To include a field of a struct in the serialized form, add the tag `serialize:"true"` to it +// 4) These typed members of a struct may be serialized: +// bool, string, uint[8,16,32,64], int[8,16,32,64], +// structs, slices, arrays, interface. +// structs, slices and arrays can only be serialized if their constituent values can be. +// 5) To marshal an interface, you must pass a pointer to the value +// 6) To unmarshal an interface, you must call codec.RegisterType([instance of the type that fulfills the interface]). +// 7) Serialized fields must be exported +// 8) nil slices are marshaled as empty slices + +// To marshal an interface, [value] must be a pointer to the interface +func (c *codec) Marshal(value interface{}) ([]byte, error) { + if value == nil { + return nil, errNil // can't marshal nil + } + + p := &wrappers.Packer{MaxSize: c.maxSize, Bytes: make([]byte, 0, initialSliceCap)} + if err := c.marshal(reflect.ValueOf(value), p); err != nil { + return nil, err + } + + return p.Bytes, nil +} + +// marshal writes the byte representation of [value] to [p] +// [value]'s underlying value must not be a nil pointer or interface +func (c *codec) marshal(value reflect.Value, p *wrappers.Packer) error { + valueKind := value.Kind() + switch valueKind { + case reflect.Interface, reflect.Ptr, reflect.Invalid: + if value.IsNil() { // Can't marshal nil (except nil slices) + return errNil + } + } + + switch valueKind { + case reflect.Uint8: + p.PackByte(uint8(value.Uint())) + return p.Err + case reflect.Int8: + p.PackByte(uint8(value.Int())) + return p.Err + case reflect.Uint16: + p.PackShort(uint16(value.Uint())) + return p.Err + case reflect.Int16: + p.PackShort(uint16(value.Int())) + return p.Err + case reflect.Uint32: + p.PackInt(uint32(value.Uint())) + return p.Err + case reflect.Int32: + p.PackInt(uint32(value.Int())) + return p.Err + case reflect.Uint64: + p.PackLong(value.Uint()) + return p.Err + case reflect.Int64: + p.PackLong(uint64(value.Int())) + return p.Err + case reflect.String: + p.PackStr(value.String()) + return p.Err + case reflect.Bool: + p.PackBool(value.Bool()) + return p.Err + case reflect.Uintptr, reflect.Ptr: + return c.marshal(value.Elem(), p) + case reflect.Interface: + underlyingValue := value.Interface() + typeID, ok := c.typeToTypeID[reflect.TypeOf(underlyingValue)] // Get the type ID of the value being marshaled + if !ok { + return fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(underlyingValue).String()) + } + p.PackInt(typeID) // Pack type ID so we know what to unmarshal this into + if p.Err != nil { + return p.Err + } + if err := c.marshal(value.Elem(), p); err != nil { + return err + } + return p.Err + case reflect.Slice: + numElts := value.Len() // # elements in the slice/array. 0 if this slice is nil. + if numElts > c.maxSliceLen { + return fmt.Errorf("slice length, %d, exceeds maximum length, %d", numElts, c.maxSliceLen) + } + p.PackInt(uint32(numElts)) // pack # elements + if p.Err != nil { + return p.Err + } + for i := 0; i < numElts; i++ { // Process each element in the slice + if err := c.marshal(value.Index(i), p); err != nil { + return err + } + } + return nil + case reflect.Array: + numElts := value.Len() + if numElts > c.maxSliceLen { + return fmt.Errorf("array length, %d, exceeds maximum length, %d", numElts, c.maxSliceLen) + } + for i := 0; i < numElts; i++ { // Process each element in the array + if err := c.marshal(value.Index(i), p); err != nil { + return err + } + } + return nil + case reflect.Struct: + serializedFields, err := c.getSerializedFieldIndices(value.Type()) + if err != nil { + return err + } + for _, fieldIndex := range serializedFields { // Go through all fields of this struct that are serialized + if err := c.marshal(value.Field(fieldIndex), p); err != nil { // Serialize the field and write to byte array + return err + } + } + return nil + default: + return fmt.Errorf("can't marshal unknown kind %s", valueKind) + } +} + +// Unmarshal unmarshals [bytes] into [dest], where +// [dest] must be a pointer or interface +func (c *codec) Unmarshal(bytes []byte, dest interface{}) error { + switch { + case len(bytes) > c.maxSize: + return fmt.Errorf("byte array exceeds maximum length, %d", c.maxSize) + case dest == nil: + return errNil + } + + destPtr := reflect.ValueOf(dest) + if destPtr.Kind() != reflect.Ptr { + return errNeedPointer + } + + p := &wrappers.Packer{MaxSize: c.maxSize, Bytes: bytes} + destVal := destPtr.Elem() + if err := c.unmarshal(p, destVal); err != nil { + return err + } + + return nil +} + +// Unmarshal from p.Bytes into [value]. [value] must be addressable. +func (c *codec) unmarshal(p *wrappers.Packer, value reflect.Value) error { + switch value.Kind() { + case reflect.Uint8: + value.SetUint(uint64(p.UnpackByte())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint8: %s", p.Err) + } + return nil + case reflect.Int8: + value.SetInt(int64(p.UnpackByte())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int8: %s", p.Err) + } + return nil + case reflect.Uint16: + value.SetUint(uint64(p.UnpackShort())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint16: %s", p.Err) + } + return nil + case reflect.Int16: + value.SetInt(int64(p.UnpackShort())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int16: %s", p.Err) + } + return nil + case reflect.Uint32: + value.SetUint(uint64(p.UnpackInt())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint32: %s", p.Err) + } + return nil + case reflect.Int32: + value.SetInt(int64(p.UnpackInt())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int32: %s", p.Err) + } + return nil + case reflect.Uint64: + value.SetUint(uint64(p.UnpackLong())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal uint64: %s", p.Err) + } + return nil + case reflect.Int64: + value.SetInt(int64(p.UnpackLong())) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal int64: %s", p.Err) + } + return nil + case reflect.Bool: + value.SetBool(p.UnpackBool()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal bool: %s", p.Err) + } + return nil + case reflect.Slice: + numElts := int(p.UnpackInt()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal slice: %s", p.Err) + } + // set [value] to be a slice of the appropriate type/capacity (right now it is nil) + value.Set(reflect.MakeSlice(value.Type(), numElts, numElts)) + // Unmarshal each element into the appropriate index of the slice + for i := 0; i < numElts; i++ { + if err := c.unmarshal(p, value.Index(i)); err != nil { + return fmt.Errorf("couldn't unmarshal slice element: %s", err) + } + } + return nil + case reflect.Array: + for i := 0; i < value.Len(); i++ { + if err := c.unmarshal(p, value.Index(i)); err != nil { + return fmt.Errorf("couldn't unmarshal array element: %s", err) + } + } + return nil + case reflect.String: + value.SetString(p.UnpackStr()) + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal string: %s", p.Err) + } + return nil + case reflect.Interface: + typeID := p.UnpackInt() // Get the type ID + if p.Err != nil { + return fmt.Errorf("couldn't unmarshal interface: %s", p.Err) + } + // Get a type that implements the interface + implementingType, ok := c.typeIDToType[typeID] + if !ok { + return fmt.Errorf("couldn't unmarshal interface: unknown type ID %d", typeID) + } + // Ensure type actually does implement the interface + if valueType := value.Type(); !implementingType.Implements(valueType) { + return fmt.Errorf("couldn't unmarshal interface: %s does not implement interface %s", implementingType, valueType) + } + intfImplementor := reflect.New(implementingType).Elem() // instance of the proper type + // Unmarshal into the struct + if err := c.unmarshal(p, intfImplementor); err != nil { + return fmt.Errorf("couldn't unmarshal interface: %s", err) + } + // And assign the filled struct to the value + value.Set(intfImplementor) + return nil + case reflect.Struct: + // Get indices of fields that will be unmarshaled into + serializedFieldIndices, err := c.getSerializedFieldIndices(value.Type()) + if err != nil { + return fmt.Errorf("couldn't unmarshal struct: %s", err) + } + // Go through the fields and umarshal into them + for _, index := range serializedFieldIndices { + if err := c.unmarshal(p, value.Field(index)); err != nil { + return fmt.Errorf("couldn't unmarshal struct: %s", err) + } + } + return nil + case reflect.Ptr: + // Get the type this pointer points to + t := value.Type().Elem() + // Create a new pointer to a new value of the underlying type + v := reflect.New(t) + // Fill the value + if err := c.unmarshal(p, v.Elem()); err != nil { + return fmt.Errorf("couldn't unmarshal pointer: %s", err) + } + // Assign to the top-level struct's member + value.Set(v) + return nil + case reflect.Invalid: + return errNil + default: + return fmt.Errorf("can't unmarshal unknown type %s", value.Kind().String()) + } +} + +// Returns the indices of the serializable fields of [t], which is a struct type +// Returns an error if a field has tag "serialize: true" but the field is unexported +// e.g. getSerializedFieldIndices(Foo) --> [1,5,8] means Foo.Field(1), Foo.Field(5), Foo.Field(8) +// are to be serialized/deserialized +func (c *codec) getSerializedFieldIndices(t reflect.Type) ([]int, error) { + if c.serializedFieldIndices == nil { + c.serializedFieldIndices = make(map[reflect.Type][]int) + } + if serializedFields, ok := c.serializedFieldIndices[t]; ok { // use pre-computed result + return serializedFields, nil + } + numFields := t.NumField() + serializedFields := make([]int, 0, numFields) + for i := 0; i < numFields; i++ { // Go through all fields of this struct + field := t.Field(i) + if field.Tag.Get("serialize") != "true" { // Skip fields we don't need to serialize + continue + } + if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields + return []int{}, fmt.Errorf("can't marshal unexported field %s", field.Name) + } + serializedFields = append(serializedFields, i) + } + c.serializedFieldIndices[t] = serializedFields // cache result + return serializedFields, nil +} diff --git a/vms/components/codec/codec_benchmark_test.go b/utils/codec/codec_benchmark_test.go similarity index 84% rename from vms/components/codec/codec_benchmark_test.go rename to utils/codec/codec_benchmark_test.go index 8e6f9f7..4adfa52 100644 --- a/vms/components/codec/codec_benchmark_test.go +++ b/utils/codec/codec_benchmark_test.go @@ -35,13 +35,22 @@ func BenchmarkMarshal(b *testing.B) { }, MyPointer: &temp, } + var unmarshaledMyStructInstance myStruct codec := NewDefault() codec.RegisterType(&MyInnerStruct{}) // Register the types that may be unmarshaled into interfaces codec.RegisterType(&MyInnerStruct2{}) + codec.Marshal(myStructInstance) // warm up serializedFields cache b.ResetTimer() for n := 0; n < b.N; n++ { - codec.Marshal(myStructInstance) + bytes, err := codec.Marshal(myStructInstance) + if err != nil { + b.Fatal(err) + } + if err := codec.Unmarshal(bytes, &unmarshaledMyStructInstance); err != nil { + b.Fatal(err) + } + } } diff --git a/vms/components/codec/codec_test.go b/utils/codec/codec_test.go similarity index 80% rename from vms/components/codec/codec_test.go rename to utils/codec/codec_test.go index 6fdfeba..edd3f85 100644 --- a/vms/components/codec/codec_test.go +++ b/utils/codec/codec_test.go @@ -5,6 +5,7 @@ package codec import ( "bytes" + "math" "reflect" "testing" ) @@ -104,36 +105,8 @@ func TestStruct(t *testing.T) { t.Fatal(err) } - if !reflect.DeepEqual(myStructUnmarshaled.Member1, myStructInstance.Member1) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !bytes.Equal(myStructUnmarshaled.MySlice, myStructInstance.MySlice) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MySlice2, myStructInstance.MySlice2) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MySlice3, myStructInstance.MySlice3) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MySlice3, myStructInstance.MySlice3) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MySlice4, myStructInstance.MySlice4) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct, myStructInstance.InnerStruct) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct2, myStructInstance.InnerStruct2) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MyArray2, myStructInstance.MyArray2) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MyArray3, myStructInstance.MyArray3) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MyArray4, myStructInstance.MyArray4) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MyInterface, myStructInstance.MyInterface) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MySlice5, myStructInstance.MySlice5) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.InnerStruct3, myStructInstance.InnerStruct3) { - t.Fatal("expected unmarshaled struct to be same as original struct") - } else if !reflect.DeepEqual(myStructUnmarshaled.MyPointer, myStructInstance.MyPointer) { - t.Fatal("expected unmarshaled struct to be same as original struct") + if !reflect.DeepEqual(*myStructUnmarshaled, myStructInstance) { + t.Fatal("should be same") } } @@ -173,6 +146,28 @@ func TestSlice(t *testing.T) { } } +// Test marshalling/unmarshalling largest possible slice +func TestMaxSizeSlice(t *testing.T) { + mySlice := make([]string, math.MaxUint16, math.MaxUint16) + mySlice[0] = "first!" + mySlice[math.MaxUint16-1] = "last!" + codec := NewDefault() + bytes, err := codec.Marshal(mySlice) + if err != nil { + t.Fatal(err) + } + + var sliceUnmarshaled []string + if err := codec.Unmarshal(bytes, &sliceUnmarshaled); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(mySlice, sliceUnmarshaled) { + t.Fatal("expected marshaled and unmarshaled values to match") + } +} + +// Test marshalling a bool func TestBool(t *testing.T) { myBool := true codec := NewDefault() @@ -191,6 +186,7 @@ func TestBool(t *testing.T) { } } +// Test marshalling an array func TestArray(t *testing.T) { myArr := [5]uint64{5, 6, 7, 8, 9} codec := NewDefault() @@ -209,6 +205,26 @@ func TestArray(t *testing.T) { } } +// Test marshalling a really big array +func TestBigArray(t *testing.T) { + myArr := [30000]uint64{5, 6, 7, 8, 9} + codec := NewDefault() + bytes, err := codec.Marshal(myArr) + if err != nil { + t.Fatal(err) + } + + var myArrUnmarshaled [30000]uint64 + if err := codec.Unmarshal(bytes, &myArrUnmarshaled); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(myArr, myArrUnmarshaled) { + t.Fatal("expected marshaled and unmarshaled values to match") + } +} + +// Test marshalling a pointer to a struct func TestPointerToStruct(t *testing.T) { myPtr := &MyInnerStruct{Str: "Hello!"} codec := NewDefault() @@ -227,6 +243,7 @@ func TestPointerToStruct(t *testing.T) { } } +// Test marshalling a slice of structs func TestSliceOfStruct(t *testing.T) { mySlice := []MyInnerStruct3{ MyInnerStruct3{ @@ -257,6 +274,7 @@ func TestSliceOfStruct(t *testing.T) { } } +// Test marshalling an interface func TestInterface(t *testing.T) { codec := NewDefault() codec.RegisterType(&MyInnerStruct2{}) @@ -278,6 +296,7 @@ func TestInterface(t *testing.T) { } } +// Test marshalling a slice of interfaces func TestSliceOfInterface(t *testing.T) { mySlice := []Foo{ &MyInnerStruct{ @@ -304,6 +323,7 @@ func TestSliceOfInterface(t *testing.T) { } } +// Test marshalling an array of interfaces func TestArrayOfInterface(t *testing.T) { myArray := [2]Foo{ &MyInnerStruct{ @@ -330,6 +350,7 @@ func TestArrayOfInterface(t *testing.T) { } } +// Test marshalling a pointer to an interface func TestPointerToInterface(t *testing.T) { var myinnerStruct Foo = &MyInnerStruct{Str: "Hello!"} var myPtr *Foo = &myinnerStruct @@ -352,6 +373,7 @@ func TestPointerToInterface(t *testing.T) { } } +// Test marshalling a string func TestString(t *testing.T) { myString := "Ayy" codec := NewDefault() @@ -370,7 +392,7 @@ func TestString(t *testing.T) { } } -// Ensure a nil slice is unmarshaled as an empty slice +// Ensure a nil slice is unmarshaled to slice with length 0 func TestNilSlice(t *testing.T) { type structWithSlice struct { Slice []byte `serialize:"true"` @@ -389,12 +411,12 @@ func TestNilSlice(t *testing.T) { } if structUnmarshaled.Slice == nil || len(structUnmarshaled.Slice) != 0 { - t.Fatal("expected slice to be empty slice") + t.Fatal("expected slice to be non-nil and length 0") } } // Ensure that trying to serialize a struct with an unexported member -// that has `serialize:"true"` returns errUnexportedField +// that has `serialize:"true"` returns error func TestSerializeUnexportedField(t *testing.T) { type s struct { ExportedField string `serialize:"true"` @@ -407,8 +429,8 @@ func TestSerializeUnexportedField(t *testing.T) { } codec := NewDefault() - if _, err := codec.Marshal(myS); err != errMarshalUnexportedField { - t.Fatalf("expected err to be errUnexportedField but was %v", err) + if _, err := codec.Marshal(myS); err == nil { + t.Fatalf("expected err but got none") } } @@ -426,12 +448,12 @@ func TestSerializeOfNoSerializeField(t *testing.T) { codec := NewDefault() marshalled, err := codec.Marshal(myS) if err != nil { - t.Fatalf("Unexpected error %q", err) + t.Fatal(err) } unmarshalled := s{} err = codec.Unmarshal(marshalled, &unmarshalled) if err != nil { - t.Fatalf("Unexpected error %q", err) + t.Fatal(err) } expectedUnmarshalled := s{SerializedField: "Serialize me"} if !reflect.DeepEqual(unmarshalled, expectedUnmarshalled) { @@ -443,11 +465,12 @@ type simpleSliceStruct struct { Arr []uint32 `serialize:"true"` } -func TestEmptySliceSerialization(t *testing.T) { +// Test marshalling of nil slice +func TestNilSliceSerialization(t *testing.T) { codec := NewDefault() val := &simpleSliceStruct{} - expected := []byte{0, 0, 0, 0} + expected := []byte{0, 0, 0, 0} // nil slice marshaled as 0 length slice result, err := codec.Marshal(val) if err != nil { t.Fatal(err) @@ -456,6 +479,36 @@ func TestEmptySliceSerialization(t *testing.T) { if !bytes.Equal(expected, result) { t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result) } + + valUnmarshaled := &simpleSliceStruct{} + if err = codec.Unmarshal(result, &valUnmarshaled); err != nil { + t.Fatal(err) + } else if len(valUnmarshaled.Arr) != 0 { + t.Fatal("should be 0 length") + } +} + +// Test marshaling a slice that has 0 elements (but isn't nil) +func TestEmptySliceSerialization(t *testing.T) { + codec := NewDefault() + + val := &simpleSliceStruct{Arr: make([]uint32, 0, 1)} + expected := []byte{0, 0, 0, 0} // 0 for size + result, err := codec.Marshal(val) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(expected, result) { + t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result) + } + + valUnmarshaled := &simpleSliceStruct{} + if err = codec.Unmarshal(result, &valUnmarshaled); err != nil { + t.Fatal(err) + } else if !reflect.DeepEqual(valUnmarshaled, val) { + t.Fatal("should be same") + } } type emptyStruct struct{} @@ -464,13 +517,14 @@ type nestedSliceStruct struct { Arr []emptyStruct `serialize:"true"` } +// Test marshaling slice that is not nil and not empty func TestSliceWithEmptySerialization(t *testing.T) { codec := NewDefault() val := &nestedSliceStruct{ Arr: make([]emptyStruct, 1000), } - expected := []byte{0x00, 0x00, 0x03, 0xE8} + expected := []byte{0x00, 0x00, 0x03, 0xE8} //1000 for numElts result, err := codec.Marshal(val) if err != nil { t.Fatal(err) @@ -485,7 +539,7 @@ func TestSliceWithEmptySerialization(t *testing.T) { t.Fatal(err) } if len(unmarshaled.Arr) != 1000 { - t.Fatalf("Should have created an array of length %d", 1000) + t.Fatalf("Should have created a slice of length %d", 1000) } } @@ -493,20 +547,15 @@ func TestSliceWithEmptySerializationOutOfMemory(t *testing.T) { codec := NewDefault() val := &nestedSliceStruct{ - Arr: make([]emptyStruct, 1000000), + Arr: make([]emptyStruct, defaultMaxSliceLength+1), } - expected := []byte{0x00, 0x0f, 0x42, 0x40} // 1,000,000 in hex - result, err := codec.Marshal(val) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(expected, result) { - t.Fatalf("\nExpected: 0x%x\nResult: 0x%x", expected, result) + bytes, err := codec.Marshal(val) + if err == nil { + t.Fatal("should have failed due to slice length too large") } unmarshaled := nestedSliceStruct{} - if err := codec.Unmarshal(expected, &unmarshaled); err == nil { + if err := codec.Unmarshal(bytes, &unmarshaled); err == nil { t.Fatalf("Should have errored due to excess memory requested") } } diff --git a/utils/wrappers/packing.go b/utils/wrappers/packing.go index 1038852..c048f9c 100644 --- a/utils/wrappers/packing.go +++ b/utils/wrappers/packing.go @@ -61,26 +61,23 @@ func (p *Packer) CheckSpace(bytes int) { } } -// Expand ensures that there is [bytes] bytes left of space in the byte array. -// If this is not allowed due to the maximum size, an error is added to the -// packer +// Expand ensures that there is [bytes] bytes left of space in the byte slice. +// If this is not allowed due to the maximum size, an error is added to the packer +// In order to understand this code, its important to understand the difference +// between a slice's length and its capacity. func (p *Packer) Expand(bytes int) { - p.CheckSpace(0) - if p.Errored() { + neededSize := bytes + p.Offset // Need byte slice's length to be at least [neededSize] + switch { + case neededSize <= len(p.Bytes): // Byte slice has sufficient length already return - } - - neededSize := bytes + p.Offset - if neededSize <= len(p.Bytes) { + case neededSize > p.MaxSize: // Lengthening the byte slice would cause it to grow too large + p.Err = errBadLength return - } - - if neededSize > p.MaxSize { - p.Add(errBadLength) - } else if neededSize > cap(p.Bytes) { - p.Bytes = append(p.Bytes[:cap(p.Bytes)], make([]byte, neededSize-cap(p.Bytes))...) - } else { + case neededSize <= cap(p.Bytes): // Byte slice has sufficient capacity to lengthen it without mem alloc p.Bytes = p.Bytes[:neededSize] + return + default: // Add capacity/length to byte slice + p.Bytes = append(p.Bytes[:cap(p.Bytes)], make([]byte, neededSize-cap(p.Bytes))...) } } diff --git a/vms/avm/base_tx.go b/vms/avm/base_tx.go index 33cba51..0ab3fa4 100644 --- a/vms/avm/base_tx.go +++ b/vms/avm/base_tx.go @@ -10,7 +10,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/create_asset_tx.go b/vms/avm/create_asset_tx.go index 9f95a15..77aae2f 100644 --- a/vms/avm/create_asset_tx.go +++ b/vms/avm/create_asset_tx.go @@ -11,7 +11,7 @@ import ( "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) const ( diff --git a/vms/avm/create_asset_tx_test.go b/vms/avm/create_asset_tx_test.go index a26a815..324f403 100644 --- a/vms/avm/create_asset_tx_test.go +++ b/vms/avm/create_asset_tx_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/avm/export_tx.go b/vms/avm/export_tx.go index d5222f4..d788360 100644 --- a/vms/avm/export_tx.go +++ b/vms/avm/export_tx.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/database/versiondb" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/export_tx_test.go b/vms/avm/export_tx_test.go index 4e9d064..fdef399 100644 --- a/vms/avm/export_tx_test.go +++ b/vms/avm/export_tx_test.go @@ -16,7 +16,7 @@ import ( "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/avm/import_tx.go b/vms/avm/import_tx.go index 09dec6e..1729221 100644 --- a/vms/avm/import_tx.go +++ b/vms/avm/import_tx.go @@ -12,7 +12,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/initial_state.go b/vms/avm/initial_state.go index c3d4b16..73ad6e4 100644 --- a/vms/avm/initial_state.go +++ b/vms/avm/initial_state.go @@ -8,7 +8,7 @@ import ( "errors" "sort" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/initial_state_test.go b/vms/avm/initial_state_test.go index 67c4b15..b61876c 100644 --- a/vms/avm/initial_state_test.go +++ b/vms/avm/initial_state_test.go @@ -11,7 +11,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/avm/operation.go b/vms/avm/operation.go index 3b5fc9a..ef9317b 100644 --- a/vms/avm/operation.go +++ b/vms/avm/operation.go @@ -10,7 +10,7 @@ import ( "github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/operation_test.go b/vms/avm/operation_test.go index 8948388..8b85901 100644 --- a/vms/avm/operation_test.go +++ b/vms/avm/operation_test.go @@ -8,7 +8,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/operation_tx.go b/vms/avm/operation_tx.go index 9384f8d..ec419c7 100644 --- a/vms/avm/operation_tx.go +++ b/vms/avm/operation_tx.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/service.go b/vms/avm/service.go index f71d607..9b21414 100644 --- a/vms/avm/service.go +++ b/vms/avm/service.go @@ -56,7 +56,7 @@ type IssueTxReply struct { // IssueTx attempts to issue a transaction into consensus func (service *Service) IssueTx(r *http.Request, args *IssueTxArgs, reply *IssueTxReply) error { - service.vm.ctx.Log.Verbo("IssueTx called with %s", args.Tx) + service.vm.ctx.Log.Info("AVM: IssueTx called with %s", args.Tx) txID, err := service.vm.IssueTx(args.Tx.Bytes, nil) if err != nil { @@ -79,7 +79,7 @@ type GetTxStatusReply struct { // GetTxStatus returns the status of the specified transaction func (service *Service) GetTxStatus(r *http.Request, args *GetTxStatusArgs, reply *GetTxStatusReply) error { - service.vm.ctx.Log.Verbo("GetTxStatus called with %s", args.TxID) + service.vm.ctx.Log.Info("AVM: GetTxStatus called with %s", args.TxID) if args.TxID.IsZero() { return errNilTxID @@ -106,7 +106,7 @@ type GetTxReply struct { // GetTx returns the specified transaction func (service *Service) GetTx(r *http.Request, args *GetTxArgs, reply *GetTxReply) error { - service.vm.ctx.Log.Verbo("GetTx called with %s", args.TxID) + service.vm.ctx.Log.Info("AVM: GetTx called with %s", args.TxID) if args.TxID.IsZero() { return errNilTxID @@ -136,7 +136,7 @@ type GetUTXOsReply struct { // GetUTXOs creates an empty account with the name passed in func (service *Service) GetUTXOs(r *http.Request, args *GetUTXOsArgs, reply *GetUTXOsReply) error { - service.vm.ctx.Log.Verbo("GetUTXOs called with %s", args.Addresses) + service.vm.ctx.Log.Info("AVM: GetUTXOs called with %s", args.Addresses) addrSet := ids.Set{} for _, addr := range args.Addresses { @@ -178,7 +178,7 @@ type GetAssetDescriptionReply struct { // GetAssetDescription creates an empty account with the name passed in func (service *Service) GetAssetDescription(_ *http.Request, args *GetAssetDescriptionArgs, reply *GetAssetDescriptionReply) error { - service.vm.ctx.Log.Verbo("GetAssetDescription called with %s", args.AssetID) + service.vm.ctx.Log.Info("AVM: GetAssetDescription called with %s", args.AssetID) assetID, err := service.vm.Lookup(args.AssetID) if err != nil { @@ -222,7 +222,7 @@ type GetBalanceReply struct { // GetBalance returns the amount of an asset that an address at least partially owns func (service *Service) GetBalance(r *http.Request, args *GetBalanceArgs, reply *GetBalanceReply) error { - service.vm.ctx.Log.Verbo("GetBalance called with address: %s assetID: %s", args.Address, args.AssetID) + service.vm.ctx.Log.Info("AVM: GetBalance called with address: %s assetID: %s", args.Address, args.AssetID) address, err := service.vm.Parse(args.Address) if err != nil { @@ -287,7 +287,7 @@ type GetAllBalancesReply struct { // Note that balances include assets that the address only _partially_ owns // (ie is one of several addresses specified in a multi-sig) func (service *Service) GetAllBalances(r *http.Request, args *GetAllBalancesArgs, reply *GetAllBalancesReply) error { - service.vm.ctx.Log.Verbo("GetAllBalances called with address: %s", args.Address) + service.vm.ctx.Log.Info("AVM: GetAllBalances called with address: %s", args.Address) address, err := service.vm.Parse(args.Address) if err != nil { @@ -360,7 +360,7 @@ type CreateFixedCapAssetReply struct { // CreateFixedCapAsset returns ID of the newly created asset func (service *Service) CreateFixedCapAsset(r *http.Request, args *CreateFixedCapAssetArgs, reply *CreateFixedCapAssetReply) error { - service.vm.ctx.Log.Verbo("CreateFixedCapAsset called with name: %s symbol: %s number of holders: %d", + service.vm.ctx.Log.Info("AVM: CreateFixedCapAsset called with name: %s symbol: %s number of holders: %d", args.Name, args.Symbol, len(args.InitialHolders), @@ -445,7 +445,7 @@ type CreateVariableCapAssetReply struct { // CreateVariableCapAsset returns ID of the newly created asset func (service *Service) CreateVariableCapAsset(r *http.Request, args *CreateVariableCapAssetArgs, reply *CreateVariableCapAssetReply) error { - service.vm.ctx.Log.Verbo("CreateFixedCapAsset called with name: %s symbol: %s number of minters: %d", + service.vm.ctx.Log.Info("AVM: CreateFixedCapAsset called with name: %s symbol: %s number of minters: %d", args.Name, args.Symbol, len(args.MinterSets), @@ -523,7 +523,7 @@ type CreateAddressReply struct { // CreateAddress creates an address for the user [args.Username] func (service *Service) CreateAddress(r *http.Request, args *CreateAddressArgs, reply *CreateAddressReply) error { - service.vm.ctx.Log.Verbo("CreateAddress called for user '%s'", args.Username) + service.vm.ctx.Log.Info("AVM: CreateAddress called for user '%s'", args.Username) db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password) if err != nil { @@ -603,7 +603,7 @@ type ExportKeyReply struct { // ExportKey returns a private key from the provided user func (service *Service) ExportKey(r *http.Request, args *ExportKeyArgs, reply *ExportKeyReply) error { - service.vm.ctx.Log.Verbo("ExportKey called for user '%s'", args.Username) + service.vm.ctx.Log.Info("AVM: ExportKey called for user '%s'", args.Username) address, err := service.vm.Parse(args.Address) if err != nil { @@ -645,7 +645,7 @@ type ImportKeyReply struct { // ImportKey adds a private key to the provided user func (service *Service) ImportKey(r *http.Request, args *ImportKeyArgs, reply *ImportKeyReply) error { - service.vm.ctx.Log.Verbo("ImportKey called for user '%s'", args.Username) + service.vm.ctx.Log.Info("AVM: ImportKey called for user '%s'", args.Username) db, err := service.vm.ctx.Keystore.GetDatabase(args.Username, args.Password) if err != nil { @@ -666,13 +666,20 @@ func (service *Service) ImportKey(r *http.Request, args *ImportKeyArgs, reply *I } addresses, _ := user.Addresses(db) - addresses = append(addresses, sk.PublicKey().Address()) + newAddress := sk.PublicKey().Address() + reply.Address = service.vm.Format(newAddress.Bytes()) + for _, address := range addresses { + if newAddress.Equals(address) { + return nil + } + } + + addresses = append(addresses, newAddress) if err := user.SetAddresses(db, addresses); err != nil { return fmt.Errorf("problem saving addresses: %w", err) } - reply.Address = service.vm.Format(sk.PublicKey().Address().Bytes()) return nil } @@ -692,7 +699,7 @@ type SendReply struct { // Send returns the ID of the newly created transaction func (service *Service) Send(r *http.Request, args *SendArgs, reply *SendReply) error { - service.vm.ctx.Log.Verbo("Send called with username: %s", args.Username) + service.vm.ctx.Log.Info("AVM: Send called with username: %s", args.Username) if args.Amount == 0 { return errInvalidAmount @@ -873,7 +880,7 @@ type CreateMintTxReply struct { // CreateMintTx returns the newly created unsigned transaction func (service *Service) CreateMintTx(r *http.Request, args *CreateMintTxArgs, reply *CreateMintTxReply) error { - service.vm.ctx.Log.Verbo("CreateMintTx called") + service.vm.ctx.Log.Info("AVM: CreateMintTx called") if args.Amount == 0 { return errInvalidMintAmount @@ -990,7 +997,7 @@ type SignMintTxReply struct { // SignMintTx returns the newly signed transaction func (service *Service) SignMintTx(r *http.Request, args *SignMintTxArgs, reply *SignMintTxReply) error { - service.vm.ctx.Log.Verbo("SignMintTx called") + service.vm.ctx.Log.Info("AVM: SignMintTx called") minter, err := service.vm.Parse(args.Minter) if err != nil { @@ -1116,7 +1123,7 @@ type ImportAVAReply struct { // The AVA must have already been exported from the P-Chain. // Returns the ID of the newly created atomic transaction func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, reply *ImportAVAReply) error { - service.vm.ctx.Log.Verbo("ImportAVA called with username: %s", args.Username) + service.vm.ctx.Log.Info("AVM: ImportAVA called with username: %s", args.Username) toBytes, err := service.vm.Parse(args.To) if err != nil { @@ -1268,7 +1275,7 @@ type ExportAVAReply struct { // After this tx is accepted, the AVA must be imported to the P-chain with an importTx. // Returns the ID of the newly created atomic transaction func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, reply *ExportAVAReply) error { - service.vm.ctx.Log.Verbo("ExportAVA called with username: %s", args.Username) + service.vm.ctx.Log.Info("AVM: ExportAVA called with username: %s", args.Username) if args.Amount == 0 { return errInvalidAmount diff --git a/vms/avm/service_test.go b/vms/avm/service_test.go index fdd8053..6e1d387 100644 --- a/vms/avm/service_test.go +++ b/vms/avm/service_test.go @@ -9,8 +9,10 @@ import ( "github.com/stretchr/testify/assert" + "github.com/ava-labs/gecko/api/keystore" "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow/choices" + "github.com/ava-labs/gecko/utils/crypto" "github.com/ava-labs/gecko/utils/formatting" ) @@ -340,3 +342,113 @@ func TestCreateVariableCapAsset(t *testing.T) { t.Fatalf("Wrong assetID returned from CreateFixedCapAsset %s", reply.AssetID) } } + +func TestImportAvmKey(t *testing.T) { + _, vm, s := setup(t) + defer func() { + vm.Shutdown() + ctx.Lock.Unlock() + }() + + userKeystore := keystore.CreateTestKeystore(t) + + username := "bobby" + password := "StrnasfqewiurPasswdn56d" + if err := userKeystore.AddUser(username, password); err != nil { + t.Fatal(err) + } + + vm.ctx.Keystore = userKeystore.NewBlockchainKeyStore(vm.ctx.ChainID) + _, err := vm.ctx.Keystore.GetDatabase(username, password) + if err != nil { + t.Fatal(err) + } + + factory := crypto.FactorySECP256K1R{} + skIntf, err := factory.NewPrivateKey() + if err != nil { + t.Fatalf("problem generating private key: %w", err) + } + sk := skIntf.(*crypto.PrivateKeySECP256K1R) + + args := ImportKeyArgs{ + Username: username, + Password: password, + PrivateKey: formatting.CB58{Bytes: sk.Bytes()}, + } + reply := ImportKeyReply{} + if err = s.ImportKey(nil, &args, &reply); err != nil { + t.Fatal(err) + } +} + +func TestImportAvmKeyNoDuplicates(t *testing.T) { + _, vm, s := setup(t) + defer func() { + vm.Shutdown() + ctx.Lock.Unlock() + }() + + userKeystore := keystore.CreateTestKeystore(t) + + username := "bobby" + password := "StrnasfqewiurPasswdn56d" + if err := userKeystore.AddUser(username, password); err != nil { + t.Fatal(err) + } + + vm.ctx.Keystore = userKeystore.NewBlockchainKeyStore(vm.ctx.ChainID) + _, err := vm.ctx.Keystore.GetDatabase(username, password) + if err != nil { + t.Fatal(err) + } + + factory := crypto.FactorySECP256K1R{} + skIntf, err := factory.NewPrivateKey() + if err != nil { + t.Fatalf("problem generating private key: %w", err) + } + sk := skIntf.(*crypto.PrivateKeySECP256K1R) + + args := ImportKeyArgs{ + Username: username, + Password: password, + PrivateKey: formatting.CB58{Bytes: sk.Bytes()}, + } + reply := ImportKeyReply{} + if err = s.ImportKey(nil, &args, &reply); err != nil { + t.Fatal(err) + } + + expectedAddress := vm.Format(sk.PublicKey().Address().Bytes()) + + if reply.Address != expectedAddress { + t.Fatalf("Reply address: %s did not match expected address: %s", reply.Address, expectedAddress) + } + + reply2 := ImportKeyReply{} + if err = s.ImportKey(nil, &args, &reply2); err != nil { + t.Fatal(err) + } + + if reply2.Address != expectedAddress { + t.Fatalf("Reply address: %s did not match expected address: %s", reply2.Address, expectedAddress) + } + + addrsArgs := ListAddressesArgs{ + Username: username, + Password: password, + } + addrsReply := ListAddressesResponse{} + if err := s.ListAddresses(nil, &addrsArgs, &addrsReply); err != nil { + t.Fatal(err) + } + + if len(addrsReply.Addresses) != 1 { + t.Fatal("Importing the same key twice created duplicate addresses") + } + + if addrsReply.Addresses[0] != expectedAddress { + t.Fatal("List addresses returned an incorrect address") + } +} diff --git a/vms/avm/static_service.go b/vms/avm/static_service.go index 3fd58f3..48b58a9 100644 --- a/vms/avm/static_service.go +++ b/vms/avm/static_service.go @@ -11,7 +11,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/wrappers" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" cjson "github.com/ava-labs/gecko/utils/json" diff --git a/vms/avm/tx.go b/vms/avm/tx.go index c35fd80..f1d0b71 100644 --- a/vms/avm/tx.go +++ b/vms/avm/tx.go @@ -10,7 +10,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/avm/tx_test.go b/vms/avm/tx_test.go index 2f269e9..53e20de 100644 --- a/vms/avm/tx_test.go +++ b/vms/avm/tx_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/avm/vm.go b/vms/avm/vm.go index b7f7252..4c0820d 100644 --- a/vms/avm/vm.go +++ b/vms/avm/vm.go @@ -25,7 +25,7 @@ import ( "github.com/ava-labs/gecko/utils/timer" "github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" cjson "github.com/ava-labs/gecko/utils/json" ) diff --git a/vms/components/ava/asset_test.go b/vms/components/ava/asset_test.go index 40d6ea8..79ae7d5 100644 --- a/vms/components/ava/asset_test.go +++ b/vms/components/ava/asset_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/ava-labs/gecko/ids" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) func TestAssetVerifyNil(t *testing.T) { diff --git a/vms/components/ava/prefixed_state.go b/vms/components/ava/prefixed_state.go index 92b3491..9906381 100644 --- a/vms/components/ava/prefixed_state.go +++ b/vms/components/ava/prefixed_state.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow/choices" "github.com/ava-labs/gecko/utils/hashing" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) // Addressable is the interface a feature extension must provide to be able to diff --git a/vms/components/ava/prefixed_state_test.go b/vms/components/ava/prefixed_state_test.go index 06cb1df..d3019d5 100644 --- a/vms/components/ava/prefixed_state_test.go +++ b/vms/components/ava/prefixed_state_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/database/memdb" "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/hashing" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/stretchr/testify/assert" ) diff --git a/vms/components/ava/state.go b/vms/components/ava/state.go index fc3b929..df724a4 100644 --- a/vms/components/ava/state.go +++ b/vms/components/ava/state.go @@ -11,7 +11,7 @@ import ( "github.com/ava-labs/gecko/database/prefixdb" "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow/choices" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) var ( diff --git a/vms/components/ava/transferables.go b/vms/components/ava/transferables.go index 4aa906d..85c2414 100644 --- a/vms/components/ava/transferables.go +++ b/vms/components/ava/transferables.go @@ -10,7 +10,7 @@ import ( "github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils/crypto" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/verify" ) diff --git a/vms/components/ava/transferables_test.go b/vms/components/ava/transferables_test.go index 80205a6..08d7b69 100644 --- a/vms/components/ava/transferables_test.go +++ b/vms/components/ava/transferables_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/formatting" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/components/ava/utxo_id_test.go b/vms/components/ava/utxo_id_test.go index 7944961..d1be00f 100644 --- a/vms/components/ava/utxo_id_test.go +++ b/vms/components/ava/utxo_id_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/ava-labs/gecko/ids" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) func TestUTXOIDVerifyNil(t *testing.T) { diff --git a/vms/components/ava/utxo_test.go b/vms/components/ava/utxo_test.go index 07b067a..151e219 100644 --- a/vms/components/ava/utxo_test.go +++ b/vms/components/ava/utxo_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/utils/formatting" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/components/codec/codec.go b/vms/components/codec/codec.go deleted file mode 100644 index 72192cb..0000000 --- a/vms/components/codec/codec.go +++ /dev/null @@ -1,345 +0,0 @@ -// (c) 2019-2020, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package codec - -import ( - "errors" - "fmt" - "reflect" - "unicode" - - "github.com/ava-labs/gecko/utils/wrappers" -) - -const ( - defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal() - defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal() -) - -// ErrBadCodec is returned when one tries to perform an operation -// using an unknown codec -var ( - errBadCodec = errors.New("wrong or unknown codec used") - errNil = errors.New("can't marshal nil value") - errUnmarshalNil = errors.New("can't unmarshal into nil") - errNeedPointer = errors.New("must unmarshal into a pointer") - errMarshalUnregisteredType = errors.New("can't marshal an unregistered type") - errUnmarshalUnregisteredType = errors.New("can't unmarshal an unregistered type") - errUnknownType = errors.New("don't know how to marshal/unmarshal this type") - errMarshalUnexportedField = errors.New("can't serialize an unexported field") - errUnmarshalUnexportedField = errors.New("can't deserialize into an unexported field") - errOutOfMemory = errors.New("out of memory") - errSliceTooLarge = errors.New("slice too large") -) - -// Codec handles marshaling and unmarshaling of structs -type codec struct { - maxSize int - maxSliceLen int - - typeIDToType map[uint32]reflect.Type - typeToTypeID map[reflect.Type]uint32 -} - -// Codec marshals and unmarshals -type Codec interface { - RegisterType(interface{}) error - Marshal(interface{}) ([]byte, error) - Unmarshal([]byte, interface{}) error -} - -// New returns a new codec -func New(maxSize, maxSliceLen int) Codec { - return codec{ - maxSize: maxSize, - maxSliceLen: maxSliceLen, - typeIDToType: map[uint32]reflect.Type{}, - typeToTypeID: map[reflect.Type]uint32{}, - } -} - -// NewDefault returns a new codec with reasonable default values -func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) } - -// RegisterType is used to register types that may be unmarshaled into an interface typed value -// [val] is a value of the type being registered -func (c codec) RegisterType(val interface{}) error { - valType := reflect.TypeOf(val) - if _, exists := c.typeToTypeID[valType]; exists { - return fmt.Errorf("type %v has already been registered", valType) - } - c.typeIDToType[uint32(len(c.typeIDToType))] = reflect.TypeOf(val) - c.typeToTypeID[valType] = uint32(len(c.typeIDToType) - 1) - return nil -} - -// A few notes: -// 1) See codec_test.go for examples of usage -// 2) We use "marshal" and "serialize" interchangeably, and "unmarshal" and "deserialize" interchangeably -// 3) To include a field of a struct in the serialized form, add the tag `serialize:"true"` to it -// 4) These typed members of a struct may be serialized: -// bool, string, uint[8,16,32,64, int[8,16,32,64], -// structs, slices, arrays, interface. -// structs, slices and arrays can only be serialized if their constituent parts can be. -// 5) To marshal an interface typed value, you must pass a _pointer_ to the value -// 6) If you want to be able to unmarshal into an interface typed value, -// you must call codec.RegisterType([instance of the type that fulfills the interface]). -// 7) nil slices will be unmarshaled as an empty slice of the appropriate type -// 8) Serialized fields must be exported - -// Marshal returns the byte representation of [value] -// If you want to marshal an interface, [value] must be a pointer -// to the interface -func (c codec) Marshal(value interface{}) ([]byte, error) { - if value == nil { - return nil, errNil - } - - return c.marshal(reflect.ValueOf(value)) -} - -// Marshal [value] to bytes -func (c codec) marshal(value reflect.Value) ([]byte, error) { - p := wrappers.Packer{MaxSize: c.maxSize, Bytes: []byte{}} - t := value.Type() - - valueKind := value.Kind() - switch valueKind { - case reflect.Interface, reflect.Ptr, reflect.Slice: - if value.IsNil() { - return nil, errNil - } - } - - switch valueKind { - case reflect.Uint8: - p.PackByte(uint8(value.Uint())) - return p.Bytes, p.Err - case reflect.Int8: - p.PackByte(uint8(value.Int())) - return p.Bytes, p.Err - case reflect.Uint16: - p.PackShort(uint16(value.Uint())) - return p.Bytes, p.Err - case reflect.Int16: - p.PackShort(uint16(value.Int())) - return p.Bytes, p.Err - case reflect.Uint32: - p.PackInt(uint32(value.Uint())) - return p.Bytes, p.Err - case reflect.Int32: - p.PackInt(uint32(value.Int())) - return p.Bytes, p.Err - case reflect.Uint64: - p.PackLong(value.Uint()) - return p.Bytes, p.Err - case reflect.Int64: - p.PackLong(uint64(value.Int())) - return p.Bytes, p.Err - case reflect.Uintptr, reflect.Ptr: - return c.marshal(value.Elem()) - case reflect.String: - p.PackStr(value.String()) - return p.Bytes, p.Err - case reflect.Bool: - p.PackBool(value.Bool()) - return p.Bytes, p.Err - case reflect.Interface: - typeID, ok := c.typeToTypeID[reflect.TypeOf(value.Interface())] // Get the type ID of the value being marshaled - if !ok { - return nil, fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(value.Interface()).String()) - } - p.PackInt(typeID) - bytes, err := c.Marshal(value.Interface()) - if err != nil { - return nil, err - } - p.PackFixedBytes(bytes) - if p.Errored() { - return nil, p.Err - } - return p.Bytes, err - case reflect.Array, reflect.Slice: - numElts := value.Len() // # elements in the slice/array (assumed to be <= 2^31 - 1) - // If this is a slice, pack the number of elements in the slice - if valueKind == reflect.Slice { - p.PackInt(uint32(numElts)) - } - for i := 0; i < numElts; i++ { // Pack each element in the slice/array - eltBytes, err := c.marshal(value.Index(i)) - if err != nil { - return nil, err - } - p.PackFixedBytes(eltBytes) - } - return p.Bytes, p.Err - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { // Go through all fields of this struct - field := t.Field(i) - if !shouldSerialize(field) { // Skip fields we don't need to serialize - continue - } - if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields - return nil, errMarshalUnexportedField - } - fieldVal := value.Field(i) // The field we're serializing - if fieldVal.Kind() == reflect.Slice && fieldVal.IsNil() { - p.PackInt(0) - continue - } - fieldBytes, err := c.marshal(fieldVal) // Serialize the field - if err != nil { - return nil, err - } - p.PackFixedBytes(fieldBytes) - } - return p.Bytes, p.Err - case reflect.Invalid: - return nil, errUnmarshalNil - default: - return nil, errUnknownType - } -} - -// Unmarshal unmarshals [bytes] into [dest], where -// [dest] must be a pointer or interface -func (c codec) Unmarshal(bytes []byte, dest interface{}) error { - p := &wrappers.Packer{Bytes: bytes} - - if len(bytes) > c.maxSize { - return errSliceTooLarge - } - - if dest == nil { - return errNil - } - - destPtr := reflect.ValueOf(dest) - - if destPtr.Kind() != reflect.Ptr { - return errNeedPointer - } - - destVal := destPtr.Elem() - - err := c.unmarshal(p, destVal) - if err != nil { - return err - } - - if p.Offset != len(p.Bytes) { - return fmt.Errorf("has %d leftover bytes after unmarshalling", len(p.Bytes)-p.Offset) - } - return nil -} - -// Unmarshal bytes from [p] into [field] -// [field] must be addressable -func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error { - kind := field.Kind() - switch kind { - case reflect.Uint8: - field.SetUint(uint64(p.UnpackByte())) - case reflect.Int8: - field.SetInt(int64(p.UnpackByte())) - case reflect.Uint16: - field.SetUint(uint64(p.UnpackShort())) - case reflect.Int16: - field.SetInt(int64(p.UnpackShort())) - case reflect.Uint32: - field.SetUint(uint64(p.UnpackInt())) - case reflect.Int32: - field.SetInt(int64(p.UnpackInt())) - case reflect.Uint64: - field.SetUint(p.UnpackLong()) - case reflect.Int64: - field.SetInt(int64(p.UnpackLong())) - case reflect.Bool: - field.SetBool(p.UnpackBool()) - case reflect.Slice: - sliceLen := int(p.UnpackInt()) // number of elements in the slice - if sliceLen < 0 || sliceLen > c.maxSliceLen { - return errSliceTooLarge - } - - // First set [field] to be a slice of the appropriate type/capacity (right now [field] is nil) - slice := reflect.MakeSlice(field.Type(), sliceLen, sliceLen) - field.Set(slice) - // Unmarshal each element into the appropriate index of the slice - for i := 0; i < sliceLen; i++ { - if err := c.unmarshal(p, field.Index(i)); err != nil { - return err - } - } - case reflect.Array: - for i := 0; i < field.Len(); i++ { - if err := c.unmarshal(p, field.Index(i)); err != nil { - return err - } - } - case reflect.String: - field.SetString(p.UnpackStr()) - case reflect.Interface: - // Get the type ID - typeID := p.UnpackInt() - // Get a struct that implements the interface - typ, ok := c.typeIDToType[typeID] - if !ok { - return errUnmarshalUnregisteredType - } - // Ensure struct actually does implement the interface - fieldType := field.Type() - if !typ.Implements(fieldType) { - return fmt.Errorf("%s does not implement interface %s", typ, fieldType) - } - concreteInstancePtr := reflect.New(typ) // instance of the proper type - // Unmarshal into the struct - if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil { - return err - } - // And assign the filled struct to the field - field.Set(concreteInstancePtr.Elem()) - case reflect.Struct: - // Type of this struct - structType := reflect.TypeOf(field.Interface()) - // Go through all the fields and umarshal into each - for i := 0; i < structType.NumField(); i++ { - structField := structType.Field(i) - if !shouldSerialize(structField) { // Skip fields we don't need to unmarshal - continue - } - if unicode.IsLower(rune(structField.Name[0])) { // Only unmarshal into exported field - return errUnmarshalUnexportedField - } - field := field.Field(i) // Get the field - if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field - return err - } - if p.Errored() { // If there was an error just return immediately - return p.Err - } - } - case reflect.Ptr: - // Get the type this pointer points to - underlyingType := field.Type().Elem() - // Create a new pointer to a new value of the underlying type - underlyingValue := reflect.New(underlyingType) - // Fill the value - if err := c.unmarshal(p, underlyingValue.Elem()); err != nil { - return err - } - // Assign to the top-level struct's member - field.Set(underlyingValue) - case reflect.Invalid: - return errUnmarshalNil - default: - return errUnknownType - } - return p.Err -} - -// Returns true iff [field] should be serialized -func shouldSerialize(field reflect.StructField) bool { - return field.Tag.Get("serialize") == "true" -} diff --git a/vms/nftfx/fx_test.go b/vms/nftfx/fx_test.go index d965902..0cfbd87 100644 --- a/vms/nftfx/fx_test.go +++ b/vms/nftfx/fx_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/timer" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index de2d41b..9913608 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -234,7 +234,7 @@ type GetCurrentValidatorsReply struct { // GetCurrentValidators returns the list of current validators func (service *Service) GetCurrentValidators(_ *http.Request, args *GetCurrentValidatorsArgs, reply *GetCurrentValidatorsReply) error { - service.vm.Ctx.Log.Debug("GetCurrentValidators called") + service.vm.Ctx.Log.Info("Platform: GetCurrentValidators called") if args.SubnetID.IsZero() { args.SubnetID = DefaultSubnetID @@ -298,7 +298,7 @@ type GetPendingValidatorsReply struct { // GetPendingValidators returns the list of current validators func (service *Service) GetPendingValidators(_ *http.Request, args *GetPendingValidatorsArgs, reply *GetPendingValidatorsReply) error { - service.vm.Ctx.Log.Debug("GetPendingValidators called") + service.vm.Ctx.Log.Info("Platform: GetPendingValidators called") if args.SubnetID.IsZero() { args.SubnetID = DefaultSubnetID @@ -360,7 +360,7 @@ type SampleValidatorsReply struct { // SampleValidators returns a sampling of the list of current validators func (service *Service) SampleValidators(_ *http.Request, args *SampleValidatorsArgs, reply *SampleValidatorsReply) error { - service.vm.Ctx.Log.Debug("Sample called with {Size = %d}", args.Size) + service.vm.Ctx.Log.Info("Platform: SampleValidators called with {Size = %d}", args.Size) if args.SubnetID.IsZero() { args.SubnetID = DefaultSubnetID @@ -437,7 +437,7 @@ type ListAccountsReply struct { // ListAccounts lists all of the accounts controlled by [args.Username] func (service *Service) ListAccounts(_ *http.Request, args *ListAccountsArgs, reply *ListAccountsReply) error { - service.vm.Ctx.Log.Debug("listAccounts called for user '%s'", args.Username) + service.vm.Ctx.Log.Info("Platform: ListAccounts called for user '%s'", args.Username) // db holds the user's info that pertains to the Platform Chain userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password) @@ -499,7 +499,7 @@ type CreateAccountReply struct { // The account's ID is [privKey].PublicKey().Address(), where [privKey] is a // private key controlled by the user. func (service *Service) CreateAccount(_ *http.Request, args *CreateAccountArgs, reply *CreateAccountReply) error { - service.vm.Ctx.Log.Debug("createAccount called for user '%s'", args.Username) + service.vm.Ctx.Log.Info("Platform: CreateAccount called for user '%s'", args.Username) // userDB holds the user's info that pertains to the Platform Chain userDB, err := service.vm.Ctx.Keystore.GetDatabase(args.Username, args.Password) @@ -569,7 +569,7 @@ type AddDefaultSubnetValidatorArgs struct { // AddDefaultSubnetValidator returns an unsigned transaction to add a validator to the default subnet // The returned unsigned transaction should be signed using Sign() func (service *Service) AddDefaultSubnetValidator(_ *http.Request, args *AddDefaultSubnetValidatorArgs, reply *CreateTxResponse) error { - service.vm.Ctx.Log.Debug("AddDefaultSubnetValidator called") + service.vm.Ctx.Log.Info("Platform: AddDefaultSubnetValidator called") switch { case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID @@ -626,7 +626,7 @@ type AddDefaultSubnetDelegatorArgs struct { // to the default subnet // The returned unsigned transaction should be signed using Sign() func (service *Service) AddDefaultSubnetDelegator(_ *http.Request, args *AddDefaultSubnetDelegatorArgs, reply *CreateTxResponse) error { - service.vm.Ctx.Log.Debug("AddDefaultSubnetDelegator called") + service.vm.Ctx.Log.Info("Platform: AddDefaultSubnetDelegator called") switch { case args.ID.IsZero(): // If ID unspecified, use this node's ID as validator ID @@ -741,7 +741,7 @@ type CreateSubnetArgs struct { // CreateSubnet returns an unsigned transaction to create a new subnet. // The unsigned transaction must be signed with the key of [args.Payer] func (service *Service) CreateSubnet(_ *http.Request, args *CreateSubnetArgs, response *CreateTxResponse) error { - service.vm.Ctx.Log.Debug("platform.createSubnet called") + service.vm.Ctx.Log.Info("Platform: CreateSubnet called") switch { case args.PayerNonce == 0: @@ -796,7 +796,7 @@ type ExportAVAArgs struct { // The unsigned transaction must be signed with the key of the account exporting the AVA // and paying the transaction fee func (service *Service) ExportAVA(_ *http.Request, args *ExportAVAArgs, response *CreateTxResponse) error { - service.vm.Ctx.Log.Debug("platform.ExportAVA called") + service.vm.Ctx.Log.Info("Platform: ExportAVA called") switch { case args.PayerNonce == 0: @@ -858,7 +858,7 @@ type SignResponse struct { // Sign [args.bytes] func (service *Service) Sign(_ *http.Request, args *SignArgs, reply *SignResponse) error { - service.vm.Ctx.Log.Debug("sign called") + service.vm.Ctx.Log.Info("Platform: Sign called") if args.Signer == "" { return errNilSigner @@ -938,7 +938,7 @@ func (service *Service) signAddDefaultSubnetValidatorTx(tx *addDefaultSubnetVali // Sign [unsigned] with [key] func (service *Service) signAddDefaultSubnetDelegatorTx(tx *addDefaultSubnetDelegatorTx, key *crypto.PrivateKeySECP256K1R) (*addDefaultSubnetDelegatorTx, error) { - service.vm.Ctx.Log.Debug("signAddDefaultSubnetValidatorTx called") + service.vm.Ctx.Log.Debug("signAddDefaultSubnetDelegatorTx called") // TODO: Should we check if tx is already signed? unsignedIntf := interface{}(&tx.UnsignedAddDefaultSubnetDelegatorTx) @@ -961,7 +961,7 @@ func (service *Service) signAddDefaultSubnetDelegatorTx(tx *addDefaultSubnetDele // Sign [xt] with [key] func (service *Service) signCreateSubnetTx(tx *CreateSubnetTx, key *crypto.PrivateKeySECP256K1R) (*CreateSubnetTx, error) { - service.vm.Ctx.Log.Debug("signAddDefaultSubnetValidatorTx called") + service.vm.Ctx.Log.Debug("signCreateSubnetTx called") // TODO: Should we check if tx is already signed? unsignedIntf := interface{}(&tx.UnsignedCreateSubnetTx) @@ -984,7 +984,7 @@ func (service *Service) signCreateSubnetTx(tx *CreateSubnetTx, key *crypto.Priva // Sign [tx] with [key] func (service *Service) signExportTx(tx *ExportTx, key *crypto.PrivateKeySECP256K1R) (*ExportTx, error) { - service.vm.Ctx.Log.Debug("platform.signAddDefaultSubnetValidatorTx called") + service.vm.Ctx.Log.Debug("signExportTx called") // TODO: Should we check if tx is already signed? unsignedIntf := interface{}(&tx.UnsignedExportTx) @@ -1075,7 +1075,7 @@ type ImportAVAArgs struct { // The AVA must have already been exported from the X-Chain. // The unsigned transaction must be signed with the key of the tx fee payer. func (service *Service) ImportAVA(_ *http.Request, args *ImportAVAArgs, response *SignResponse) error { - service.vm.Ctx.Log.Debug("platform.ImportAVA called") + service.vm.Ctx.Log.Info("Platform: ImportAVA called") switch { case args.To == "": @@ -1263,7 +1263,7 @@ type IssueTxResponse struct { // IssueTx issues the transaction [args.Tx] to the network func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *IssueTxResponse) error { - service.vm.Ctx.Log.Debug("issueTx called") + service.vm.Ctx.Log.Info("Platform: IssueTx called") genTx := genericTx{} if err := Codec.Unmarshal(args.Tx.Bytes, &genTx); err != nil { @@ -1275,7 +1275,7 @@ func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *Is if err := tx.initialize(service.vm); err != nil { return fmt.Errorf("error initializing tx: %s", err) } - service.vm.unissuedEvents.Push(tx) + service.vm.unissuedEvents.Add(tx) response.TxID = tx.ID() case DecisionTx: if err := tx.initialize(service.vm); err != nil { @@ -1290,7 +1290,7 @@ func (service *Service) IssueTx(_ *http.Request, args *IssueTxArgs, response *Is service.vm.unissuedAtomicTxs = append(service.vm.unissuedAtomicTxs, tx) response.TxID = tx.ID() default: - return errors.New("Could not parse given tx. Must be a TimedTx, DecisionTx, or AtomicTx") + return errors.New("Could not parse given tx. Provided tx needs to be a TimedTx, DecisionTx, or AtomicTx") } service.vm.resetTimer() @@ -1327,7 +1327,7 @@ type CreateBlockchainArgs struct { // CreateBlockchain returns an unsigned transaction to create a new blockchain // Must be signed with the Subnet's control keys and with a key that pays the transaction fee before issuance func (service *Service) CreateBlockchain(_ *http.Request, args *CreateBlockchainArgs, response *CreateTxResponse) error { - service.vm.Ctx.Log.Debug("createBlockchain called") + service.vm.Ctx.Log.Info("Platform: CreateBlockchain called") switch { case args.PayerNonce == 0: @@ -1410,7 +1410,7 @@ type GetBlockchainStatusReply struct { // GetBlockchainStatus gets the status of a blockchain with the ID [args.BlockchainID]. func (service *Service) GetBlockchainStatus(_ *http.Request, args *GetBlockchainStatusArgs, reply *GetBlockchainStatusReply) error { - service.vm.Ctx.Log.Debug("getBlockchainStatus called") + service.vm.Ctx.Log.Info("Platform: GetBlockchainStatus called") switch { case args.BlockchainID == "": @@ -1490,7 +1490,7 @@ type ValidatedByResponse struct { // ValidatedBy returns the ID of the Subnet that validates [args.BlockchainID] func (service *Service) ValidatedBy(_ *http.Request, args *ValidatedByArgs, response *ValidatedByResponse) error { - service.vm.Ctx.Log.Debug("validatedBy called") + service.vm.Ctx.Log.Info("Platform: ValidatedBy called") switch { case args.BlockchainID == "": @@ -1522,7 +1522,7 @@ type ValidatesResponse struct { // Validates returns the IDs of the blockchains validated by [args.SubnetID] func (service *Service) Validates(_ *http.Request, args *ValidatesArgs, response *ValidatesResponse) error { - service.vm.Ctx.Log.Debug("validates called") + service.vm.Ctx.Log.Info("Platform: Validates called") switch { case args.SubnetID == "": @@ -1576,7 +1576,7 @@ type GetBlockchainsResponse struct { // GetBlockchains returns all of the blockchains that exist func (service *Service) GetBlockchains(_ *http.Request, args *struct{}, response *GetBlockchainsResponse) error { - service.vm.Ctx.Log.Debug("getBlockchains called") + service.vm.Ctx.Log.Info("Platform: GetBlockchains called") chains, err := service.vm.getChains(service.vm.DB) if err != nil { diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 9ac4a6c..b6e4a31 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -6,6 +6,9 @@ package platformvm import ( "encoding/json" "testing" + "time" + + "github.com/ava-labs/gecko/utils/formatting" ) func TestAddDefaultSubnetValidator(t *testing.T) { @@ -50,3 +53,184 @@ func TestImportKey(t *testing.T) { t.Fatal(err) } } + +func TestIssueTxKeepsTimedEventsSorted(t *testing.T) { + vm := defaultVM() + vm.Ctx.Lock.Lock() + defer func() { + vm.Shutdown() + vm.Ctx.Lock.Unlock() + }() + + service := Service{vm: vm} + + pendingValidatorStartTime1 := defaultGenesisTime.Add(3 * time.Second) + pendingValidatorEndTime1 := pendingValidatorStartTime1.Add(MinimumStakingDuration) + nodeIDKey1, _ := vm.factory.NewPrivateKey() + nodeID1 := nodeIDKey1.PublicKey().Address() + addPendingValidatorTx1, err := vm.newAddDefaultSubnetValidatorTx( + defaultNonce+1, + defaultStakeAmount, + uint64(pendingValidatorStartTime1.Unix()), + uint64(pendingValidatorEndTime1.Unix()), + nodeID1, + nodeID1, + NumberOfShares, + testNetworkID, + defaultKey, + ) + if err != nil { + t.Fatal(err) + } + + txBytes1, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx1}) + if err != nil { + t.Fatal(err) + } + + args1 := &IssueTxArgs{} + args1.Tx = formatting.CB58{Bytes: txBytes1} + reply1 := IssueTxResponse{} + + err = service.IssueTx(nil, args1, &reply1) + if err != nil { + t.Fatal(err) + } + + pendingValidatorStartTime2 := defaultGenesisTime.Add(2 * time.Second) + pendingValidatorEndTime2 := pendingValidatorStartTime2.Add(MinimumStakingDuration) + nodeIDKey2, _ := vm.factory.NewPrivateKey() + nodeID2 := nodeIDKey2.PublicKey().Address() + addPendingValidatorTx2, err := vm.newAddDefaultSubnetValidatorTx( + defaultNonce+1, + defaultStakeAmount, + uint64(pendingValidatorStartTime2.Unix()), + uint64(pendingValidatorEndTime2.Unix()), + nodeID2, + nodeID2, + NumberOfShares, + testNetworkID, + defaultKey, + ) + if err != nil { + t.Fatal(err) + } + + txBytes2, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx2}) + if err != nil { + t.Fatal(err) + } + + args2 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes2}} + reply2 := IssueTxResponse{} + + err = service.IssueTx(nil, &args2, &reply2) + if err != nil { + t.Fatal(err) + } + + pendingValidatorStartTime3 := defaultGenesisTime.Add(10 * time.Second) + pendingValidatorEndTime3 := pendingValidatorStartTime3.Add(MinimumStakingDuration) + nodeIDKey3, _ := vm.factory.NewPrivateKey() + nodeID3 := nodeIDKey3.PublicKey().Address() + addPendingValidatorTx3, err := vm.newAddDefaultSubnetValidatorTx( + defaultNonce+1, + defaultStakeAmount, + uint64(pendingValidatorStartTime3.Unix()), + uint64(pendingValidatorEndTime3.Unix()), + nodeID3, + nodeID3, + NumberOfShares, + testNetworkID, + defaultKey, + ) + if err != nil { + t.Fatal(err) + } + + txBytes3, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx3}) + if err != nil { + t.Fatal(err) + } + + args3 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes3}} + reply3 := IssueTxResponse{} + + err = service.IssueTx(nil, &args3, &reply3) + if err != nil { + t.Fatal(err) + } + + pendingValidatorStartTime4 := defaultGenesisTime.Add(1 * time.Second) + pendingValidatorEndTime4 := pendingValidatorStartTime4.Add(MinimumStakingDuration) + nodeIDKey4, _ := vm.factory.NewPrivateKey() + nodeID4 := nodeIDKey4.PublicKey().Address() + addPendingValidatorTx4, err := vm.newAddDefaultSubnetValidatorTx( + defaultNonce+1, + defaultStakeAmount, + uint64(pendingValidatorStartTime4.Unix()), + uint64(pendingValidatorEndTime4.Unix()), + nodeID4, + nodeID4, + NumberOfShares, + testNetworkID, + defaultKey, + ) + if err != nil { + t.Fatal(err) + } + + txBytes4, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx4}) + if err != nil { + t.Fatal(err) + } + + args4 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes4}} + reply4 := IssueTxResponse{} + + err = service.IssueTx(nil, &args4, &reply4) + if err != nil { + t.Fatal(err) + } + + pendingValidatorStartTime5 := defaultGenesisTime.Add(50 * time.Second) + pendingValidatorEndTime5 := pendingValidatorStartTime5.Add(MinimumStakingDuration) + nodeIDKey5, _ := vm.factory.NewPrivateKey() + nodeID5 := nodeIDKey5.PublicKey().Address() + addPendingValidatorTx5, err := vm.newAddDefaultSubnetValidatorTx( + defaultNonce+1, + defaultStakeAmount, + uint64(pendingValidatorStartTime5.Unix()), + uint64(pendingValidatorEndTime5.Unix()), + nodeID5, + nodeID5, + NumberOfShares, + testNetworkID, + defaultKey, + ) + if err != nil { + t.Fatal(err) + } + + txBytes5, err := Codec.Marshal(genericTx{Tx: addPendingValidatorTx5}) + if err != nil { + t.Fatal(err) + } + + args5 := IssueTxArgs{Tx: formatting.CB58{Bytes: txBytes5}} + reply5 := IssueTxResponse{} + + err = service.IssueTx(nil, &args5, &reply5) + if err != nil { + t.Fatal(err) + } + + currentEvent := vm.unissuedEvents.Remove() + for vm.unissuedEvents.Len() > 0 { + nextEvent := vm.unissuedEvents.Remove() + if !currentEvent.StartTime().Before(nextEvent.StartTime()) { + t.Fatal("IssueTx does not keep event heap ordered") + } + currentEvent = nextEvent + } +} diff --git a/vms/platformvm/static_service.go b/vms/platformvm/static_service.go index 8acc0a9..1cdeeca 100644 --- a/vms/platformvm/static_service.go +++ b/vms/platformvm/static_service.go @@ -4,7 +4,6 @@ package platformvm import ( - "container/heap" "errors" "net/http" @@ -174,8 +173,8 @@ func (*StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, repl return errAccountHasNoValue } accounts = append(accounts, newAccount( - account.Address, // ID - 0, // nonce + account.Address, // ID + 0, // nonce uint64(account.Balance), // balance )) } @@ -210,7 +209,7 @@ func (*StaticService) BuildGenesis(_ *http.Request, args *BuildGenesisArgs, repl return err } - heap.Push(validators, tx) + validators.Add(tx) } // Specify the chains that exist at genesis. diff --git a/vms/platformvm/static_service_test.go b/vms/platformvm/static_service_test.go index 04433ff..3f64a9b 100644 --- a/vms/platformvm/static_service_test.go +++ b/vms/platformvm/static_service_test.go @@ -111,3 +111,77 @@ func TestBuildGenesisInvalidEndtime(t *testing.T) { t.Fatalf("Should have errored due to an invalid end time") } } + +func TestBuildGenesisReturnsSortedValidators(t *testing.T) { + id := ids.NewShortID([20]byte{1}) + account := APIAccount{ + Address: id, + Balance: 123456789, + } + + weight := json.Uint64(987654321) + validator1 := APIDefaultSubnetValidator{ + APIValidator: APIValidator{ + StartTime: 0, + EndTime: 20, + Weight: &weight, + ID: id, + }, + Destination: id, + } + + validator2 := APIDefaultSubnetValidator{ + APIValidator: APIValidator{ + StartTime: 3, + EndTime: 15, + Weight: &weight, + ID: id, + }, + Destination: id, + } + + validator3 := APIDefaultSubnetValidator{ + APIValidator: APIValidator{ + StartTime: 1, + EndTime: 10, + Weight: &weight, + ID: id, + }, + Destination: id, + } + + args := BuildGenesisArgs{ + Accounts: []APIAccount{ + account, + }, + Validators: []APIDefaultSubnetValidator{ + validator1, + validator2, + validator3, + }, + Time: 5, + } + reply := BuildGenesisReply{} + + ss := StaticService{} + if err := ss.BuildGenesis(nil, &args, &reply); err != nil { + t.Fatalf("BuildGenesis should not have errored") + } + + genesis := &Genesis{} + if err := Codec.Unmarshal(reply.Bytes.Bytes, genesis); err != nil { + t.Fatal(err) + } + validators := genesis.Validators + if validators.Len() == 0 { + t.Fatal("Validators should contain 3 validators") + } + currentValidator := validators.Remove() + for validators.Len() > 0 { + nextValidator := validators.Remove() + if currentValidator.EndTime().Unix() > nextValidator.EndTime().Unix() { + t.Fatalf("Validators returned by genesis should be a min heap sorted by end time") + } + currentValidator = nextValidator + } +} diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index 9f1ce53..baff040 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -4,7 +4,6 @@ package platformvm import ( - "container/heap" "errors" "fmt" "time" @@ -27,7 +26,7 @@ import ( "github.com/ava-labs/gecko/utils/units" "github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/core" "github.com/ava-labs/gecko/vms/secp256k1fx" ) @@ -406,10 +405,14 @@ func (vm *VM) createChain(tx *CreateChainTx) { } // Bootstrapping marks this VM as bootstrapping -func (vm *VM) Bootstrapping() error { return nil } +func (vm *VM) Bootstrapping() error { + return vm.fx.Bootstrapping() +} // Bootstrapped marks this VM as bootstrapped -func (vm *VM) Bootstrapped() error { return nil } +func (vm *VM) Bootstrapped() error { + return vm.fx.Bootstrapped() +} // Shutdown this blockchain func (vm *VM) Shutdown() error { @@ -698,7 +701,7 @@ func (vm *VM) resetTimer() { vm.SnowmanVM.NotifyBlockReady() // Should issue a ProposeAddValidator return } - // If the tx doesn't meet the syncrony bound, drop it + // If the tx doesn't meet the synchrony bound, drop it vm.unissuedEvents.Remove() vm.Ctx.Log.Debug("dropping tx to add validator because its start time has passed") } @@ -780,8 +783,8 @@ func (vm *VM) calculateValidators(db database.Database, timestamp time.Time, sub if timestamp.Before(nextTx.StartTime()) { break } - heap.Push(current, nextTx) - heap.Pop(pending) + current.Add(nextTx) + pending.Remove() started.Add(nextTx.Vdr().ID()) } return current, pending, started, stopped, nil diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index b8bb47c..dcee89a 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -5,7 +5,6 @@ package platformvm import ( "bytes" - "container/heap" "errors" "testing" "time" @@ -193,6 +192,8 @@ func defaultVM() *VM { panic("no subnets found") } // end delete + vm.registerDBTypes() + return vm } @@ -226,7 +227,7 @@ func GenesisCurrentValidators() *EventHeap { testNetworkID, // network ID key, // key paying tx fee and stake ) - heap.Push(validators, validator) + validators.Add(validator) } return validators } @@ -1011,7 +1012,7 @@ func TestCreateSubnet(t *testing.T) { t.Fatal(err) } - vm.unissuedEvents.Push(addValidatorTx) + vm.unissuedEvents.Add(addValidatorTx) blk, err = vm.BuildBlock() // should add validator to the new subnet if err != nil { t.Fatal(err) diff --git a/vms/propertyfx/fx_test.go b/vms/propertyfx/fx_test.go index cfdf5c9..887cf73 100644 --- a/vms/propertyfx/fx_test.go +++ b/vms/propertyfx/fx_test.go @@ -9,7 +9,7 @@ import ( "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/timer" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" ) diff --git a/vms/secp256k1fx/credential_test.go b/vms/secp256k1fx/credential_test.go index 5157fab..e85ce1b 100644 --- a/vms/secp256k1fx/credential_test.go +++ b/vms/secp256k1fx/credential_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/ava-labs/gecko/utils/crypto" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) func TestCredentialVerify(t *testing.T) { diff --git a/vms/secp256k1fx/fx_test.go b/vms/secp256k1fx/fx_test.go index 79e6c89..566b4cb 100644 --- a/vms/secp256k1fx/fx_test.go +++ b/vms/secp256k1fx/fx_test.go @@ -12,7 +12,7 @@ import ( "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/timer" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) var ( diff --git a/vms/secp256k1fx/transer_input_test.go b/vms/secp256k1fx/transer_input_test.go index e954af0..00e894f 100644 --- a/vms/secp256k1fx/transer_input_test.go +++ b/vms/secp256k1fx/transer_input_test.go @@ -7,7 +7,7 @@ import ( "bytes" "testing" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) func TestTransferInputAmount(t *testing.T) { diff --git a/vms/secp256k1fx/transfer_output_test.go b/vms/secp256k1fx/transfer_output_test.go index 7e87875..09bb0ce 100644 --- a/vms/secp256k1fx/transfer_output_test.go +++ b/vms/secp256k1fx/transfer_output_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/ava-labs/gecko/ids" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) func TestOutputAmount(t *testing.T) { diff --git a/vms/secp256k1fx/vm.go b/vms/secp256k1fx/vm.go index bb59166..37aa23b 100644 --- a/vms/secp256k1fx/vm.go +++ b/vms/secp256k1fx/vm.go @@ -6,7 +6,7 @@ package secp256k1fx import ( "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/timer" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" ) // VM that this Fx must be run by diff --git a/vms/timestampvm/vm.go b/vms/timestampvm/vm.go index c571d9a..5376e2f 100644 --- a/vms/timestampvm/vm.go +++ b/vms/timestampvm/vm.go @@ -12,7 +12,7 @@ import ( "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow/consensus/snowman" "github.com/ava-labs/gecko/snow/engine/common" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/components/core" ) diff --git a/xputtest/avmwallet/wallet.go b/xputtest/avmwallet/wallet.go index ef01eb0..c5d2cd9 100644 --- a/xputtest/avmwallet/wallet.go +++ b/xputtest/avmwallet/wallet.go @@ -19,7 +19,7 @@ import ( "github.com/ava-labs/gecko/utils/wrappers" "github.com/ava-labs/gecko/vms/avm" "github.com/ava-labs/gecko/vms/components/ava" - "github.com/ava-labs/gecko/vms/components/codec" + "github.com/ava-labs/gecko/utils/codec" "github.com/ava-labs/gecko/vms/secp256k1fx" )