mirror of https://github.com/poanetwork/gecko.git
cleanup/formatting. Add logs and error handling. Make token lifespan reasonable length.
This commit is contained in:
parent
7038641fd8
commit
9b0981fe7b
|
@ -2,37 +2,45 @@ package auth
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
// TODO: Add method to revoke a token
|
||||
|
||||
const (
|
||||
headerKey = "Authorization"
|
||||
headerValStart = "Bearer "
|
||||
// Endpoint is the base of the auth URL
|
||||
Endpoint = "auth"
|
||||
)
|
||||
|
||||
var (
|
||||
// TokenLifespan is how long a token lives before it expires
|
||||
TokenLifespan = 1 * time.Minute
|
||||
TokenLifespan = time.Hour * 12
|
||||
|
||||
// ErrNoToken is returned by GetToken if no token is provided
|
||||
ErrNoToken = errors.New("auth token not provided")
|
||||
)
|
||||
|
||||
// TODO: Add method to revoke a token
|
||||
|
||||
// Auth ...
|
||||
// Auth handles HTTP API authorization for this node
|
||||
type Auth struct {
|
||||
lock sync.RWMutex
|
||||
Password string
|
||||
lock sync.RWMutex // Prevent race condition when accessing password
|
||||
Enabled bool // True iff API calls need auth token
|
||||
Password string // The password. Can be changed via API call.
|
||||
}
|
||||
|
||||
// GetToken gets the JWT token from the request header
|
||||
// getToken gets the JWT token from the request header
|
||||
// Assumes the header is this form:
|
||||
// "Authorization": "Bearer TOKEN.GOES.HERE"
|
||||
func GetToken(r *http.Request) (string, error) {
|
||||
func getToken(r *http.Request) (string, error) {
|
||||
rawHeader := r.Header.Get("Authorization") // Should be "Bearer AUTH.TOKEN.HERE"
|
||||
if rawHeader == "" {
|
||||
return "", ErrNoToken
|
||||
|
@ -40,5 +48,48 @@ func GetToken(r *http.Request) (string, error) {
|
|||
if !strings.HasPrefix(rawHeader, headerValStart) {
|
||||
return "", errors.New("token is invalid format")
|
||||
}
|
||||
return rawHeader[len(headerValStart):], nil // Should be the actual auth token. Slice guaranteed to not go OOB
|
||||
return rawHeader[len(headerValStart):], nil // Returns actual auth token. Slice guaranteed to not go OOB
|
||||
}
|
||||
|
||||
// WrapHandler wraps a handler. Before passing a request to the handler, check that
|
||||
// an auth token was provided (if necessary) and that it is valid/unexpired.
|
||||
func (auth *Auth) WrapHandler(h http.Handler) http.Handler {
|
||||
if !auth.Enabled { // Auth tokens aren't in use. Do nothing.
|
||||
return h
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if path.Base(r.URL.Path) == Endpoint { // Don't require auth token to hit auth endpoint
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr, err := getToken(r) // Get the token from the header
|
||||
if err == ErrNoToken {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
} else if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, "couldn't parse auth token. Header \"Authorization\" should be \"Bearer TOKEN.GOES.HERE\"")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenStr, func(*jwt.Token) (interface{}, error) { // See if token is well-formed and signature is right
|
||||
auth.lock.RLock()
|
||||
defer auth.lock.RUnlock()
|
||||
return []byte(auth.Password), nil
|
||||
})
|
||||
if err != nil { // Signature is probably wrong
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, fmt.Sprintf("invalid auth token: %s", err))
|
||||
return
|
||||
}
|
||||
if !token.Valid { // Check that token isn't expired
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, "invalid auth token. Is it expired?")
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r) // Authentication successful
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,7 +15,8 @@ import (
|
|||
|
||||
// Service ...
|
||||
type Service struct {
|
||||
*Auth
|
||||
log logging.Logger
|
||||
*Auth // has to be a reference to the same Auth inside the API sever
|
||||
}
|
||||
|
||||
// NewService returns a new auth API service
|
||||
|
@ -24,7 +25,7 @@ func NewService(log logging.Logger, auth *Auth) *common.HTTPHandler {
|
|||
codec := cjson.NewCodec()
|
||||
newServer.RegisterCodec(codec, "application/json")
|
||||
newServer.RegisterCodec(codec, "application/json;charset=UTF-8")
|
||||
newServer.RegisterService(&Service{Auth: auth}, "auth")
|
||||
newServer.RegisterService(&Service{Auth: auth, log: log}, "auth")
|
||||
return &common.HTTPHandler{Handler: newServer}
|
||||
}
|
||||
|
||||
|
@ -35,22 +36,23 @@ type Success struct {
|
|||
|
||||
// NewTokenArgs ...
|
||||
type NewTokenArgs struct {
|
||||
Password string `json:"password"`
|
||||
Password string `json:"password"` // The authotization password
|
||||
}
|
||||
|
||||
// NewTokenResponse ...
|
||||
type NewTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Token string `json:"token"` // The new token. Expires in [TokenLifespan].
|
||||
}
|
||||
|
||||
// NewToken ...
|
||||
// NewToken returns a new token
|
||||
func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *NewTokenResponse) error {
|
||||
s.log.Info("Auth: NewToken called")
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
if args.Password != s.Password {
|
||||
return errors.New("incorrect password")
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ // Make a new token that expires in one week
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ // Make a new token
|
||||
ExpiresAt: time.Now().Add(TokenLifespan).Unix(),
|
||||
})
|
||||
var err error
|
||||
|
@ -60,18 +62,21 @@ func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *NewTokenR
|
|||
|
||||
// ChangePasswordArgs ...
|
||||
type ChangePasswordArgs struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
OldPassword string `json:"oldPassword"` // Current authorization password
|
||||
NewPassword string `json:"newPassword"` // New authorization password
|
||||
}
|
||||
|
||||
// ChangePassword ...
|
||||
func (s *Service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *Success) error {
|
||||
s.log.Info("Auth: ChangePassword called")
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
if args.OldPassword != s.Password {
|
||||
return errors.New("incorrect password")
|
||||
} else if len(args.NewPassword) == 0 {
|
||||
return errors.New("new password can't be empty")
|
||||
}
|
||||
s.Password = args.NewPassword // TODO: Add validation for password
|
||||
s.Password = args.NewPassword
|
||||
reply.Success = true
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -10,11 +10,8 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
|
||||
"github.com/gorilla/handlers"
|
||||
|
||||
"github.com/rs/cors"
|
||||
|
@ -26,8 +23,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
baseURL = "/ext"
|
||||
authEndpoint = "auth"
|
||||
baseURL = "/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -36,58 +32,35 @@ var (
|
|||
|
||||
// Server maintains the HTTP router
|
||||
type Server struct {
|
||||
log logging.Logger
|
||||
factory logging.Factory
|
||||
router *router
|
||||
listenAddress string
|
||||
requireAuthToken bool
|
||||
auth *auth.Auth
|
||||
}
|
||||
|
||||
// Wrap a handler. Before passing a request to the handler, check that
|
||||
func (s *Server) authMiddleware(h http.Handler) http.Handler {
|
||||
if !s.requireAuthToken {
|
||||
return h
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if path.Base(r.URL.Path) == authEndpoint { // Don't require auth token to hit auth endpoint
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tokenStr, err := auth.GetToken(r) // Get the token from the header
|
||||
if err == auth.ErrNoToken {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
} else if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, "couldn't parse auth token. Header \"Authorization\" should be \"Bearer TOKEN.GOES.HERE\"")
|
||||
return
|
||||
}
|
||||
token, err := jwt.Parse(tokenStr, func(*jwt.Token) (interface{}, error) {
|
||||
return []byte(s.auth.Password), nil
|
||||
})
|
||||
if !token.Valid {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, "auth token is invalid")
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
// log this server writes to
|
||||
log logging.Logger
|
||||
// generates new logs for chains to write to
|
||||
factory logging.Factory
|
||||
// Maps endpoints to handlers
|
||||
router *router
|
||||
// Listens for HTTP traffic on this address
|
||||
listenAddress string
|
||||
// Handles authorization. Must be non-nil after initialization, even if token authorization is off.
|
||||
// Assumes the auth service is the only endpoint whose path ends with /auth
|
||||
auth *auth.Auth
|
||||
}
|
||||
|
||||
// Initialize creates the API server at the provided host and port
|
||||
func (s *Server) Initialize(log logging.Logger, factory logging.Factory, host string, port uint16, requireAuthToken bool, authPassword string) {
|
||||
func (s *Server) Initialize(log logging.Logger, factory logging.Factory, host string, port uint16, authEnabled bool, authPassword string) error {
|
||||
s.log = log
|
||||
s.factory = factory
|
||||
s.listenAddress = fmt.Sprintf("%s:%d", host, port)
|
||||
s.router = newRouter()
|
||||
if requireAuthToken {
|
||||
s.requireAuthToken = requireAuthToken
|
||||
s.auth = &auth.Auth{Password: authPassword}
|
||||
authService := auth.NewService(s.log, s.auth)
|
||||
s.AddRoute(authService, &sync.RWMutex{}, authEndpoint, "", s.log) // TODO check error
|
||||
s.auth = &auth.Auth{
|
||||
Enabled: authEnabled,
|
||||
Password: authPassword,
|
||||
}
|
||||
if authEnabled { // only create auth service if token authorization is required
|
||||
s.log.Info("API authorization is enabled. Auth token must be passed in header of API requests (except requests to auth service.)")
|
||||
authService := auth.NewService(s.log, s.auth)
|
||||
return s.AddRoute(authService, &sync.RWMutex{}, auth.Endpoint, "", s.log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dispatch starts the API server
|
||||
|
@ -98,7 +71,7 @@ func (s *Server) Dispatch() error {
|
|||
}
|
||||
s.log.Info("HTTP API server listening on %q", s.listenAddress)
|
||||
handler := cors.Default().Handler(s.router)
|
||||
handler = s.authMiddleware(handler)
|
||||
handler = s.auth.WrapHandler(handler)
|
||||
return http.Serve(listener, handler)
|
||||
}
|
||||
|
||||
|
@ -110,7 +83,7 @@ func (s *Server) DispatchTLS(certFile, keyFile string) error {
|
|||
}
|
||||
s.log.Info("HTTPS API server listening on %q", s.listenAddress)
|
||||
handler := cors.Default().Handler(s.router)
|
||||
handler = s.authMiddleware(handler)
|
||||
handler = s.auth.WrapHandler(handler)
|
||||
return http.ServeTLS(listener, handler, certFile, keyFile)
|
||||
}
|
||||
|
||||
|
|
11
node/node.go
11
node/node.go
|
@ -390,10 +390,12 @@ func (n *Node) initChains() error {
|
|||
}
|
||||
|
||||
// initAPIServer initializes the server that handles HTTP calls
|
||||
func (n *Node) initAPIServer() {
|
||||
func (n *Node) initAPIServer() error {
|
||||
n.Log.Info("Initializing API server")
|
||||
|
||||
n.APIServer.Initialize(n.Log, n.LogFactory, n.Config.HTTPHost, n.Config.HTTPPort, n.Config.HTTPRequireAuthToken, n.Config.HTTPAuthPassword)
|
||||
if err := n.APIServer.Initialize(n.Log, n.LogFactory, n.Config.HTTPHost, n.Config.HTTPPort, n.Config.HTTPRequireAuthToken, n.Config.HTTPAuthPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go n.Log.RecoverAndPanic(func() {
|
||||
if n.Config.EnableHTTPS {
|
||||
|
@ -408,6 +410,7 @@ func (n *Node) initAPIServer() {
|
|||
n.Log.Fatal("API server initialization failed with %s", err)
|
||||
n.Net.Close()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assumes n.DB, n.vdrs all initialized (non-nil)
|
||||
|
@ -561,7 +564,9 @@ func (n *Node) Initialize(Config *Config, logger logging.Logger, logFactory logg
|
|||
n.initBeacons()
|
||||
|
||||
// Start HTTP APIs
|
||||
n.initAPIServer() // Start the API Server
|
||||
if err := n.initAPIServer(); err != nil { // Start the API Server
|
||||
return fmt.Errorf("couldn't initialize API server: %w", err)
|
||||
}
|
||||
n.initKeystoreAPI() // Start the Keystore API
|
||||
n.initMetricsAPI() // Start the Metrics API
|
||||
|
||||
|
|
Loading…
Reference in New Issue