generate staking key at ~/.gecko/staking/staker.key if no key given.

This commit is contained in:
Dan Laine 2020-05-07 14:34:32 -04:00
parent 02f162db1a
commit 6bd84af8ec
2 changed files with 156 additions and 27 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/ava-labs/gecko/nat"
"github.com/ava-labs/gecko/node"
"github.com/ava-labs/gecko/snow/networking/router"
"github.com/ava-labs/gecko/staking"
"github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/utils/formatting"
"github.com/ava-labs/gecko/utils/hashing"
@ -34,8 +35,14 @@ const (
// Results of parsing the CLI
var (
Config = node.Config{}
Err error
Config = node.Config{}
Err error
defaultStakingKeyPath = "~/.gecko/staking/staker.key"
defaultStakingCertPath = "~/.gecko/staking/staker.crt"
)
var (
errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs")
)
// GetIPs returns the default IPs for each network
@ -54,17 +61,15 @@ func GetIPs(networkID uint32) []string {
}
}
var (
errBootstrapMismatch = errors.New("more bootstrap IDs provided than bootstrap IPs")
)
// Parse the CLI arguments
func init() {
errs := &wrappers.Errs{}
defer func() { Err = errs.Err }()
loggingConfig, err := logging.DefaultConfig()
errs.Add(err)
if errs.Add(err); errs.Errored() {
return
}
fs := flag.NewFlagSet("gecko", flag.ContinueOnError)
@ -100,8 +105,8 @@ func init() {
// Staking:
consensusPort := fs.Uint("staking-port", 9651, "Port of the consensus server")
fs.BoolVar(&Config.EnableStaking, "staking-tls-enabled", true, "Require TLS to authenticate staking connections")
fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", "keys/staker.key", "TLS private key file for staking connections")
fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", "keys/staker.crt", "TLS certificate file for staking connections")
fs.StringVar(&Config.StakingKeyFile, "staking-tls-key-file", defaultStakingKeyPath, "TLS private key for staking")
fs.StringVar(&Config.StakingCertFile, "staking-tls-cert-file", defaultStakingCertPath, "TLS certificate for staking")
// Plugins:
fs.StringVar(&Config.PluginDir, "plugin-dir", "./build/plugins", "Plugin directory for Ava VMs")
@ -142,22 +147,24 @@ func init() {
}
networkID, err := genesis.NetworkID(*networkName)
errs.Add(err)
if errs.Add(err); errs.Errored() {
return
}
Config.NetworkID = networkID
// DB:
if *db && err == nil {
// TODO: Add better params here
if *dbDir == defaultDbDir {
if *dbDir, err = homedir.Expand(defaultDbDir); err != nil {
errs.Add(fmt.Errorf("couldn't resolve default db path: %v", err))
}
if *db {
*dbDir, err = homedir.Expand(*dbDir)
if errs.Add(fmt.Errorf("couldn't resolve db path: %w", err)); errs.Errored() {
return
}
dbPath := path.Join(*dbDir, genesis.NetworkName(Config.NetworkID), dbVersion)
db, err := leveldb.New(dbPath, 0, 0, 0)
if errs.Add(fmt.Errorf("couldn't create db: %w", err)); errs.Errored() {
return
}
Config.DB = db
errs.Add(err)
} else {
Config.DB = memdb.New()
}
@ -169,7 +176,7 @@ func init() {
if *consensusIP == "" {
ip, err = Config.Nat.IP()
if err != nil {
ip = net.IPv4zero
ip = net.IPv4zero // Couldn't get my IP...set to 0.0.0.0
}
} else {
ip = net.ParseIP(*consensusIP)
@ -177,7 +184,9 @@ func init() {
if ip == nil {
errs.Add(fmt.Errorf("Invalid IP Address %s", *consensusIP))
return
}
Config.StakingIP = utils.IPDesc{
IP: ip,
Port: uint16(*consensusPort),
@ -190,7 +199,10 @@ func init() {
for _, ip := range strings.Split(*bootstrapIPs, ",") {
if ip != "" {
addr, err := utils.ToIPDesc(ip)
errs.Add(err)
if err != nil {
errs.Add(fmt.Errorf("couldn't parse ip: %w", err))
return
}
Config.BootstrapPeers = append(Config.BootstrapPeers, &node.Peer{
IP: addr,
})
@ -209,20 +221,27 @@ func init() {
cb58 := formatting.CB58{}
for _, id := range strings.Split(*bootstrapIDs, ",") {
if id != "" {
errs.Add(cb58.FromString(id))
cert, err := ids.ToShortID(cb58.Bytes)
errs.Add(err)
err = cb58.FromString(id)
if err != nil {
errs.Add(fmt.Errorf("couldn't parse bootstrap peer id to bytes: %w", err))
return
}
peerID, err := ids.ToShortID(cb58.Bytes)
if err != nil {
errs.Add(fmt.Errorf("couldn't parse bootstrap peer id: %w", err))
return
}
if len(Config.BootstrapPeers) <= i {
errs.Add(errBootstrapMismatch)
continue
return
}
Config.BootstrapPeers[i].ID = cert
Config.BootstrapPeers[i].ID = peerID
i++
}
}
if len(Config.BootstrapPeers) != i {
errs.Add(fmt.Errorf("More bootstrap IPs, %d, provided than bootstrap IDs, %d", len(Config.BootstrapPeers), i))
return
}
} else {
for _, peer := range Config.BootstrapPeers {
@ -230,6 +249,38 @@ func init() {
}
}
// Staking
Config.StakingKeyFile, err = homedir.Expand(Config.StakingKeyFile)
if err != nil {
errs.Add(fmt.Errorf("couldn't resolve staking key path: %w", err))
return
}
Config.StakingCertFile, err = homedir.Expand(Config.StakingCertFile)
if err != nil {
errs.Add(fmt.Errorf("couldn't resolve staking cert path: %v", err))
return
}
defaultKeyPath, _ := homedir.Expand(defaultStakingKeyPath)
defaultStakingCertPath, _ := homedir.Expand(defaultStakingCertPath)
switch {
// If staking key/cert locations are specified but not found, error
case Config.StakingKeyFile != defaultKeyPath || Config.StakingCertFile != defaultStakingCertPath:
if _, err := os.Stat(Config.StakingKeyFile); os.IsNotExist(err) {
errs.Add(fmt.Errorf("couldn't find staking key at %s", Config.StakingKeyFile))
return
} else if _, err := os.Stat(Config.StakingCertFile); os.IsNotExist(err) {
errs.Add(fmt.Errorf("couldn't find staking certificate at %s", Config.StakingCertFile))
return
}
default:
// Only creates staking key/cert if [stakingKeyPath] doesn't exist
if err := staking.GenerateStakingKeyCert(Config.StakingKeyFile, Config.StakingCertFile); err != nil {
errs.Add(fmt.Errorf("couldn't generate staking key/cert: %w", err))
return
}
}
// HTTP:
Config.HTTPPort = uint16(*httpPort)
@ -238,14 +289,18 @@ func init() {
loggingConfig.Directory = *logsDir
}
logFileLevel, err := logging.ToLevel(*logLevel)
errs.Add(err)
if errs.Add(err); errs.Errored() {
return
}
loggingConfig.LogLevel = logFileLevel
if *logDisplayLevel == "" {
*logDisplayLevel = *logLevel
}
displayLevel, err := logging.ToLevel(*logDisplayLevel)
errs.Add(err)
if errs.Add(err); errs.Errored() {
return
}
loggingConfig.DisplayLevel = displayLevel
Config.LoggingConfig = loggingConfig

74
staking/gen_staker_key.go Normal file
View File

@ -0,0 +1,74 @@
package staking
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)
// GenerateStakingKeyCert generates a self-signed TLS key/cert pair to use in staking
// The key and files will be placed at [keyPath] and [certPath], respectively
// If there is already a file at [keyPath], returns nil
func GenerateStakingKeyCert(keyPath, certPath string) error {
// If there is already a file at [keyPath], do nothing
if _, err := os.Stat(keyPath); !os.IsNotExist(err) {
return nil
}
// Create key to sign cert with
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return fmt.Errorf("couldn't generate rsa key: %w", err)
}
// Create self-signed staking cert
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(0),
NotBefore: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC),
NotAfter: time.Now().AddDate(100, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment,
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &key.PublicKey, key)
if err != nil {
return fmt.Errorf("couldn't create certificate: %w", err)
}
// Write cert to disk
if err := os.MkdirAll(filepath.Dir(certPath), 0755); err != nil {
return fmt.Errorf("couldn't create path for key/cert: %w", err)
}
certOut, err := os.Create(certPath)
if err != nil {
return fmt.Errorf("couldn't create cert file: %w", err)
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}); err != nil {
return fmt.Errorf("couldn't write cert file: %w", err)
}
if err := certOut.Close(); err != nil {
return fmt.Errorf("couldn't close cert file: %w", err)
}
// Write key to disk
keyOut, err := os.Create(keyPath)
if err != nil {
return fmt.Errorf("couldn't create key file: %w", err)
}
privBytes, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return fmt.Errorf("couldn't marshal private key: %w", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return fmt.Errorf("couldn't write private key: %w", err)
}
if err := keyOut.Close(); err != nil {
return fmt.Errorf("couldn't close key file: %w", err)
}
return nil
}