mirror of https://github.com/poanetwork/gecko.git
use clock in auth struct. add tests
This commit is contained in:
parent
2640977cde
commit
ba299559e7
|
@ -10,6 +10,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ava-labs/gecko/utils/timer"
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
|
@ -35,6 +36,7 @@ var (
|
|||
type Auth struct {
|
||||
lock sync.RWMutex // Prevent race condition when accessing password
|
||||
Enabled bool // True iff API calls need auth token
|
||||
clock timer.Clock // Tells the time. Can be faked for testing
|
||||
Password string // The password. Can be changed via API call.
|
||||
revoked []string // List of tokens that have been revoked
|
||||
}
|
||||
|
@ -80,7 +82,7 @@ func (auth *Auth) newToken(password string, endpoints []string) (string, error)
|
|||
}
|
||||
claims := endpointClaims{
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
ExpiresAt: time.Now().Add(TokenLifespan).Unix(),
|
||||
ExpiresAt: auth.clock.Time().Add(TokenLifespan).Unix(),
|
||||
},
|
||||
}
|
||||
if canAccessAll {
|
||||
|
@ -180,11 +182,6 @@ func (auth *Auth) WrapHandler(h http.Handler) http.Handler {
|
|||
io.WriteString(w, "expected auth token's claims to be type endpointClaims but is different type")
|
||||
return
|
||||
}
|
||||
if l := len(claims.Endpoints); l < 1 || l > maxEndpoints {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
io.WriteString(w, fmt.Sprintf("expected auth token to allow access to between %d and %d endpoints, but does %d", 1, maxEndpoints, l))
|
||||
return
|
||||
}
|
||||
canAccess := false // true iff the token authorizes access to the API
|
||||
for _, endpoint := range claims.Endpoints {
|
||||
if endpoint == "*" || strings.HasSuffix(r.URL.Path, endpoint) {
|
||||
|
|
|
@ -0,0 +1,326 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ava-labs/gecko/utils/timer"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
const (
|
||||
password = "password"
|
||||
)
|
||||
|
||||
var (
|
||||
// Always returns 200 (http.StatusOK)
|
||||
dummyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
)
|
||||
|
||||
func TestNewTokenWrongPassword(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
if _, err := auth.newToken("", []string{"endpoint1, endpoint2"}); err == nil {
|
||||
t.Fatal("should have failed because password is wrong")
|
||||
} else if _, err := auth.newToken("notThePassword", []string{"endpoint1, endpoint2"}); err == nil {
|
||||
t.Fatal("should have failed because password is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTokenHappyPath(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
now := time.Now()
|
||||
auth.clock.Set(now)
|
||||
|
||||
// Make a token
|
||||
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
||||
auth.lock.RLock()
|
||||
defer auth.lock.RUnlock()
|
||||
return []byte(auth.Password), nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't parse new token: %s", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*endpointClaims)
|
||||
if !ok {
|
||||
t.Fatal("expected auth token's claims to be type endpointClaims but is different type")
|
||||
}
|
||||
if !reflect.DeepEqual(claims.Endpoints, endpoints) {
|
||||
t.Fatal("token has wrong endpoint claims")
|
||||
}
|
||||
if shouldExpireAt := now.Add(TokenLifespan).Unix(); shouldExpireAt != now.Add(TokenLifespan).Unix() {
|
||||
t.Fatalf("token expiration time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHasWrongSig(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Make a token
|
||||
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Try to parse the token using the wrong password
|
||||
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
||||
auth.lock.RLock()
|
||||
defer auth.lock.RUnlock()
|
||||
return []byte(""), nil
|
||||
}); err == nil {
|
||||
t.Fatalf("should have failed because password is wrong")
|
||||
}
|
||||
|
||||
// Try to parse the token using the wrong password
|
||||
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
||||
auth.lock.RLock()
|
||||
defer auth.lock.RUnlock()
|
||||
return []byte("notThePassword"), nil
|
||||
}); err == nil {
|
||||
t.Fatalf("should have failed because password is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePassword(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
password2 := "password2"
|
||||
if err := auth.changePassword("", password2); err == nil {
|
||||
t.Fatal("should have failed because old password is wrong")
|
||||
} else if err := auth.changePassword("notThePassword", password2); err == nil {
|
||||
t.Fatal("should have failed because old password is wrong")
|
||||
} else if err := auth.changePassword(password, ""); err == nil {
|
||||
t.Fatal("should have failed because new password is empty")
|
||||
} else if err := auth.changePassword(password, password2); err != nil {
|
||||
t.Fatal("should have succeeded")
|
||||
}
|
||||
|
||||
if auth.Password != password2 {
|
||||
t.Fatal("password should have been changed")
|
||||
}
|
||||
|
||||
password3 := "password3"
|
||||
if err := auth.changePassword(password, password3); err == nil {
|
||||
t.Fatal("should have failed because old password is wrong")
|
||||
} else if err := auth.changePassword(password2, password3); err != nil {
|
||||
t.Fatal("should have succeeded")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetToken(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "http://127.0.0.1:9650/ext/auth", strings.NewReader(""))
|
||||
if _, err := getToken(req); err == nil {
|
||||
t.Fatal("should have failed because no auth token given")
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "")
|
||||
if _, err := getToken(req); err == nil {
|
||||
t.Fatal("should have failed because auth token invalid")
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "this isn't an auth token!")
|
||||
if _, err := getToken(req); err == nil {
|
||||
t.Fatal("should have failed because auth token invalid")
|
||||
}
|
||||
|
||||
wellFormedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJFbmRwb2ludHMiOlsiKiJdLCJleHAiOjE1OTM0NzU4OTR9.Cqo7TraN_CFN13q3ae4GRJCMgd8ZOlQwBzyC29M6Aps"
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", wellFormedToken))
|
||||
if token, err := getToken(req); err != nil {
|
||||
t.Fatal("should have been able to parse valid header")
|
||||
} else if token != wellFormedToken {
|
||||
t.Fatal("parsed token incorrectly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerHappyPath(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Make a token
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatal("should have passed authorization")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerExpiredToken(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
clock: timer.Clock{},
|
||||
}
|
||||
auth.clock.Set(time.Now().Add(-2 * TokenLifespan))
|
||||
|
||||
// Make a token that expired well in the past
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatal("should have failed authorization because token is expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerNoAuthToken(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
for _, endpoint := range endpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatal("should have failed authorization since no auth token given")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Make a token
|
||||
endpoints := []string{"/ext/info"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
unauthorizedEndpoints := []string{"/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
for _, endpoint := range unauthorizedEndpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatal("should have failed authorization since this endpoint is not allowed by the token")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerAuthEndpoint(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Make a token
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
|
||||
tokenStr, err := auth.newToken(password, endpoints)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", fmt.Sprintf("/ext/%s", Endpoint)), strings.NewReader(""))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatal("should always allow access to the auth endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerAccessAll(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: true,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Make a token that allows access to all endpoints
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info"}
|
||||
tokenStr, err := auth.newToken(password, []string{"*"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
for _, endpoint := range endpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatal("* in token should have allowed access to all endpoints")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapHandlerAuthDisabled(t *testing.T) {
|
||||
auth := Auth{
|
||||
Enabled: false,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info", "/ext/auth"}
|
||||
|
||||
wrappedHandler := auth.WrapHandler(dummyHandler)
|
||||
for _, endpoint := range endpoints {
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatal("auth is disabled so should allow access to all endpoints")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -38,7 +38,7 @@ type Success struct {
|
|||
|
||||
// Password ...
|
||||
type Password struct {
|
||||
Password string `json:"password"` // The authotization password
|
||||
Password string `json:"password"` // The authorization password
|
||||
}
|
||||
|
||||
// NewTokenArgs ...
|
||||
|
|
|
@ -30,7 +30,7 @@ func (s *Service) Call(_ *http.Request, args *Args, reply *Reply) error {
|
|||
|
||||
func TestCall(t *testing.T) {
|
||||
s := Server{}
|
||||
s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080)
|
||||
s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080, false, "")
|
||||
|
||||
serv := &Service{}
|
||||
newServer := rpc.NewServer()
|
||||
|
|
Loading…
Reference in New Issue