diff --git a/api/auth/auth.go b/api/auth/auth.go index 644b68d..8621c05 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/ava-labs/gecko/utils/timer" jwt "github.com/dgrijalva/jwt-go" ) @@ -35,6 +36,7 @@ var ( type Auth struct { lock sync.RWMutex // Prevent race condition when accessing password Enabled bool // True iff API calls need auth token + clock timer.Clock // Tells the time. Can be faked for testing Password string // The password. Can be changed via API call. revoked []string // List of tokens that have been revoked } @@ -80,7 +82,7 @@ func (auth *Auth) newToken(password string, endpoints []string) (string, error) } claims := endpointClaims{ StandardClaims: jwt.StandardClaims{ - ExpiresAt: time.Now().Add(TokenLifespan).Unix(), + ExpiresAt: auth.clock.Time().Add(TokenLifespan).Unix(), }, } if canAccessAll { @@ -180,11 +182,6 @@ func (auth *Auth) WrapHandler(h http.Handler) http.Handler { io.WriteString(w, "expected auth token's claims to be type endpointClaims but is different type") return } - if l := len(claims.Endpoints); l < 1 || l > maxEndpoints { - w.WriteHeader(http.StatusUnauthorized) - io.WriteString(w, fmt.Sprintf("expected auth token to allow access to between %d and %d endpoints, but does %d", 1, maxEndpoints, l)) - return - } canAccess := false // true iff the token authorizes access to the API for _, endpoint := range claims.Endpoints { if endpoint == "*" || strings.HasSuffix(r.URL.Path, endpoint) { diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..7ec8edb --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,326 @@ +package auth + +import ( + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "github.com/ava-labs/gecko/utils/timer" + + jwt "github.com/dgrijalva/jwt-go" +) + +const ( + password = "password" +) + +var ( + // Always returns 200 (http.StatusOK) + dummyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) +) + +func TestNewTokenWrongPassword(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + if _, err := auth.newToken("", []string{"endpoint1, endpoint2"}); err == nil { + t.Fatal("should have failed because password is wrong") + } else if _, err := auth.newToken("notThePassword", []string{"endpoint1, endpoint2"}); err == nil { + t.Fatal("should have failed because password is wrong") + } +} + +func TestNewTokenHappyPath(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + now := time.Now() + auth.clock.Set(now) + + // Make a token + endpoints := []string{"endpoint1", "endpoint2", "endpoint3"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + // Parse the token + token, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) { + auth.lock.RLock() + defer auth.lock.RUnlock() + return []byte(auth.Password), nil + }) + if err != nil { + t.Fatalf("couldn't parse new token: %s", err) + } + + claims, ok := token.Claims.(*endpointClaims) + if !ok { + t.Fatal("expected auth token's claims to be type endpointClaims but is different type") + } + if !reflect.DeepEqual(claims.Endpoints, endpoints) { + t.Fatal("token has wrong endpoint claims") + } + if shouldExpireAt := now.Add(TokenLifespan).Unix(); shouldExpireAt != now.Add(TokenLifespan).Unix() { + t.Fatalf("token expiration time is wrong") + } +} + +func TestTokenHasWrongSig(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + // Make a token + endpoints := []string{"endpoint1", "endpoint2", "endpoint3"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + // Try to parse the token using the wrong password + if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) { + auth.lock.RLock() + defer auth.lock.RUnlock() + return []byte(""), nil + }); err == nil { + t.Fatalf("should have failed because password is wrong") + } + + // Try to parse the token using the wrong password + if _, err := jwt.ParseWithClaims(tokenStr, &endpointClaims{}, func(*jwt.Token) (interface{}, error) { + auth.lock.RLock() + defer auth.lock.RUnlock() + return []byte("notThePassword"), nil + }); err == nil { + t.Fatalf("should have failed because password is wrong") + } +} + +func TestChangePassword(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + password2 := "password2" + if err := auth.changePassword("", password2); err == nil { + t.Fatal("should have failed because old password is wrong") + } else if err := auth.changePassword("notThePassword", password2); err == nil { + t.Fatal("should have failed because old password is wrong") + } else if err := auth.changePassword(password, ""); err == nil { + t.Fatal("should have failed because new password is empty") + } else if err := auth.changePassword(password, password2); err != nil { + t.Fatal("should have succeeded") + } + + if auth.Password != password2 { + t.Fatal("password should have been changed") + } + + password3 := "password3" + if err := auth.changePassword(password, password3); err == nil { + t.Fatal("should have failed because old password is wrong") + } else if err := auth.changePassword(password2, password3); err != nil { + t.Fatal("should have succeeded") + } + +} + +func TestGetToken(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://127.0.0.1:9650/ext/auth", strings.NewReader("")) + if _, err := getToken(req); err == nil { + t.Fatal("should have failed because no auth token given") + } + + req.Header.Add("Authorization", "") + if _, err := getToken(req); err == nil { + t.Fatal("should have failed because auth token invalid") + } + + req.Header.Set("Authorization", "this isn't an auth token!") + if _, err := getToken(req); err == nil { + t.Fatal("should have failed because auth token invalid") + } + + wellFormedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJFbmRwb2ludHMiOlsiKiJdLCJleHAiOjE1OTM0NzU4OTR9.Cqo7TraN_CFN13q3ae4GRJCMgd8ZOlQwBzyC29M6Aps" + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", wellFormedToken)) + if token, err := getToken(req); err != nil { + t.Fatal("should have been able to parse valid header") + } else if token != wellFormedToken { + t.Fatal("parsed token incorrectly") + } +} + +func TestWrapHandlerHappyPath(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + // Make a token + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + wrappedHandler := auth.WrapHandler(dummyHandler) + + for _, endpoint := range endpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatal("should have passed authorization") + } + } +} + +func TestWrapHandlerExpiredToken(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + clock: timer.Clock{}, + } + auth.clock.Set(time.Now().Add(-2 * TokenLifespan)) + + // Make a token that expired well in the past + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + wrappedHandler := auth.WrapHandler(dummyHandler) + + for _, endpoint := range endpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatal("should have failed authorization because token is expired") + } + } +} + +func TestWrapHandlerNoAuthToken(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics"} + wrappedHandler := auth.WrapHandler(dummyHandler) + for _, endpoint := range endpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatal("should have failed authorization since no auth token given") + } + } +} + +func TestWrapHandlerUnauthorizedEndpoint(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + // Make a token + endpoints := []string{"/ext/info"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + unauthorizedEndpoints := []string{"/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"} + + wrappedHandler := auth.WrapHandler(dummyHandler) + for _, endpoint := range unauthorizedEndpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatal("should have failed authorization since this endpoint is not allowed by the token") + } + } +} + +func TestWrapHandlerAuthEndpoint(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + // Make a token + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/info/foo"} + tokenStr, err := auth.newToken(password, endpoints) + if err != nil { + t.Fatal(err) + } + + wrappedHandler := auth.WrapHandler(dummyHandler) + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", fmt.Sprintf("/ext/%s", Endpoint)), strings.NewReader("")) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatal("should always allow access to the auth endpoint") + } +} + +func TestWrapHandlerAccessAll(t *testing.T) { + auth := Auth{ + Enabled: true, + Password: password, + } + + // Make a token that allows access to all endpoints + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info"} + tokenStr, err := auth.newToken(password, []string{"*"}) + if err != nil { + t.Fatal(err) + } + + wrappedHandler := auth.WrapHandler(dummyHandler) + for _, endpoint := range endpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tokenStr)) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatal("* in token should have allowed access to all endpoints") + } + } +} + +func TestWrapHandlerAuthDisabled(t *testing.T) { + auth := Auth{ + Enabled: false, + Password: password, + } + + endpoints := []string{"/ext/info", "/ext/bc/X", "/ext/metrics", "", "/foo", "/ext/foo/info", "/ext/auth"} + + wrappedHandler := auth.WrapHandler(dummyHandler) + for _, endpoint := range endpoints { + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:9650%s", endpoint), strings.NewReader("")) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatal("auth is disabled so should allow access to all endpoints") + } + } +} diff --git a/api/auth/service.go b/api/auth/service.go index a909521..4caff3e 100644 --- a/api/auth/service.go +++ b/api/auth/service.go @@ -38,7 +38,7 @@ type Success struct { // Password ... type Password struct { - Password string `json:"password"` // The authotization password + Password string `json:"password"` // The authorization password } // NewTokenArgs ... diff --git a/api/server_test.go b/api/server_test.go index dc0ba9c..98856c8 100644 --- a/api/server_test.go +++ b/api/server_test.go @@ -30,7 +30,7 @@ func (s *Service) Call(_ *http.Request, args *Args, reply *Reply) error { func TestCall(t *testing.T) { s := Server{} - s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080) + s.Initialize(logging.NoLog{}, logging.NoFactory{}, "localhost", 8080, false, "") serv := &Service{} newServer := rpc.NewServer()