diff --git a/cmd/server/main.go b/cmd/server/main.go index 949b21b..83cb91f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -105,12 +105,11 @@ func fileExists(filename string) bool { func main() { opts := &Options{} flag.StringVar(&opts.bindAddr, "bind-addr", "127.0.0.1:9067", "the address to listen on") - flag.StringVar(&opts.tlsCertPath, "tls-cert", "./cert.pem", "the path to a TLS certificate") - flag.StringVar(&opts.tlsKeyPath, "tls-key", "./cert.key", "the path to a TLS key file") + flag.StringVar(&opts.tlsCertPath, "tls-cert", "", "the path to a TLS certificate") + flag.StringVar(&opts.tlsKeyPath, "tls-key", "", "the path to a TLS key file") flag.Uint64Var(&opts.logLevel, "log-level", uint64(logrus.InfoLevel), "log level (logrus 1-7)") flag.StringVar(&opts.logPath, "log-file", "./server.log", "log file to write to") flag.StringVar(&opts.zcashConfPath, "conf-file", "./zcash.conf", "conf file to pull RPC creds from") - flag.BoolVar(&opts.veryInsecure, "no-tls-very-insecure", false, "run without the required TLS certificate, only for debugging, DO NOT use in production") flag.BoolVar(&opts.wantVersion, "version", false, "version (major.minor.patch)") flag.IntVar(&opts.cacheSize, "cache-size", 80000, "number of blocks to hold in the cache") @@ -124,19 +123,16 @@ func main() { } filesThatShouldExist := []string{ - opts.tlsCertPath, - opts.tlsKeyPath, - opts.logPath, opts.zcashConfPath, } + if opts.tlsCertPath != "" { + filesThatShouldExist = append(filesThatShouldExist, opts.tlsCertPath) + } + if opts.tlsKeyPath != "" { + filesThatShouldExist = append(filesThatShouldExist, opts.tlsKeyPath) + } for _, filename := range filesThatShouldExist { - if !fileExists(opts.logPath) { - os.OpenFile(opts.logPath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) - } - if opts.veryInsecure && (filename == opts.tlsCertPath || filename == opts.tlsKeyPath) { - continue - } if !fileExists(filename) { os.Stderr.WriteString(fmt.Sprintf("\n ** File does not exist: %s\n\n", filename)) flag.Usage() @@ -162,11 +158,15 @@ func main() { // gRPC initialization var server *grpc.Server + var transportCreds credentials.TransportCredentials + var err error - if opts.veryInsecure { - server = grpc.NewServer(LoggingInterceptor()) + if (opts.tlsCertPath == "") && (opts.tlsKeyPath == "") { + log.Warning("Certificate and key not provided, generating self signed values") + tlsCert := common.GenerateCerts() + transportCreds = credentials.NewServerTLSFromCert(tlsCert) } else { - transportCreds, err := credentials.NewServerTLSFromFile(opts.tlsCertPath, opts.tlsKeyPath) + transportCreds, err = credentials.NewServerTLSFromFile(opts.tlsCertPath, opts.tlsKeyPath) if err != nil { log.WithFields(logrus.Fields{ "cert_file": opts.tlsCertPath, @@ -174,8 +174,8 @@ func main() { "error": err, }).Fatal("couldn't load TLS credentials") } - server = grpc.NewServer(grpc.Creds(transportCreds), LoggingInterceptor()) } + server = grpc.NewServer(grpc.Creds(transportCreds), LoggingInterceptor()) // Enable reflection for debugging if opts.logLevel >= uint64(logrus.WarnLevel) { diff --git a/common/generatecerts.go b/common/generatecerts.go new file mode 100644 index 0000000..2b9cb79 --- /dev/null +++ b/common/generatecerts.go @@ -0,0 +1,72 @@ +package common + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log" + "math/big" + "time" +) + +// GenerateCerts create self signed certificate for local development use +func GenerateCerts() (cert *tls.Certificate) { + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + publicKey := &privKey.PublicKey + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %s", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Lighwalletd developer"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Local().Add(time.Hour * 24 * 365), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // List of hostnames and IPs for the cert + template.DNSNames = append(template.DNSNames, "localhost") + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privKey) + if err != nil { + log.Fatalf("Failed to create certificate: %s", err) + } + + // PEM encode the certificate (this is a standard TLS encoding) + b := pem.Block{Type: "CERTIFICATE", Bytes: certDER} + certPEM := pem.EncodeToMemory(&b) + fmt.Printf("%s\n", certPEM) + + // PEM encode the private key + privBytes, err := x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + log.Fatalf("Unable to marshal private key: %v", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", Bytes: privBytes, + }) + + // Create a TLS cert using the private key and certificate + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + log.Fatalf("invalid key pair: %v", err) + } + + cert = &tlsCert + return cert + +}