Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for refresh_in for ConfidentialClient #542

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 93 additions & 3 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
Expand Down Expand Up @@ -138,6 +139,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
RefreshOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
4gust marked this conversation as resolved.
Show resolved Hide resolved
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
Expand Down Expand Up @@ -255,7 +257,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token, "", "rt", "", 3600, 7400)))
4gust marked this conversation as resolved.
Show resolved Hide resolved

client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
Expand All @@ -278,7 +280,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {
}
// new assertion should trigger new token request
token2 := token + "2"
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token2, "", "rt", "", 3600, 360)))
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -708,7 +710,8 @@ func TestNewCredFromCertError(t *testing.T) {
func TestNewCredFromTokenProvider(t *testing.T) {
expectedToken := "expected token"
called := false
expiresIn := 4200
expiresIn := 18000
refreshOn := expiresIn / 2
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
Expand All @@ -728,6 +731,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
return exported.TokenProviderResult{
AccessToken: expectedToken,
ExpiresInSeconds: expiresIn,
RefreshInSeconds: refreshOn,
}, nil
})
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
Expand All @@ -744,6 +748,9 @@ func TestNewCredFromTokenProvider(t *testing.T) {
if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v)
}
if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshOn-2 || v > refreshOn {
4gust marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, v)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
Expand All @@ -756,6 +763,89 @@ func TestNewCredFromTokenProvider(t *testing.T) {
}
}

func TestRefreshIn(t *testing.T) {
4gust marked this conversation as resolved.
Show resolved Hide resolved
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
firstToken := "first token"
secondToken := "new token"
lmo := "login.microsoftonline.com"
tenant := "tenant"
refreshIn := 43200
expiresIn := 86400
for _, tt := range []struct {
shouldGetNewToken bool
secondRequestAfter int
shouldReturnError bool
}{
{secondRequestAfter: 40000, shouldGetNewToken: false}, // from cache
{secondRequestAfter: 43400, shouldGetNewToken: true}, // refresh in expired so new token
{secondRequestAfter: 40000, shouldGetNewToken: false, shouldReturnError: true}, // refresh in not expired but refresh failed so new token
{secondRequestAfter: 80000, shouldGetNewToken: true, shouldReturnError: false}, // refresh in expired but refresh failed so new token
{secondRequestAfter: 1003400, shouldGetNewToken: true},
} {
name := "token doesn't need refresh"
t.Run(name, func(t *testing.T) {
originalTime := base.GetCurrentTime
defer func() {
base.GetCurrentTime = originalTime
}()
// Create a mock client and append mock responses
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
if tt.shouldReturnError {
mockClient.AppendResponse(
mock.WithCode(http.StatusBadGateway),
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
} else {
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))),
)
}

// Create the client instance
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
// Acquire the first token
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
// Assert the first token is returned
if ar.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken)
}
if ar.Metadata.RefreshOn.IsZero() {
t.Fatal("RefreshOn shouldn't be zero")
}
if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshIn-10 || v > refreshIn+10 {
t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, v)
}
fixedTime := time.Now().Add(time.Duration(tt.secondRequestAfter) * time.Second)
base.GetCurrentTime = func() time.Time {
return fixedTime
}
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.Metadata.TokenSource != base.Cache && !tt.shouldGetNewToken {
t.Fatal("should have returned from cache.")
}
if (ar.AccessToken == secondToken) != tt.shouldGetNewToken {
t.Fatalf("wanted %q, got %q", secondToken, ar.AccessToken)
}
})
}
}

