From 6bd84af8ecbbb5c08bc1c9379e3d19d92ed3de7b Mon Sep 17 00:00:00 2001 From: Dan Laine Date: Thu, 7 May 2020 14:34:32 -0400 Subject: [PATCH] generate staking key at ~/.gecko/staking/staker.key if no key given. --- main/params.go | 109 ++++++++++++++++++++++++++++---------- staking/gen_staker_key.go | 74 ++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 27 deletions(-) create mode 100644 staking/gen_staker_key.go diff --git a/main/params.go b/main/params.go index a216720..f6b9daa 100644 --- a/main/params.go +++ b/main/params.go @@ -19,6 +19,7 @@ import ( "github.com/ava-labs/gecko/nat" "github.com/ava-labs/gecko/node" "github.com/ava-labs/gecko/snow/networking/router" + "github.com/ava-labs/gecko/staking" "github.com/ava-labs/gecko/utils" "github.com/ava-labs/gecko/utils/formatting" "github.com/ava-labs/gecko/utils/hashing" @@ -34,8 +35,14 @@ const ( // Results of parsing the CLI var ( - Config = node.Config{} - Err error + Config = node.Config{} + Err error + defaultStakingKeyPath = "~/.gecko/staking/staker.key" + defaultStakingCertPath = "~/.gecko/staking/staker.crt" +) + +var ( + errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs") ) // GetIPs returns the default IPs for each network @@ -54,17 +61,15 @@ func GetIPs(networkID uint32) []string { } } -var ( - errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs") -) - // Parse the CLI arguments func init() { errs := &wrappers.Errs{} defer func() { Err = errs.Err }() loggingConfig, err := logging.DefaultConfig() - errs.Add(err) + if errs.Add(err); errs.Errored() { + return + } fs := flag.NewFlagSet("gecko", flag.ContinueOnError) @@ -100,8 +105,8 @@ func init() { // Staking: consensusPort := fs.Uint("staking-port", 9651, "Port of the consensus server") fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Require TLS to authenticate staking connections") - fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", "keys/staker.key", "TLS private key file for staking connections") - fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", "keys/staker.crt", "TLS certificate file for staking connections") + fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", defaultStakingKeyPath, "TLS private key for staking") + fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", defaultStakingCertPath, "TLS certificate for staking") // Plugins: fs.StringVar(&Config.PluginDir, "plugin-dir", "./build/plugins", "Plugin directory for Ava VMs") @@ -142,22 +147,24 @@ func init() { } networkID, err := genesis.NetworkID(*networkName) - errs.Add(err) + if errs.Add(err); errs.Errored() { + return + } Config.NetworkID = networkID // DB: - if *db && err == nil { - // TODO: Add better params here - if *dbDir == defaultDbDir { - if *dbDir, err = homedir.Expand(defaultDbDir); err != nil { - errs.Add(fmt.Errorf("couldn't resolve default db path: %v", err)) - } + if *db { + *dbDir, err = homedir.Expand(*dbDir) + if errs.Add(fmt.Errorf("couldn't resolve db path: %w", err)); errs.Errored() { + return } dbPath := path.Join(*dbDir, genesis.NetworkName(Config.NetworkID), dbVersion) db, err := leveldb.New(dbPath, 0, 0, 0) + if errs.Add(fmt.Errorf("couldn't create db: %w", err)); errs.Errored() { + return + } Config.DB = db - errs.Add(err) } else { Config.DB = memdb.New() } @@ -169,7 +176,7 @@ func init() { if *consensusIP == "" { ip, err = Config.Nat.IP() if err != nil { - ip = net.IPv4zero + ip = net.IPv4zero // Couldn't get my IP...set to 0.0.0.0 } } else { ip = net.ParseIP(*consensusIP) @@ -177,7 +184,9 @@ func init() { if ip == nil { errs.Add(fmt.Errorf("Invalid IP Address %s", *consensusIP)) + return } + Config.StakingIP = utils.IPDesc{ IP: ip, Port: uint16(*consensusPort), @@ -190,7 +199,10 @@ func init() { for _, ip := range strings.Split(*bootstrapIPs, ",") { if ip != "" { addr, err := utils.ToIPDesc(ip) - errs.Add(err) + if err != nil { + errs.Add(fmt.Errorf("couldn't parse ip: %w", err)) + return + } Config.BootstrapPeers = append(Config.BootstrapPeers, &node.Peer{ IP: addr, }) @@ -209,20 +221,27 @@ func init() { cb58 := formatting.CB58{} for _, id := range strings.Split(*bootstrapIDs, ",") { if id != "" { - errs.Add(cb58.FromString(id)) - cert, err := ids.ToShortID(cb58.Bytes) - errs.Add(err) - + err = cb58.FromString(id) + if err != nil { + errs.Add(fmt.Errorf("couldn't parse bootstrap peer id to bytes: %w", err)) + return + } + peerID, err := ids.ToShortID(cb58.Bytes) + if err != nil { + errs.Add(fmt.Errorf("couldn't parse bootstrap peer id: %w", err)) + return + } if len(Config.BootstrapPeers) <= i { errs.Add(errBootstrapMismatch) - continue + return } - Config.BootstrapPeers[i].ID = cert + Config.BootstrapPeers[i].ID = peerID i++ } } if len(Config.BootstrapPeers) != i { errs.Add(fmt.Errorf("More bootstrap IPs, %d, provided than bootstrap IDs, %d", len(Config.BootstrapPeers), i)) + return } } else { for _, peer := range Config.BootstrapPeers { @@ -230,6 +249,38 @@ func init() { } } + // Staking + Config.StakingKeyFile, err = homedir.Expand(Config.StakingKeyFile) + if err != nil { + errs.Add(fmt.Errorf("couldn't resolve staking key path: %w", err)) + return + } + Config.StakingCertFile, err = homedir.Expand(Config.StakingCertFile) + if err != nil { + errs.Add(fmt.Errorf("couldn't resolve staking cert path: %v", err)) + return + } + defaultKeyPath, _ := homedir.Expand(defaultStakingKeyPath) + defaultStakingCertPath, _ := homedir.Expand(defaultStakingCertPath) + + switch { + // If staking key/cert locations are specified but not found, error + case Config.StakingKeyFile != defaultKeyPath || Config.StakingCertFile != defaultStakingCertPath: + if _, err := os.Stat(Config.StakingKeyFile); os.IsNotExist(err) { + errs.Add(fmt.Errorf("couldn't find staking key at %s", Config.StakingKeyFile)) + return + } else if _, err := os.Stat(Config.StakingCertFile); os.IsNotExist(err) { + errs.Add(fmt.Errorf("couldn't find staking certificate at %s", Config.StakingCertFile)) + return + } + default: + // Only creates staking key/cert if [stakingKeyPath] doesn't exist + if err := staking.GenerateStakingKeyCert(Config.StakingKeyFile, Config.StakingCertFile); err != nil { + errs.Add(fmt.Errorf("couldn't generate staking key/cert: %w", err)) + return + } + } + // HTTP: Config.HTTPPort = uint16(*httpPort) @@ -238,14 +289,18 @@ func init() { loggingConfig.Directory = *logsDir } logFileLevel, err := logging.ToLevel(*logLevel) - errs.Add(err) + if errs.Add(err); errs.Errored() { + return + } loggingConfig.LogLevel = logFileLevel if *logDisplayLevel == "" { *logDisplayLevel = *logLevel } displayLevel, err := logging.ToLevel(*logDisplayLevel) - errs.Add(err) + if errs.Add(err); errs.Errored() { + return + } loggingConfig.DisplayLevel = displayLevel Config.LoggingConfig = loggingConfig diff --git a/staking/gen_staker_key.go b/staking/gen_staker_key.go new file mode 100644 index 0000000..8969ea3 --- /dev/null +++ b/staking/gen_staker_key.go @@ -0,0 +1,74 @@ +package staking + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// GenerateStakingKeyCert generates a self-signed TLS key/cert pair to use in staking +// The key and files will be placed at [keyPath] and [certPath], respectively +// If there is already a file at [keyPath], returns nil +func GenerateStakingKeyCert(keyPath, certPath string) error { + // If there is already a file at [keyPath], do nothing + if _, err := os.Stat(keyPath); !os.IsNotExist(err) { + return nil + } + + // Create key to sign cert with + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return fmt.Errorf("couldn't generate rsa key: %w", err) + } + + // Create self-signed staking cert + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0), + NotBefore: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + NotAfter: time.Now().AddDate(100, 0, 0), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment, + BasicConstraintsValid: true, + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &key.PublicKey, key) + if err != nil { + return fmt.Errorf("couldn't create certificate: %w", err) + } + + // Write cert to disk + if err := os.MkdirAll(filepath.Dir(certPath), 0755); err != nil { + return fmt.Errorf("couldn't create path for key/cert: %w", err) + } + certOut, err := os.Create(certPath) + if err != nil { + return fmt.Errorf("couldn't create cert file: %w", err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}); err != nil { + return fmt.Errorf("couldn't write cert file: %w", err) + } + if err := certOut.Close(); err != nil { + return fmt.Errorf("couldn't close cert file: %w", err) + } + + // Write key to disk + keyOut, err := os.Create(keyPath) + if err != nil { + return fmt.Errorf("couldn't create key file: %w", err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return fmt.Errorf("couldn't marshal private key: %w", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + return fmt.Errorf("couldn't write private key: %w", err) + } + if err := keyOut.Close(); err != nil { + return fmt.Errorf("couldn't close key file: %w", err) + } + return nil +}