From 0977eb4dceb75f4d334e140fa2a85ef0384949f1 Mon Sep 17 00:00:00 2001 From: Chris Moran Date: Wed, 3 Jul 2024 07:21:42 -0400 Subject: [PATCH] feat: support multiple issuer:audience combinations by introducing an option for the expectedClaims. WithExpectedClaims can be called with multiple jwt.Expected parameters to allow different Issuer:Audience combinations to validate tokens feat: support multiple issuers in a provider using WithAdditionalIssuers option Every effort has been made to ensure backwards compatibility. Some error messages will be different due to the wrapping of errors when multiple jwt.Expected are set. When validating the jwt, if an error is encountered, instead of returning immediately, the current error is wrapped. This is good and bad. Good because all verification failure causes are captured in a single wrapped error; Bad because all verification failure causes are captured in a single monolithic wrapped error. Unwrapping the error can be tedious if many jwt.Expected are included. There is likely a better way but this suits my purposes. A few more test cases will likely be needed in order to achieve true confidence in this change --- README.md | 1 - examples/gin-example/main.go | 54 ++++- examples/gin-example/middleware.go | 54 ++++- examples/http-example/main.go | 3 +- examples/http-jwks-example/main.go | 3 +- extractor.go | 4 +- extractor_test.go | 2 +- jwks/provider.go | 61 +++++- validator/option.go | 14 ++ validator/validator.go | 159 ++++++++++++--- validator/validator_test.go | 304 ++++++++++++++++++++++++++++- 11 files changed, 606 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index fbbb7b1a..72f20676 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ import ( "log" "net/http" - "github.com/auth0/go-jwt-middleware/v2" "github.com/auth0/go-jwt-middleware/v2/validator" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" ) diff --git a/examples/gin-example/main.go b/examples/gin-example/main.go index 03cc34e2..a9afffc2 100644 --- a/examples/gin-example/main.go +++ b/examples/gin-example/main.go @@ -39,9 +39,29 @@ import ( // "username": "user123", // "shouldReject": true // } +// +// You can also try out the /multiple endpoint. This endpoint accepts tokens signed by multiple issuers. Try the +// token below which has a different issuer: +// +// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1tdWx0aXBsZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtbXVsdGlwbGUtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.9zV_bY1wAmQlMCPlXOppx1Y9_z_T_wNng9-yfQk4I0c +// +// which is signed with 'secret' and has the data: +// +// { +// "iss": "go-jwt-middleware-multiple-example", +// "aud": "audience-multiple-example", +// "sub": "1234567890", +// "name": "John Doe", +// "iat": 1516239022, +// "username": "user123" +// } +// +// You can also try the previous tokens with the /multiple endpoint. The first token will be valid the second will fail because +// the custom validator rejects it (shouldReject: true) func main() { router := gin.Default() + router.GET("/", checkJWT(), func(ctx *gin.Context) { claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) if !ok { @@ -52,7 +72,37 @@ func main() { return } - customClaims, ok := claims.CustomClaims.(*CustomClaimsExample) + localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample) + if !ok { + ctx.AbortWithStatusJSON( + http.StatusInternalServerError, + map[string]string{"message": "Failed to cast custom JWT claims to specific type."}, + ) + return + } + + if len(localCustomClaims.Username) == 0 { + ctx.AbortWithStatusJSON( + http.StatusBadRequest, + map[string]string{"message": "Username in JWT claims was empty."}, + ) + return + } + + ctx.JSON(http.StatusOK, claims) + }) + + router.GET("/multiple", checkJWTMultiple(), func(ctx *gin.Context) { + claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) + if !ok { + ctx.AbortWithStatusJSON( + http.StatusInternalServerError, + map[string]string{"message": "Failed to get validated JWT claims."}, + ) + return + } + + localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample) if !ok { ctx.AbortWithStatusJSON( http.StatusInternalServerError, @@ -61,7 +111,7 @@ func main() { return } - if len(customClaims.Username) == 0 { + if len(localCustomClaims.Username) == 0 { ctx.AbortWithStatusJSON( http.StatusBadRequest, map[string]string{"message": "Username in JWT claims was empty."}, diff --git a/examples/gin-example/middleware.go b/examples/gin-example/middleware.go index 104cd07c..1752f7b2 100644 --- a/examples/gin-example/middleware.go +++ b/examples/gin-example/middleware.go @@ -2,6 +2,7 @@ package main import ( "context" + "gopkg.in/go-jose/go-jose.v2/jwt" "log" "net/http" "time" @@ -16,10 +17,12 @@ var ( signingKey = []byte("secret") // The issuer of our token. - issuer = "go-jwt-middleware-example" + issuer = "go-jwt-middleware-example" + issuerTwo = "go-jwt-middleware-multiple-example" // The audience of our token. - audience = []string{"audience-example"} + audience = []string{"audience-example"} + audienceTwo = []string{"audience-multiple-example"} // Our token must be signed using this data. keyFunc = func(ctx context.Context) (interface{}, error) { @@ -76,3 +79,50 @@ func checkJWT() gin.HandlerFunc { } } } + +func checkJWTMultiple() gin.HandlerFunc { + // Set up the validator. + jwtValidator, err := validator.NewValidator( + keyFunc, + validator.HS256, + validator.WithCustomClaims(customClaims), + validator.WithAllowedClockSkew(30*time.Second), + validator.WithExpectedClaims(jwt.Expected{ + Issuer: issuer, + Audience: audience, + }, jwt.Expected{ + Issuer: issuerTwo, + Audience: audienceTwo, + }), + ) + if err != nil { + log.Fatalf("failed to set up the validator: %v", err) + } + + errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("Encountered error while validating JWT: %v", err) + } + + middleware := jwtmiddleware.New( + jwtValidator.ValidateToken, + jwtmiddleware.WithErrorHandler(errorHandler), + ) + + return func(ctx *gin.Context) { + encounteredError := true + var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + encounteredError = false + ctx.Request = r + ctx.Next() + } + + middleware.CheckJWT(handler).ServeHTTP(ctx.Writer, ctx.Request) + + if encounteredError { + ctx.AbortWithStatusJSON( + http.StatusUnauthorized, + map[string]string{"message": "JWT is invalid."}, + ) + } + } +} diff --git a/examples/http-example/main.go b/examples/http-example/main.go index b7ad5eb9..d824b668 100644 --- a/examples/http-example/main.go +++ b/examples/http-example/main.go @@ -8,9 +8,8 @@ import ( "net/http" "time" - "github.com/auth0/go-jwt-middleware/v2" - "github.com/auth0/go-jwt-middleware/v2/validator" jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" + "github.com/auth0/go-jwt-middleware/v2/validator" ) var ( diff --git a/examples/http-jwks-example/main.go b/examples/http-jwks-example/main.go index 81776dcc..93ee1440 100644 --- a/examples/http-jwks-example/main.go +++ b/examples/http-jwks-example/main.go @@ -7,10 +7,9 @@ import ( "net/url" "time" - "github.com/auth0/go-jwt-middleware/v2" + jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" "github.com/auth0/go-jwt-middleware/v2/jwks" "github.com/auth0/go-jwt-middleware/v2/validator" - jwtmiddleware "github.com/auth0/go-jwt-middleware/v2" ) var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/extractor.go b/extractor.go index 376e513c..33882665 100644 --- a/extractor.go +++ b/extractor.go @@ -23,7 +23,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) { authHeaderParts := strings.Fields(authHeader) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") + return "", errors.New("authorization header format must be Bearer {token}") } return authHeaderParts[1], nil @@ -34,7 +34,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) { func CookieTokenExtractor(cookieName string) TokenExtractor { return func(r *http.Request) (string, error) { cookie, err := r.Cookie(cookieName) - if err == http.ErrNoCookie { + if errors.Is(err, http.ErrNoCookie) { return "", nil // No cookie, then no JWT, so no error. } diff --git a/extractor_test.go b/extractor_test.go index 3101847d..adca0443 100644 --- a/extractor_test.go +++ b/extractor_test.go @@ -38,7 +38,7 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) { "Authorization": []string{"i-am-a-token"}, }, }, - wantError: "Authorization header format must be Bearer {token}", + wantError: "authorization header format must be Bearer {token}", }, } diff --git a/jwks/provider.go b/jwks/provider.go index 808cae75..aa30b9b8 100644 --- a/jwks/provider.go +++ b/jwks/provider.go @@ -21,9 +21,10 @@ import ( // getting and caching JWKS which can help reduce request time and potential // rate limiting from your provider. type Provider struct { - IssuerURL *url.URL // Required. - CustomJWKSURI *url.URL // Optional. - Client *http.Client + IssuerURL *url.URL // Required. + CustomJWKSURI *url.URL // Optional. + AdditionalProviders []Provider // Optional + Client *http.Client } // ProviderOption is how options for the Provider are set up. @@ -32,14 +33,24 @@ type ProviderOption func(*Provider) // NewProvider builds and returns a new *Provider. func NewProvider(issuerURL *url.URL, opts ...ProviderOption) *Provider { p := &Provider{ - IssuerURL: issuerURL, - Client: &http.Client{}, + Client: &http.Client{}, + AdditionalProviders: make([]Provider, 0), + } + + if issuerURL != nil { + p.IssuerURL = issuerURL } for _, opt := range opts { opt(p) } + for _, provider := range p.AdditionalProviders { + if provider.Client == nil { + provider.Client = p.Client + } + } + return p } @@ -56,6 +67,21 @@ func WithCustomJWKSURI(jwksURI *url.URL) ProviderOption { func WithCustomClient(c *http.Client) ProviderOption { return func(p *Provider) { p.Client = c + for _, provider := range p.AdditionalProviders { + provider.Client = c + } + } +} + +// WithAdditionalProviders allows validation with mutliple IssuerURLs if desired. If multiple issuers are specified, +// a jwt may be signed by any of them and be considered valid +func WithAdditionalProviders(issuerURL *url.URL, customJWKSURI *url.URL) ProviderOption { + return func(p *Provider) { + p.AdditionalProviders = append(p.AdditionalProviders, Provider{ + IssuerURL: issuerURL, + CustomJWKSURI: customJWKSURI, + Client: p.Client, + }) } } @@ -63,6 +89,25 @@ func WithCustomClient(c *http.Client) ProviderOption { // While it returns an interface to adhere to keyFunc, as long as the // error is nil the type will be *jose.JSONWebKeySet. func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) { + rawJwks, err := p.keyFunc(ctx) + + if len(p.AdditionalProviders) == 0 { + return rawJwks, err + } else { + var jwks *jose.JSONWebKeySet + jwks = rawJwks.(*jose.JSONWebKeySet) + for _, provider := range p.AdditionalProviders { + if rawJwks, err = provider.keyFunc(ctx); err != nil { + continue + } else { + jwks.Keys = append(jwks.Keys, rawJwks.(*jose.JSONWebKeySet).Keys...) + } + } + return jwks, err + } +} + +func (p *Provider) keyFunc(ctx context.Context) (interface{}, error) { jwksURI := p.CustomJWKSURI if jwksURI == nil { wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.Client, *p.IssuerURL) @@ -85,10 +130,12 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) { if err != nil { return nil, err } - defer response.Body.Close() + defer func() { + _ = response.Body.Close() + }() var jwks jose.JSONWebKeySet - if err := json.NewDecoder(response.Body).Decode(&jwks); err != nil { + if err = json.NewDecoder(response.Body).Decode(&jwks); err != nil { return nil, fmt.Errorf("could not decode jwks: %w", err) } diff --git a/validator/option.go b/validator/option.go index 12c1cc61..bd318299 100644 --- a/validator/option.go +++ b/validator/option.go @@ -1,6 +1,7 @@ package validator import ( + "gopkg.in/go-jose/go-jose.v2/jwt" "time" ) @@ -26,3 +27,16 @@ func WithCustomClaims(f func() CustomClaims) Option { v.customClaims = f } } + +// WithExpectedClaims allows fine-grained customization of the expected claims +func WithExpectedClaims(expectedClaims ...jwt.Expected) Option { + return func(v *Validator) { + if len(expectedClaims) == 0 { + return + } + if v.expectedClaims == nil { + v.expectedClaims = make([]jwt.Expected, 0) + } + v.expectedClaims = append(v.expectedClaims, expectedClaims...) + } +} diff --git a/validator/validator.go b/validator/validator.go index 2a302493..b28dc948 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "gopkg.in/go-jose/go-jose.v2/jwt" @@ -30,7 +31,7 @@ const ( type Validator struct { keyFunc func(context.Context) (interface{}, error) // Required. signatureAlgorithm SignatureAlgorithm // Required. - expectedClaims jwt.Expected // Internal. + expectedClaims []jwt.Expected // Internal. customClaims func() CustomClaims // Optional. allowedClockSkew time.Duration // Optional. } @@ -66,11 +67,61 @@ func New( if keyFunc == nil { return nil, errors.New("keyFunc is required but was nil") } - if issuerURL == "" { - return nil, errors.New("issuer url is required but was empty") + if _, ok := allowedSigningAlgorithms[signatureAlgorithm]; !ok { + return nil, errors.New("unsupported signature algorithm") + } + + v := &Validator{ + keyFunc: keyFunc, + signatureAlgorithm: signatureAlgorithm, + expectedClaims: make([]jwt.Expected, 0), } - if len(audience) == 0 { + + for _, opt := range opts { + opt(v) + } + + if len(v.expectedClaims) == 0 && issuerURL == "" { + return nil, errors.New("issuer url is required but was empty") + } else if len(v.expectedClaims) == 0 && len(audience) == 0 { return nil, errors.New("audience is required but was empty") + } else if len(issuerURL) > 0 && len(audience) > 0 { + v.expectedClaims = append(v.expectedClaims, jwt.Expected{ + Issuer: issuerURL, + Audience: audience, + }) + } + + if len(v.expectedClaims) == 0 { + return nil, errors.New("expected claims but none provided") + } + + for i, expected := range v.expectedClaims { + if expected.Issuer == "" { + return nil, fmt.Errorf("issuer url %d is required but was empty", i) + } + if len(expected.Audience) == 0 { + return nil, fmt.Errorf("audience %d is required but was empty", i) + } + } + + return v, nil +} + +// NewValidator sets up a new Validator with the required keyFunc +// and signatureAlgorithm as well as custom options. +// This function has been added to provide an alternate function without the required issuer or audience parameters +// so they can be included in the opts parameter via WithExpectedClaims +// This function operates exactly like New with the exception of the two parameters issuer and audience and this function +// expects the inclusion of WithExpectedClaims with at least one valid expected claim. +// A valid expected claim would include an issuer and at least one audience +func NewValidator( + keyFunc func(context.Context) (interface{}, error), + signatureAlgorithm SignatureAlgorithm, + opts ...Option, +) (*Validator, error) { + if keyFunc == nil { + return nil, errors.New("keyFunc is required but was nil") } if _, ok := allowedSigningAlgorithms[signatureAlgorithm]; !ok { return nil, errors.New("unsupported signature algorithm") @@ -79,16 +130,26 @@ func New( v := &Validator{ keyFunc: keyFunc, signatureAlgorithm: signatureAlgorithm, - expectedClaims: jwt.Expected{ - Issuer: issuerURL, - Audience: audience, - }, + expectedClaims: make([]jwt.Expected, 0), } for _, opt := range opts { opt(v) } + if len(v.expectedClaims) == 0 { + return nil, errors.New("expected claims but none provided") + } + + for i, expected := range v.expectedClaims { + if expected.Issuer == "" { + return nil, fmt.Errorf("issuer url %d is required but was empty", i) + } + if len(expected.Audience) == 0 { + return nil, fmt.Errorf("audience %d is required but was empty", i) + } + } + return v, nil } @@ -134,38 +195,74 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte return validatedClaims, nil } -func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error { - expectedClaims := expected - expectedClaims.Time = time.Now() +func validateClaimsWithLeeway(actualClaims jwt.Claims, expectedIn []jwt.Expected, leeway time.Duration) error { + now := time.Now() + var currentError error + for _, expected := range expectedIn { + expectedClaims := expected + expectedClaims.Time = now - if actualClaims.Issuer != expectedClaims.Issuer { - return jwt.ErrInvalidIssuer - } + if actualClaims.Issuer != expectedClaims.Issuer { + currentError = createOrWrapError(currentError, jwt.ErrInvalidIssuer, actualClaims.Issuer, expectedClaims.Issuer) + continue + } - foundAudience := false - for _, value := range expectedClaims.Audience { - if actualClaims.Audience.Contains(value) { - foundAudience = true - break + foundAudience := false + for _, value := range expectedClaims.Audience { + if actualClaims.Audience.Contains(value) { + foundAudience = true + break + } + } + if !foundAudience { + currentError = createOrWrapError( + currentError, + jwt.ErrInvalidAudience, + strings.Join(actualClaims.Audience, ","), + strings.Join(expectedClaims.Audience, ","), + ) + continue } - } - if !foundAudience { - return jwt.ErrInvalidAudience - } - if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) { - return jwt.ErrNotValidYet - } + if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) { + return createOrWrapError( + currentError, + jwt.ErrNotValidYet, + actualClaims.NotBefore.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } - if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) { - return jwt.ErrExpired + if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) { + return createOrWrapError( + currentError, + jwt.ErrExpired, + actualClaims.Expiry.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } + + if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) { + return createOrWrapError( + currentError, + jwt.ErrIssuedInTheFuture, + actualClaims.IssuedAt.Time().String(), + expectedClaims.Time.Add(leeway).String(), + ) + } + + return nil } - if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) { - return jwt.ErrIssuedInTheFuture + return currentError +} + +func createOrWrapError(base, current error, actual, expected string) error { + if base == nil { + return current } - return nil + return errors.Join(base, fmt.Errorf("%v: %s vs %s", current, actual, expected)) } func validateSigningMethod(validAlg, tokenAlg string) error { diff --git a/validator/validator_test.go b/validator/validator_test.go index 08feeb14..84d986b2 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -234,7 +234,228 @@ func TestValidator_ValidateToken(t *testing.T) { } } -func TestNewValidator(t *testing.T) { +func TestNewValidator_ValidateToken(t *testing.T) { + const ( + issuer = "https://go-jwt-middleware.eu.auth0.com/" + audience = "https://go-jwt-middleware-api/" + subject = "1234567890" + issuerB = "https://go-jwt-middleware.us.auth0.com/" + audienceB = "https://go-jwt-middleware-api-b/" + subjectB = "0987654321" + ) + + testCases := []struct { + name string + token string + keyFunc func(context.Context) (interface{}, error) + algorithm SignatureAlgorithm + customClaims func() CustomClaims + expectedError error + expectedClaims *ValidatedClaims + }{ + { + name: "it successfully validates a token", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + }, + }, + { + name: "it successfully validates a token with custom claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return &testClaims{} + }, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + CustomClaims: &testClaims{ + Scope: "read:messages", + }, + }, + }, + { + name: "it throws an error when token has a different signing algorithm than the validator", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: RS256, + expectedError: errors.New(`signing method is invalid: expected "RS256" signing algorithm but token specified "HS256"`), + }, + { + name: "it throws an error when it cannot parse the token", + token: "", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("could not parse the token: go-jose/go-jose: compact JWS format must have three parts"), + }, + { + name: "it throws an error when it fails to fetch the keys from the key func", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.-R2K2tZHDrgsEh9JNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return nil, errors.New("key func error message") + }, + algorithm: HS256, + expectedError: errors.New("failed to deserialize token claims: error getting the keys from the key func: key func error message"), + }, + { + name: "it throws an error when it fails to deserialize the claims because the signature is invalid", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdfQ.vR2K2tZHDrgsEh9zNWcyk4aljtR6gZK0s2anNGlfwz0", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("failed to deserialize token claims: could not get token claims: go-jose/go-jose: error in cryptographic primitive"), + }, + { + name: "it throws an error when it fails to validate the registered claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIn0.VoIwDVmb--26wGrv93NmjNZYa4nrzjLw4JANgEjPI28", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: errors.New("go-jose/go-jose/jwt: validation failed, invalid audience claim (aud)"), + }, + { + name: "it throws an error when it fails to validate the custom claims", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return &testClaims{ + ReturnError: errors.New("custom claims error message"), + } + }, + expectedError: errors.New("custom claims not validated: custom claims error message"), + }, + { + name: "it successfully validates a token even if customClaims() returns nil", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJzY29wZSI6InJlYWQ6bWVzc2FnZXMifQ.oqtUZQ-Q8un4CPduUBdGVq5gXpQVIFT_QSQjkOXFT5I", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + customClaims: func() CustomClaims { + return nil + }, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + }, + CustomClaims: nil, + }, + }, + { + name: "it successfully validates a token with exp, nbf and iat", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo5NjY3OTM3Njg2fQ.FKZogkm08gTfYfPU6eYu7OHCjJKnKGLiC0IfoIOPEhs", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedClaims: &ValidatedClaims{ + RegisteredClaims: RegisteredClaims{ + Issuer: issuer, + Subject: subject, + Audience: []string{audience}, + Expiry: 9667937686, + NotBefore: 1666939000, + IssuedAt: 1666937686, + }, + }, + }, + { + name: "it throws an error when token is not valid yet", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6OTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.yUizJ-zK_33tv1qBVvDKO0RuCWtvJ02UQKs8gBadgGY", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrNotValidYet), + }, + { + name: "it throws an error when token is expired", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo2Njc5Mzc2ODZ9.SKvz82VOXRi_sjvZWIsPG9vSWAXKKgVS4DkGZcwFKL8", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrExpired), + }, + { + name: "it throws an error when token is issued in the future", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjkxNjY2OTM3Njg2LCJuYmYiOjE2NjY5MzkwMDAsImV4cCI6ODY2NzkzNzY4Nn0.ieFV7XNJxiJyw8ARq9yHw-01Oi02e3P2skZO10ypxL8", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrIssuedInTheFuture), + }, + { + name: "it throws an error when token issuer is invalid", + token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6WyJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLWFwaS8iXSwiaWF0Ijo5MTY2NjkzNzY4NiwibmJmIjoxNjY2OTM5MDAwLCJleHAiOjg2Njc5Mzc2ODZ9.b5gXNrUNfd_jyCWZF-6IPK_UFfvTr9wBQk9_QgRQ8rA", + keyFunc: func(context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + algorithm: HS256, + expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer), + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + validator, err := NewValidator( + testCase.keyFunc, + testCase.algorithm, + WithCustomClaims(testCase.customClaims), + WithAllowedClockSkew(time.Second), + WithExpectedClaims(jwt.Expected{ + Issuer: issuer, + Audience: []string{audience, "another-audience"}, + }, jwt.Expected{ + Issuer: issuerB, + Audience: []string{audienceB, "another-audienceb"}, + }), + ) + require.NoError(t, err) + + tokenClaims, err := validator.ValidateToken(context.Background(), testCase.token) + if testCase.expectedError != nil { + assert.ErrorContains(t, err, testCase.expectedError.Error()) + assert.Nil(t, tokenClaims) + } else { + require.NoError(t, err) + assert.Exactly(t, testCase.expectedClaims, tokenClaims) + } + }) + } +} + +func TestNew(t *testing.T) { const ( issuer = "https://go-jwt-middleware.eu.auth0.com/" audience = "https://go-jwt-middleware-api/" @@ -260,12 +481,12 @@ func TestNewValidator(t *testing.T) { assert.EqualError(t, err, "unsupported signature algorithm") }) - t.Run("it throws an error when the issuerURL is empty", func(t *testing.T) { + t.Run("it throws an error when the issuerURL is empty and no expectedClaims option", func(t *testing.T) { _, err := New(keyFunc, algorithm, "", []string{audience}) assert.EqualError(t, err, "issuer url is required but was empty") }) - t.Run("it throws an error when the audience is nil", func(t *testing.T) { + t.Run("it throws an error when the audience is nil if no expectedClaims option included", func(t *testing.T) { _, err := New(keyFunc, algorithm, issuer, nil) assert.EqualError(t, err, "audience is required but was empty") }) @@ -274,4 +495,81 @@ func TestNewValidator(t *testing.T) { _, err := New(keyFunc, algorithm, issuer, []string{}) assert.EqualError(t, err, "audience is required but was empty") }) + + t.Run("it throws an error when the issuerURL is empty and an expectedClaims option with only an audience", func(t *testing.T) { + _, err := New(keyFunc, algorithm, "", []string{}, WithExpectedClaims(jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 0 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience", func(t *testing.T) { + _, err := New(keyFunc, algorithm, issuer, []string{}, WithExpectedClaims(jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 0 is required but was empty") + }) + + t.Run("it throws no error when the issuerURL is empty but expectedClaims option included", func(t *testing.T) { + _, err := New(keyFunc, algorithm, "", []string{audience}, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) + + t.Run("it throws no error when the audience is nil but expectedClaims option included", func(t *testing.T) { + _, err := New(keyFunc, algorithm, issuer, nil, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) +} + +func TestNewValidator(t *testing.T) { + const ( + issuer = "https://go-jwt-middleware.eu.auth0.com/" + audience = "https://go-jwt-middleware-api/" + algorithm = HS256 + ) + + var keyFunc = func(context.Context) (interface{}, error) { + return []byte("secret"), nil + } + + t.Run("it throws an error when the keyFunc is nil", func(t *testing.T) { + _, err := NewValidator(nil, algorithm) + assert.EqualError(t, err, "keyFunc is required but was nil") + }) + + t.Run("it throws an error when the signature algorithm is empty", func(t *testing.T) { + _, err := NewValidator(keyFunc, "") + assert.EqualError(t, err, "unsupported signature algorithm") + }) + + t.Run("it throws an error when the signature algorithm is unsupported", func(t *testing.T) { + _, err := NewValidator(keyFunc, "none") + assert.EqualError(t, err, "unsupported signature algorithm") + }) + + t.Run("it throws an error when there are no expected claims", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm) + assert.EqualError(t, err, "expected claims but none provided") + }) + + t.Run("it throws an error when expectedClaims option with only an audience", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 0 is required but was empty") + }) + + t.Run("it throws an error when expectedClaims option with only an audience in the second jwt.Expected", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}}, jwt.Expected{Audience: []string{audience}})) + assert.EqualError(t, err, "issuer url 1 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 0 is required but was empty") + }) + + t.Run("it throws an error when the audience is empty and the expectedClaims are missing an audience in the second jwt.Expected", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}}, jwt.Expected{Issuer: issuer})) + assert.EqualError(t, err, "audience 1 is required but was empty") + }) + + t.Run("it throws no error when input is correct", func(t *testing.T) { + _, err := NewValidator(keyFunc, algorithm, WithExpectedClaims(jwt.Expected{Issuer: issuer, Audience: []string{audience}})) + assert.NoError(t, err, "no error was expected") + }) }