func TestNewCredFromTokenProviderError(t *testing.T) {
expectedError := "something went wrong"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
Expand Down
40 changes: 36 additions & 4 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ package base

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"sync"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
Expand Down Expand Up @@ -83,9 +84,10 @@ type AcquireTokenOnBehalfOfParameters struct {
// AuthResult contains the results of one token acquisition operation in PublicClientApplication
// or ConfidentialClientApplication. For details see https://aka.ms/msal-net-authenticationresult
type AuthResult struct {
Account shared.Account
IDToken accesstokens.IDToken
AccessToken string
Account shared.Account
IDToken accesstokens.IDToken
AccessToken string
//RefreshOn indicates the recommended time to request a new access token, or zero if no refresh time is suggested
ExpiresOn time.Time
GrantedScopes []string
DeclinedScopes []string
Expand All @@ -94,6 +96,7 @@ type AuthResult struct {

// AuthResultMetadata which contains meta data for the AuthResult
type AuthResultMetadata struct {
RefreshOn time.Time
TokenSource TokenSource
}

Expand Down Expand Up @@ -133,6 +136,7 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
DeclinedScopes: nil,
Metadata: AuthResultMetadata{
TokenSource: Cache,
RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T,
},
}, nil
}
Expand All @@ -150,6 +154,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
GrantedScopes: tokenResponse.GrantedScopes.Slice,
Metadata: AuthResultMetadata{
TokenSource: IdentityProvider,
RefreshOn: tokenResponse.RefreshOn.T,
},
}, nil
}
Expand Down Expand Up @@ -345,6 +350,24 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
if silent.Claims == "" {
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) {
if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil {
return b.AuthResultFromToken(ctx, authParams, tr, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is there any code path where AuthResultFromToken is called with the "cacheWrite" flag set to false? If not, consider getting rid of that param.

} else if callErr, ok := er.(*errors.CallErr); ok {
// Check if the error is of type CallErr and matches the relevant status codes
switch callErr.Resp.StatusCode {
case http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
bgavrilMS marked this conversation as resolved.
Show resolved Hide resolved
default:
// Handle non-retryable errors
return AuthResult{}, er
}
}
}
ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
Expand Down Expand Up @@ -458,6 +481,15 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au
return ar, err
}

// This function wraps time.Now() and is used for refreshing the application
// was created to test the function against refreshin
var GetCurrentTime = time.Now
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Making sure this isn't part of the public API?
  2. Does it make sense to keep it higher up?


// shouldRefresh returns true if the token should be refreshed.
func shouldRefresh(t time.Time) bool {
return !t.IsZero() && t.Before(GetCurrentTime())
}

func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
Expand Down
37 changes: 37 additions & 0 deletions apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,40 @@ func TestAuthResultFromStorage(t *testing.T) {
}
}
}

// TestShouldRefresh tests the shouldRefresh function
func TestShouldRefresh(t *testing.T) {
// Get the current time to use for comparison
now := time.Now()

tests := []struct {
name string
input time.Time
expected bool
}{
{
name: "Zero time",
input: time.Time{}, // Zero time
expected: false, // Should return false because it's zero time
},
{
name: "Future time",
input: now.Add(time.Hour), // 1 hour in the future
expected: false, // Should return false because it's in the future
},
{
name: "Past time",
input: now.Add(-time.Hour), // 1 hour in the past
expected: true, // Should return true because it's in the past
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldRefresh(tt.input)
if result != tt.expected {
t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected)
}
})
}
}
4 changes: 3 additions & 1 deletion apps/internal/base/internal/storage/items.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type AccessToken struct {
ClientID string `json:"client_id,omitempty"`
Secret string `json:"secret,omitempty"`
Scopes string `json:"target,omitempty"`
RefreshOn internalTime.Unix `json:"refresh_on,omitempty"`
ExpiresOn internalTime.Unix `json:"expires_on,omitempty"`
ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"`
CachedAt internalTime.Unix `json:"cached_at,omitempty"`
Expand All @@ -83,7 +84,7 @@ type AccessToken struct {
}

// NewAccessToken is the constructor for AccessToken.
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshOn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
return AccessToken{
HomeAccountID: homeID,
Environment: env,
Expand All @@ -93,6 +94,7 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex
Secret: token,
Scopes: scopes,
CachedAt: internalTime.Unix{T: cachedAt.UTC()},
RefreshOn: internalTime.Unix{T: refreshOn.UTC()},
ExpiresOn: internalTime.Unix{T: expiresOn.UTC()},
ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()},
TokenType: tokenType,
Expand Down
3 changes: 3 additions & 0 deletions apps/internal/base/internal/storage/items_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ var (
)

func TestCreateAccessToken(t *testing.T) {

testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testRefreshOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC)
actualAt := NewAccessToken("testHID",
"env",
"realm",
"clientID",
testCachedAt,
testRefreshOn,
testExpiresOn,
testExtExpiresOn,
"user.read",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
realm,
clientID,
cachedAt,
tokenResponse.RefreshOn.T,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ func TestReadPartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid user.read",
"secret",
"Bearer",
Expand Down Expand Up @@ -211,6 +212,7 @@ func TestWritePartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid",
"secret",
"tokenType",
Expand Down
1 change: 1 addition & 0 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
realm,
clientID,
cachedAt,
tokenResponse.RefreshOn.T,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
Loading
Loading