use clock in auth struct. add tests

This commit is contained in:
Dan Laine 2020-06-29 15:18:33 -04:00
parent 2640977cde
commit ba299559e7
4 changed files with 331 additions and 8 deletions

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/ava-labs/gecko/utils/timer"
jwt "github.com/dgrijalva/jwt-go" jwt "github.com/dgrijalva/jwt-go"
) )
@ -35,6 +36,7 @@ var (
type Auth struct { type Auth struct {
lock sync.RWMutex // Prevent race condition when accessing password lock sync.RWMutex // Prevent race condition when accessing password
Enabled bool // True iff API calls need auth token Enabled bool // True iff API calls need auth token
clock timer.Clock // Tells the time. Can be faked for testing
Password string // The password. Can be changed via API call. Password string // The password. Can be changed via API call.
revoked []string // List of tokens that have been revoked revoked []string // List of tokens that have been revoked
} }
@ -80,7 +82,7 @@ func (auth *Auth) newToken(password string, endpoints []string) (string, error)
} }
claims := endpointClaims{ claims := endpointClaims{
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
ExpiresAt: time.Now().Add(TokenLifespan).Unix(), ExpiresAt: auth.clock.Time().Add(TokenLifespan).Unix(),
}, },
} }
if canAccessAll { if canAccessAll {
@ -180,11 +182,6 @@ func (auth *Auth) WrapHandler(h http.Handler) http.Handler {
io.WriteString(w, "expected auth token's claims to be type endpointClaims but is different type") io.WriteString(w, "expected auth token's claims to be type endpointClaims but is different type")
return return
} }
if l := len(claims.Endpoints); l < 1 || l > maxEndpoints {
w.WriteHeader(http.StatusUnauthorized)
io.WriteString(w, fmt.Sprintf("expected auth token to allow access to between %d and %d endpoints, but does %d", 1, maxEndpoints, l))
return
}
canAccess := false // true iff the token authorizes access to the API canAccess := false // true iff the token authorizes access to the API
for _, endpoint := range claims.Endpoints { for _, endpoint := range claims.Endpoints {
if endpoint == "*" || strings.HasSuffix(r.URL.Path, endpoint) { if endpoint == "*" || strings.HasSuffix(r.URL.Path, endpoint) {

326
api/auth/auth_test.go Normal file
View File

@ -0,0 +1,326 @@
package auth
import (
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/ava-labs/gecko/utils/timer"
jwt "github.com/dgrijalva/jwt-go"
)
const (
password = "password"
)
var (
// Always returns 200 (http.StatusOK)
dummyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
)
func TestNewTokenWrongPassword(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
if _, err := auth.newToken("", []string{"endpoint1, endpoint2"}); err == nil {
t.Fatal("should have failed because password is wrong")
} else if _, err := auth.newToken("notThePassword", []string{"endpoint1, endpoint2"}); err == nil {
t.Fatal("should have failed because password is wrong")
}
}
func TestNewTokenHappyPath(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
now := time.Now()
auth.clock.Set(now)
// Make a token
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
// Parse the token
token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
auth.lock.RLock()
defer auth.lock.RUnlock()
return []byte(auth.Password), nil
})
if err != nil {
t.Fatalf("couldn't parse new token: %s", err)
}
claims, ok := token.Claims.(*endpointClaims)
if !ok {
t.Fatal("expected auth token's claims to be type endpointClaims but is different type")
}
if !reflect.DeepEqual(claims.Endpoints, endpoints) {
t.Fatal("token has wrong endpoint claims")
}
if shouldExpireAt := now.Add(TokenLifespan).Unix(); shouldExpireAt != now.Add(TokenLifespan).Unix() {
t.Fatalf("token expiration time is wrong")
}
}
func TestTokenHasWrongSig(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
// Make a token
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
// Try to parse the token using the wrong password
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
auth.lock.RLock()
defer auth.lock.RUnlock()
return []byte(""), nil
}); err == nil {
t.Fatalf("should have failed because password is wrong")
}
// Try to parse the token using the wrong password
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
auth.lock.RLock()
defer auth.lock.RUnlock()
return []byte("notThePassword"), nil
}); err == nil {
t.Fatalf("should have failed because password is wrong")
}
}
func TestChangePassword(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
password2 := "password2"
if err := auth.changePassword("", password2); err == nil {
t.Fatal("should have failed because old password is wrong")
} else if err := auth.changePassword("notThePassword", password2); err == nil {
t.Fatal("should have failed because old password is wrong")
} else if err := auth.changePassword(password, ""); err == nil {
t.Fatal("should have failed because new password is empty")
} else if err := auth.changePassword(password, password2); err != nil {
t.Fatal("should have succeeded")
}
if auth.Password != password2 {
t.Fatal("password should have been changed")
}
password3 := "password3"
if err := auth.changePassword(password, password3); err == nil {
t.Fatal("should have failed because old password is wrong")
} else if err := auth.changePassword(password2, password3); err != nil {
t.Fatal("should have succeeded")
}
}
func TestGetToken(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "http://127.0.0.1:9650/ext/auth", strings.NewReader(""))
if _, err := getToken(req); err == nil {
t.Fatal("should have failed because no auth token given")
}
req.Header.Add("Authorization", "")
if _, err := getToken(req); err == nil {
t.Fatal("should have failed because auth token invalid")
}
req.Header.Set("Authorization", "this isn't an auth token!")
if _, err := getToken(req); err == nil {
t.Fatal("should have failed because auth token invalid")
}
wellFormedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJFbmRwb2ludHMiOlsiKiJdLCJleHAiOjE1OTM0NzU4OTR9.Cqo7TraN_CFN13q3ae4GRJCMgd8ZOlQwBzyC29M6Aps"
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", wellFormedToken))
if token, err := getToken(req); err != nil {
t.Fatal("should have been able to parse valid header")
} else if token != wellFormedToken {
t.Fatal("parsed token incorrectly")
}
}
func TestWrapHandlerHappyPath(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
// Make a token
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range endpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatal("should have passed authorization")
}
}
}
func TestWrapHandlerExpiredToken(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
clock: timer.Clock{},
}
auth.clock.Set(time.Now().Add(-2 * TokenLifespan))
// Make a token that expired well in the past
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range endpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatal("should have failed authorization because token is expired")
}
}
}
func TestWrapHandlerNoAuthToken(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range endpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatal("should have failed authorization since no auth token given")
}
}
}
func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
// Make a token
endpoints := []string{"/ext/info"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
unauthorizedEndpoints := []string{"/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range unauthorizedEndpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatal("should have failed authorization since this endpoint is not allowed by the token")
}
}
}
func TestWrapHandlerAuthEndpoint(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
// Make a token
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
tokenStr, err := auth.newToken(password, endpoints)
if err != nil {
t.Fatal(err)
}
wrappedHandler := auth.WrapHandler(dummyHandler)
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", fmt.Sprintf("/ext/%s", Endpoint)), strings.NewReader(""))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatal("should always allow access to the auth endpoint")
}
}
func TestWrapHandlerAccessAll(t *testing.T) {
auth := Auth{
Enabled: true,
Password: password,
}
// Make a token that allows access to all endpoints
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info"}
tokenStr, err := auth.newToken(password, []string{"*"})
if err != nil {
t.Fatal(err)
}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range endpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatal("* in token should have allowed access to all endpoints")
}
}
}
func TestWrapHandlerAuthDisabled(t *testing.T) {
auth := Auth{
Enabled: false,
Password: password,
}
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info", "/ext/auth"}
wrappedHandler := auth.WrapHandler(dummyHandler)
for _, endpoint := range endpoints {
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatal("auth is disabled so should allow access to all endpoints")
}
}
}

View File

@ -38,7 +38,7 @@ type Success struct {
// Password ... // Password ...
type Password struct { type Password struct {
Password string `json:"password"` // The authotization password Password string `json:"password"` // The authorization password
} }
// NewTokenArgs ... // NewTokenArgs ...

View File

@ -30,7 +30,7 @@ func (s *Service) Call(_ *http.Request, args *Args, reply *Reply) error {
func TestCall(t *testing.T) { func TestCall(t *testing.T) {
s := Server{} s := Server{}
s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080) s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080, false, "")
serv := &Service{} serv := &Service{}
newServer := rpc.NewServer() newServer := rpc.NewServer()