cleanup/formatting. Add logs and error handling. Make token lifespan reasonable length.

This commit is contained in:
Dan Laine 2020-06-25 18:19:53 -04:00
parent 7038641fd8
commit 9b0981fe7b
4 changed files with 106 additions and 72 deletions

View File

@ -2,37 +2,45 @@ package auth
import (
"errors"
"fmt"
"io"
"net/http"
"path"
"strings"
"sync"
"time"
jwt "github.com/dgrijalva/jwt-go"
)
// TODO: Add method to revoke a token
const (
headerKey = "Authorization"
headerValStart = "Bearer "
// Endpoint is the base of the auth URL
Endpoint = "auth"
)
var (
// TokenLifespan is how long a token lives before it expires
TokenLifespan = 1 * time.Minute
TokenLifespan = time.Hour * 12
// ErrNoToken is returned by GetToken if no token is provided
ErrNoToken = errors.New("auth token not provided")
)
// TODO: Add method to revoke a token
// Auth ...
// Auth handles HTTP API authorization for this node
type Auth struct {
lock sync.RWMutex
Password string
lock sync.RWMutex // Prevent race condition when accessing password
Enabled bool // True iff API calls need auth token
Password string // The password. Can be changed via API call.
}
// GetToken gets the JWT token from the request header
// getToken gets the JWT token from the request header
// Assumes the header is this form:
// "Authorization": "Bearer TOKEN.GOES.HERE"
func GetToken(r *http.Request) (string, error) {
func getToken(r *http.Request) (string, error) {
rawHeader := r.Header.Get("Authorization") // Should be "Bearer AUTH.TOKEN.HERE"
if rawHeader == "" {
return "", ErrNoToken
@ -40,5 +48,48 @@ func GetToken(r *http.Request) (string, error) {
if !strings.HasPrefix(rawHeader, headerValStart) {
return "", errors.New("token is invalid format")
}
return rawHeader[len(headerValStart):], nil // Should be the actual auth token. Slice guaranteed to not go OOB
return rawHeader[len(headerValStart):], nil // Returns actual auth token. Slice guaranteed to not go OOB
}
// WrapHandler wraps a handler. Before passing a request to the handler, check that
// an auth token was provided (if necessary) and that it is valid/unexpired.
func (auth *Auth) WrapHandler(h http.Handler) http.Handler {
if !auth.Enabled { // Auth tokens aren't in use. Do nothing.
return h
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if path.Base(r.URL.Path) == Endpoint { // Don't require auth token to hit auth endpoint
h.ServeHTTP(w, r)
return
}
tokenStr, err := getToken(r) // Get the token from the header
if err == ErrNoToken {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, err.Error())
return
} else if err != nil {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, "couldn't parse auth token. Header \"Authorization\" should be \"Bearer TOKEN.GOES.HERE\"")
return
}
token, err := jwt.Parse(tokenStr, func(*jwt.Token) (interface{}, error) { // See if token is well-formed and signature is right
auth.lock.RLock()
defer auth.lock.RUnlock()
return []byte(auth.Password), nil
})
if err != nil { // Signature is probably wrong
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, fmt.Sprintf("invalid auth token: %s", err))
return
}
if !token.Valid { // Check that token isn't expired
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, "invalid auth token. Is it expired?")
return
}
h.ServeHTTP(w, r) // Authentication successful
})
}

View File

