diff --git a/go.mod b/go.mod index 9fade8e2..e9a3d5a0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 require ( github.com/google/go-cmp v0.6.0 github.com/stretchr/testify v1.8.4 + golang.org/x/sync v0.5.0 gopkg.in/go-jose/go-jose.v2 v2.6.1 ) diff --git a/go.sum b/go.sum index 0482f5c3..d88934d3 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/go-jose/go-jose.v2 v2.6.1 h1:qEzJlIDmG9q5VO0M/o8tGS65QMHMS1w01TQJB1VPJ4U= diff --git a/jwks/provider.go b/jwks/provider.go index 767bd5d9..808cae75 100644 --- a/jwks/provider.go +++ b/jwks/provider.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "golang.org/x/sync/semaphore" "gopkg.in/go-jose/go-jose.v2" "github.com/auth0/go-jwt-middleware/v2/internal/oidc" @@ -97,11 +98,16 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) { // CachingProvider handles getting JWKS from the specified IssuerURL // and caching them for CacheTTL time. It exposes KeyFunc which adheres // to the keyFunc signature that the Validator requires. +// When the CacheTTL value has been reached, a JWKS refresh will be triggered +// in the background and the existing cached JWKS will be returned until the +// JWKS cache is updated, or if the request errors then it will be evicted from +// the cache. type CachingProvider struct { *Provider CacheTTL time.Duration mu sync.RWMutex cache map[string]cachedJWKS + sem semaphore.Weighted } type cachedJWKS struct { @@ -120,6 +126,7 @@ func NewCachingProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...Prov Provider: NewProvider(issuerURL, opts...), CacheTTL: cacheTTL, cache: map[string]cachedJWKS{}, + sem: *semaphore.NewWeighted(1), } } @@ -132,10 +139,22 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) { issuer := c.IssuerURL.Hostname() if cached, ok := c.cache[issuer]; ok { - if !time.Now().After(cached.expiresAt) { - c.mu.RUnlock() - return cached.jwks, nil + if time.Now().After(cached.expiresAt) && c.sem.TryAcquire(1) { + go func() { + defer c.sem.Release(1) + refreshCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + _, err := c.refreshKey(refreshCtx, issuer) + + if err != nil { + c.mu.Lock() + delete(c.cache, issuer) + c.mu.Unlock() + } + }() } + c.mu.RUnlock() + return cached.jwks, nil } c.mu.RUnlock() diff --git a/jwks/provider_test.go b/jwks/provider_test.go index 66cb244d..76820108 100644 --- a/jwks/provider_test.go +++ b/jwks/provider_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/go-jose/go-jose.v2" @@ -84,7 +85,8 @@ func Test_JWKSProvider(t *testing.T) { } }) - t.Run("It re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) { + t.Run("It eventually re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) { + requestCount = 0 expiredCachedJWKS, err := generateJWKS() require.NoError(t, err) @@ -94,16 +96,20 @@ func Test_JWKSProvider(t *testing.T) { expiresAt: time.Now().Add(-10 * time.Minute), } - actualJWKS, err := provider.KeyFunc(context.Background()) + returnedJWKS, err := provider.KeyFunc(context.Background()) require.NoError(t, err) - if !cmp.Equal(expectedJWKS, actualJWKS) { - t.Fatalf("jwks did not match: %s", cmp.Diff(expectedJWKS, actualJWKS)) + if !cmp.Equal(expiredCachedJWKS, returnedJWKS) { + t.Fatalf("jwks did not match: %s", cmp.Diff(expiredCachedJWKS, returnedJWKS)) } - if !cmp.Equal(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks) { - t.Fatalf("cached jwks did not match: %s", cmp.Diff(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks)) - } + require.EventuallyWithT(t, func(c *assert.CollectT) { + returnedJWKS, err := provider.KeyFunc(context.Background()) + require.NoError(t, err) + + assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS)) + assert.Equal(c, int32(2), requestCount) + }, 1*time.Second, 250*time.Millisecond, "JWKS did not update") cacheExpiresAt := provider.cache[testServerURL.Hostname()].expiresAt if !time.Now().Before(cacheExpiresAt) { @@ -154,6 +160,86 @@ func Test_JWKSProvider(t *testing.T) { } }, ) + + t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with expired cache", func(t *testing.T) { + initialJWKS, err := generateJWKS() + require.NoError(t, err) + requestCount = 0 + + provider := NewCachingProvider(testServerURL, 5*time.Minute) + provider.cache[testServerURL.Hostname()] = cachedJWKS{ + jwks: initialJWKS, + expiresAt: time.Now(), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + _, _ = provider.KeyFunc(context.Background()) + wg.Done() + }() + } + wg.Wait() + + require.EventuallyWithT(t, func(c *assert.CollectT) { + returnedJWKS, err := provider.KeyFunc(context.Background()) + require.NoError(t, err) + + assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS)) + assert.Equal(c, int32(2), requestCount) + }, 1*time.Second, 250*time.Millisecond, "JWKS did not update") + }) + + t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with no cache", func(t *testing.T) { + provider := NewCachingProvider(testServerURL, 5*time.Minute) + requestCount = 0 + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + _, _ = provider.KeyFunc(context.Background()) + wg.Done() + }() + } + wg.Wait() + + if requestCount != 2 { + t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", requestCount) + } + }) + + t.Run("Should delete cache entry if the refresh request fails", func(t *testing.T) { + malformedURL, err := url.Parse(testServer.URL + "/malformed") + require.NoError(t, err) + + expiredCachedJWKS, err := generateJWKS() + require.NoError(t, err) + + provider := NewCachingProvider(malformedURL, 5*time.Minute) + provider.cache[malformedURL.Hostname()] = cachedJWKS{ + jwks: expiredCachedJWKS, + expiresAt: time.Now().Add(-10 * time.Minute), + } + + // Trigger the refresh of the JWKS, which should return the cached JWKS + returnedJWKS, err := provider.KeyFunc(context.Background()) + require.NoError(t, err) + assert.Equal(t, expiredCachedJWKS, returnedJWKS) + + // Eventually it should return a nil JWKS + require.EventuallyWithT(t, func(c *assert.CollectT) { + returnedJWKS, err := provider.KeyFunc(context.Background()) + require.Error(t, err) + + assert.Nil(c, returnedJWKS) + + cachedJWKS := provider.cache[malformedURL.Hostname()].jwks + + assert.Nil(t, cachedJWKS) + }, 1*time.Second, 250*time.Millisecond, "JWKS did not get uncached") + }) } func generateJWKS() (*jose.JSONWebKeySet, error) {