mirror of https://github.com/poanetwork/gecko.git
380 lines
11 KiB
Go
380 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ava-labs/gecko/utils/hashing"
|
|
|
|
"github.com/ava-labs/gecko/utils/timer"
|
|
|
|
jwt "github.com/dgrijalva/jwt-go"
|
|
)
|
|
|
|
var (
|
|
testPassword = "password!@#$%$#@!"
|
|
hashedPassword = hashing.ComputeHash256([]byte(testPassword))
|
|
)
|
|
|
|
var (
|
|
// Always returns 200 (http.StatusOK)
|
|
dummyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
|
)
|
|
|
|
func TestNewTokenWrongPassword(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
if _, err := auth.newToken("", []string{"endpoint1, endpoint2"}); err == nil {
|
|
t.Fatal("should have failed because password is wrong")
|
|
} else if _, err := auth.newToken("notThePassword", []string{"endpoint1, endpoint2"}); err == nil {
|
|
t.Fatal("should have failed because password is wrong")
|
|
}
|
|
}
|
|
|
|
func TestNewTokenHappyPath(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
now := time.Now()
|
|
auth.clock.Set(now)
|
|
|
|
// Make a token
|
|
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Parse the token
|
|
token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
|
auth.lock.RLock()
|
|
defer auth.lock.RUnlock()
|
|
return auth.HashedPassword, nil
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("couldn't parse new token: %s", err)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*endpointClaims)
|
|
if !ok {
|
|
t.Fatal("expected auth token's claims to be type endpointClaims but is different type")
|
|
}
|
|
if !reflect.DeepEqual(claims.Endpoints, endpoints) {
|
|
t.Fatal("token has wrong endpoint claims")
|
|
}
|
|
if shouldExpireAt := now.Add(TokenLifespan).Unix(); shouldExpireAt != now.Add(TokenLifespan).Unix() {
|
|
t.Fatalf("token expiration time is wrong")
|
|
}
|
|
}
|
|
|
|
func TestTokenHasWrongSig(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"endpoint1", "endpoint2", "endpoint3"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Try to parse the token using the wrong password
|
|
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
|
auth.lock.RLock()
|
|
defer auth.lock.RUnlock()
|
|
return []byte(""), nil
|
|
}); err == nil {
|
|
t.Fatalf("should have failed because password is wrong")
|
|
}
|
|
|
|
// Try to parse the token using the wrong password
|
|
if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) {
|
|
auth.lock.RLock()
|
|
defer auth.lock.RUnlock()
|
|
return []byte("notThePassword"), nil
|
|
}); err == nil {
|
|
t.Fatalf("should have failed because password is wrong")
|
|
}
|
|
}
|
|
|
|
func TestChangePassword(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
password2 := "fejhkefjhefjhefhje"
|
|
if err := auth.changePassword("", password2); err == nil {
|
|
t.Fatal("should have failed because old password is wrong")
|
|
} else if err := auth.changePassword("notThePassword", password2); err == nil {
|
|
t.Fatal("should have failed because old password is wrong")
|
|
} else if err := auth.changePassword(testPassword, ""); err == nil {
|
|
t.Fatal("should have failed because new password is empty")
|
|
} else if err := auth.changePassword(testPassword, password2); err != nil {
|
|
t.Fatal("should have succeeded")
|
|
}
|
|
|
|
if !bytes.Equal(auth.HashedPassword, hashing.ComputeHash256([]byte(password2))) {
|
|
t.Fatal("password should have been changed")
|
|
}
|
|
|
|
password3 := "ufwhwohwfohawfhwdwd"
|
|
if err := auth.changePassword(testPassword, password3); err == nil {
|
|
t.Fatal("should have failed because old password is wrong")
|
|
} else if err := auth.changePassword(password2, password3); err != nil {
|
|
t.Fatal("should have succeeded")
|
|
}
|
|
|
|
}
|
|
|
|
func TestGetToken(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "http://127.0.0.1:9650/ext/auth", strings.NewReader(""))
|
|
if _, err := getToken(req); err == nil {
|
|
t.Fatal("should have failed because no auth token given")
|
|
}
|
|
|
|
req.Header.Add("Authorization", "")
|
|
if _, err := getToken(req); err == nil {
|
|
t.Fatal("should have failed because auth token invalid")
|
|
}
|
|
|
|
req.Header.Set("Authorization", "this isn't an auth token!")
|
|
if _, err := getToken(req); err == nil {
|
|
t.Fatal("should have failed because auth token invalid")
|
|
}
|
|
|
|
wellFormedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJFbmRwb2ludHMiOlsiKiJdLCJleHAiOjE1OTM0NzU4OTR9.Cqo7TraN_CFN13q3ae4GRJCMgd8ZOlQwBzyC29M6Aps"
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", wellFormedToken))
|
|
if token, err := getToken(req); err != nil {
|
|
t.Fatal("should have been able to parse valid header")
|
|
} else if token != wellFormedToken {
|
|
t.Fatal("parsed token incorrectly")
|
|
}
|
|
}
|
|
|
|
func TestRevokeToken(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := auth.revokeToken(tokenStr, testPassword); err != nil {
|
|
t.Fatal("should have succeeded")
|
|
} else if len(auth.revoked) != 1 || auth.revoked[0] != tokenStr {
|
|
t.Fatal("revoked token list is incorrect")
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerHappyPath(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatal("should have passed authorization")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerRevokedToken(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := auth.revokeToken(tokenStr, testPassword); err != nil {
|
|
t.Fatalf("should have been able to revoke token but got: %s", err)
|
|
}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatal("should have failed authorization because token was revoked")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerExpiredToken(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
clock: timer.Clock{},
|
|
}
|
|
auth.clock.Set(time.Now().Add(-2 * TokenLifespan))
|
|
|
|
// Make a token that expired well in the past
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatal("should have failed authorization because token is expired")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerNoAuthToken(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"}
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatal("should have failed authorization since no auth token given")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"/ext/info"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
unauthorizedEndpoints := []string{"/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
for _, endpoint := range unauthorizedEndpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Fatal("should have failed authorization since this endpoint is not allowed by the token")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerAuthEndpoint(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"}
|
|
tokenStr, err := auth.newToken(testPassword, endpoints)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", fmt.Sprintf("/ext/%s", Endpoint)), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatal("should always allow access to the auth endpoint")
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerAccessAll(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: true,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
// Make a token that allows access to all endpoints
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info"}
|
|
tokenStr, err := auth.newToken(testPassword, []string{"*"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatal("* in token should have allowed access to all endpoints")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWrapHandlerAuthDisabled(t *testing.T) {
|
|
auth := Auth{
|
|
Enabled: false,
|
|
HashedPassword: hashedPassword,
|
|
}
|
|
|
|
endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info", "/ext/auth"}
|
|
|
|
wrappedHandler := auth.WrapHandler(dummyHandler)
|
|
for _, endpoint := range endpoints {
|
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader(""))
|
|
rr := httptest.NewRecorder()
|
|
wrappedHandler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatal("auth is disabled so should allow access to all endpoints")
|
|
}
|
|
}
|
|
}
|