@ -15,7 +15,8 @@ import (
// Service ...
type Service struct {
*Auth
log logging.Logger
*Auth // has to be a reference to the same Auth inside the API sever
}
// NewService returns a new auth API service
@ -24,7 +25,7 @@ func NewService(log logging.Logger, auth *Auth) *common.HTTPHandler {
codec := cjson.NewCodec()
newServer.RegisterCodec(codec, "application/json")
newServer.RegisterCodec(codec, "application/json;charset=UTF-8")
newServer.RegisterService(&Service{Auth: auth}, "auth")
newServer.RegisterService(&Service{Auth: auth, log: log}, "auth")
return &common.HTTPHandler{Handler: newServer}
}
@ -35,22 +36,23 @@ type Success struct {
// NewTokenArgs ...
type NewTokenArgs struct {
Password string `json:"password"`
Password string `json:"password"` // The authotization password
}
// NewTokenResponse ...
type NewTokenResponse struct {
Token string `json:"token"`
Token string `json:"token"` // The new token. Expires in [TokenLifespan].
}
// NewToken ...
// NewToken returns a new token
func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *NewTokenResponse) error {
s.log.Info("Auth: NewToken called")
s.lock.RLock()
defer s.lock.RUnlock()
if args.Password != s.Password {
return errors.New("incorrect password")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ // Make a new token that expires in one week
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ // Make a new token
ExpiresAt: time.Now().Add(TokenLifespan).Unix(),
})
var err error
@ -60,18 +62,21 @@ func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *NewTokenR
// ChangePasswordArgs ...
type ChangePasswordArgs struct {
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
OldPassword string `json:"oldPassword"` // Current authorization password
NewPassword string `json:"newPassword"` // New authorization password
}
// ChangePassword ...
func (s *Service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *Success) error {
s.log.Info("Auth: ChangePassword called")
s.lock.Lock()
defer s.lock.Unlock()
if args.OldPassword != s.Password {
return errors.New("incorrect password")
} else if len(args.NewPassword) == 0 {
return errors.New("new password can't be empty")
}
s.Password = args.NewPassword // TODO: Add validation for password
s.Password = args.NewPassword
reply.Success = true
return nil
}

View File

@ -10,11 +10,8 @@ import (
"net"
"net/http"
"net/url"
"path"
"sync"
jwt "github.com/dgrijalva/jwt-go"
"github.com/gorilla/handlers"
"github.com/rs/cors"
@ -26,8 +23,7 @@ import (
)
const (
baseURL = "/ext"
authEndpoint = "auth"
baseURL = "/ext"
)
var (
@ -36,58 +32,35 @@ var (
// Server maintains the HTTP router
type Server struct {
log logging.Logger
factory logging.Factory
router *router
listenAddress string
requireAuthToken bool
auth *auth.Auth
}
// Wrap a handler. Before passing a request to the handler, check that
func (s *Server) authMiddleware(h http.Handler) http.Handler {
if !s.requireAuthToken {
return h
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if path.Base(r.URL.Path) == authEndpoint { // Don't require auth token to hit auth endpoint
h.ServeHTTP(w, r)
return
}
tokenStr, err := auth.GetToken(r) // Get the token from the header
if err == auth.ErrNoToken {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, err.Error())
return
} else if err != nil {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, "couldn't parse auth token. Header \"Authorization\" should be \"Bearer TOKEN.GOES.HERE\"")
return
}
token, err := jwt.Parse(tokenStr, func(*jwt.Token) (interface{}, error) {
return []byte(s.auth.Password), nil
})
if !token.Valid {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, "auth token is invalid")
return
}
h.ServeHTTP(w, r)
})
// log this server writes to
log logging.Logger
// generates new logs for chains to write to
factory logging.Factory
// Maps endpoints to handlers
router *router
// Listens for HTTP traffic on this address
listenAddress string
// Handles authorization. Must be non-nil after initialization, even if token authorization is off.
// Assumes the auth service is the only endpoint whose path ends with /auth
auth *auth.Auth
}
// Initialize creates the API server at the provided host and port
func (s *Server) Initialize(log logging.Logger, factory logging.Factory, host string, port uint16, requireAuthToken bool, authPassword string) {
func (s *Server) Initialize(log logging.Logger, factory logging.Factory, host string, port uint16, authEnabled bool, authPassword string) error {
s.log = log
s.factory = factory
s.listenAddress = fmt.Sprintf("%s:%d", host, port)
s.router = newRouter()
if requireAuthToken {
s.requireAuthToken = requireAuthToken
s.auth = &auth.Auth{Password: authPassword}
authService := auth.NewService(s.log, s.auth)
s.AddRoute(authService, &sync.RWMutex{}, authEndpoint, "", s.log) // TODO check error
s.auth = &auth.Auth{
Enabled: authEnabled,
Password: authPassword,
}
if authEnabled { // only create auth service if token authorization is required
s.log.Info("API authorization is enabled. Auth token must be passed in header of API requests (except requests to auth service.)")
authService := auth.NewService(s.log, s.auth)
return s.AddRoute(authService, &sync.RWMutex{}, auth.Endpoint, "", s.log)
}
return nil
}
// Dispatch starts the API server
@ -98,7 +71,7 @@ func (s *Server) Dispatch() error {
}
s.log.Info("HTTP API server listening on %q", s.listenAddress)
handler := cors.Default().Handler(s.router)
handler = s.authMiddleware(handler)
handler = s.auth.WrapHandler(handler)
return http.Serve(listener, handler)
}
@ -110,7 +83,7 @@ func (s *Server) DispatchTLS(certFile, keyFile string) error {
}
s.log.Info("HTTPS API server listening on %q", s.listenAddress)
handler := cors.Default().Handler(s.router)
handler = s.authMiddleware(handler)
handler = s.auth.WrapHandler(handler)
return http.ServeTLS(listener, handler, certFile, keyFile)
}

View File

@ -390,10 +390,12 @@ func (n *Node) initChains() error {
}
// initAPIServer initializes the server that handles HTTP calls
func (n *Node) initAPIServer() {
func (n *Node) initAPIServer() error {
n.Log.Info("Initializing API server")
n.APIServer.Initialize(n.Log, n.LogFactory, n.Config.HTTPHost, n.Config.HTTPPort, n.Config.HTTPRequireAuthToken, n.Config.HTTPAuthPassword)
if err := n.APIServer.Initialize(n.Log, n.LogFactory, n.Config.HTTPHost, n.Config.HTTPPort, n.Config.HTTPRequireAuthToken, n.Config.HTTPAuthPassword); err != nil {
return err
}
go n.Log.RecoverAndPanic(func() {
if n.Config.EnableHTTPS {
@ -408,6 +410,7 @@ func (n *Node) initAPIServer() {
n.Log.Fatal("API server initialization failed with %s", err)
n.Net.Close()
})
return nil
}
// Assumes n.DB, n.vdrs all initialized (non-nil)
@ -561,7 +564,9 @@ func (n *Node) Initialize(Config *Config, logger logging.Logger, logFactory logg
n.initBeacons()
// Start HTTP APIs
n.initAPIServer() // Start the API Server
if err := n.initAPIServer(); err != nil { // Start the API Server
return fmt.Errorf("couldn't initialize API server: %w", err)
}
n.initKeystoreAPI() // Start the Keystore API
n.initMetricsAPI() // Start the Metrics API