Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for refresh_in for ConfidentialClient #542

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
RefreshIn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
Expand Down Expand Up @@ -463,6 +464,7 @@ func TestADFSTokenCaching(t *testing.T) {
AccessToken: "at1",
RefreshToken: "rt",
TokenType: "bearer",
RefreshIn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
Expand Down Expand Up @@ -812,6 +814,62 @@ func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.Rep
return nil
}

func TestAcquireTokenSilentRefreshIn(t *testing.T) {

for _, test := range []struct {
expireOn int
refreshIn int
}{
{3600, 1},
{7200, 3600},
} {
cache := make(testCache)
accessToken := "*"
lmo := "login.microsoftonline.com"
tenantA, tenantB := "a", "b"
authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB)
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", test.expireOn, test.refreshIn)))

cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
account := ar.Account
if actual := account.Realm; actual != tenantA {
t.Fatalf(`unexpected realm "%s"`, actual)
}

// a client configured for a different tenant should be able to authenticate silently with the shared cache's data
client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// this should succeed because the cache contains an access token from tenantA
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA)))
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA))
if err != nil && test.refreshIn > 1 {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
}
}

func TestWithCache(t *testing.T) {
cache := make(testCache)
accessToken := "*"
Expand Down
6 changes: 5 additions & 1 deletion apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type AuthResult struct {
Account shared.Account
IDToken accesstokens.IDToken
AccessToken string
RefreshIn time.Time
4gust marked this conversation as resolved.
Show resolved Hide resolved
ExpiresOn time.Time
GrantedScopes []string
DeclinedScopes []string
Expand Down Expand Up @@ -128,6 +129,7 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
Account: account,
IDToken: idToken,
AccessToken: accessToken,
RefreshIn: storageTokenResponse.AccessToken.RefreshIn.T,
ExpiresOn: storageTokenResponse.AccessToken.ExpiresOn.T,
GrantedScopes: grantedScopes,
DeclinedScopes: nil,
Expand Down Expand Up @@ -346,7 +348,9 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
if ar.RefreshIn.IsZero() || ar.RefreshIn.After(time.Now()) {
return ar, err
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion apps/internal/base/internal/storage/items.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type AccessToken struct {
ClientID string `json:"client_id,omitempty"`
Secret string `json:"secret,omitempty"`
Scopes string `json:"target,omitempty"`
RefreshIn internalTime.Unix `json:"refresh_in,omitempty"`
4gust marked this conversation as resolved.
Show resolved Hide resolved
ExpiresOn internalTime.Unix `json:"expires_on,omitempty"`
ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"`
CachedAt internalTime.Unix `json:"cached_at,omitempty"`
Expand All @@ -83,7 +84,7 @@ type AccessToken struct {
}

// NewAccessToken is the constructor for AccessToken.
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshIn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
return AccessToken{
HomeAccountID: homeID,
Environment: env,
Expand All @@ -93,6 +94,7 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex
Secret: token,
Scopes: scopes,
CachedAt: internalTime.Unix{T: cachedAt.UTC()},
RefreshIn: internalTime.Unix{T: refreshIn.UTC()},
ExpiresOn: internalTime.Unix{T: expiresOn.UTC()},
ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()},
TokenType: tokenType,
Expand Down
3 changes: 3 additions & 0 deletions apps/internal/base/internal/storage/items_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ var (
)

func TestCreateAccessToken(t *testing.T) {

testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testRefreshIn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC)
actualAt := NewAccessToken("testHID",
"env",
"realm",
"clientID",
testCachedAt,
testRefreshIn,
testExpiresOn,
testExtExpiresOn,
"user.read",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
realm,
clientID,
cachedAt,
tokenResponse.RefreshIn.T,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ func TestReadPartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid user.read",
"secret",
"Bearer",
Expand Down Expand Up @@ -211,6 +212,7 @@ func TestWritePartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid",
"secret",
"tokenType",
Expand Down
1 change: 1 addition & 0 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
realm,
clientID,
cachedAt,
tokenResponse.RefreshIn.T,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
21 changes: 15 additions & 6 deletions apps/internal/base/internal/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ func TestReadAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid user.read",
"secret",
"tokenType",
Expand All @@ -241,6 +242,7 @@ func TestReadAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid user.read",
"secret2",
"",
Expand Down Expand Up @@ -343,6 +345,7 @@ func TestWriteAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid",
"secret",
"tokenType",
Expand Down Expand Up @@ -848,6 +851,7 @@ func TestIsAccessTokenValid(t *testing.T) {
cachedAt := time.Now()
badCachedAt := time.Now().Add(500 * time.Second)
expiresOn := time.Now().Add(1000 * time.Second)
refreshIn := time.Now().Add(1000 * time.Second)
badExpiresOn := time.Now().Add(200 * time.Second)
extended := time.Now()

Expand All @@ -858,16 +862,16 @@ func TestIsAccessTokenValid(t *testing.T) {
}{
{
desc: "Success",
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, expiresOn, extended, "openid", "secret", "tokenType", ""),
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""),
},
{
desc: "ExpiresOnUnixTimestamp has expired",
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, badExpiresOn, extended, "openid", "secret", "tokenType", ""),
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, badExpiresOn, extended, "openid", "secret", "tokenType", ""),
err: true,
},
{
desc: "Success",
token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, expiresOn, extended, "openid", "secret", "tokenType", ""),
token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""),
err: true,
},
}
Expand All @@ -890,6 +894,7 @@ func TestRead(t *testing.T) {
"realm",
"cid",
time.Now(),
time.Now(),
time.Now().Add(1000*time.Second),
time.Now(),
"openid profile",
Expand Down Expand Up @@ -1008,13 +1013,16 @@ func TestWrite(t *testing.T) {
PreferredUsername: "username",
}
expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)}
timeRemaining := expiresOn.T.Sub(now) / 2
refreshIn := internalTime.DurationTime{T: now.Add(timeRemaining)}
tokenResponse := accesstokens.TokenResponse{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
IDToken: idToken,
FamilyID: "fid",
ClientInfo: clientInfo,
GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}},
RefreshIn: refreshIn,
ExpiresOn: expiresOn,
ExtExpiresOn: internalTime.DurationTime{T: now},
TokenType: "Bearer",
Expand All @@ -1039,6 +1047,7 @@ func TestWrite(t *testing.T) {
"realm",
"cid",
now,
now.Add(500*time.Second),
now.Add(1000*time.Second),
now,
"openid profile",
Expand Down Expand Up @@ -1136,7 +1145,7 @@ func TestRemoveRefreshTokens(t *testing.T) {
func TestRemoveAccessTokens(t *testing.T) {
now := time.Now()
storageManager := newForTest(nil)
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid", "secret", "tokenType", "")
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid", "secret", "tokenType", "")
key := testAccessToken.Key()
contract := &Contract{
AccessTokens: map[string]AccessToken{
Expand Down Expand Up @@ -1187,7 +1196,7 @@ func TestRemoveAccountObject(t *testing.T) {

func TestRemoveAccount(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret", "tokenType", "")
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "")
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
testAppMeta := NewAppMetaData("fid", "cid", "env")
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
Expand Down Expand Up @@ -1229,7 +1238,7 @@ func TestRemoveAccount(t *testing.T) {

func TestRemoveEmptyAccount(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret", "tokenType", "")
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "")
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
testAppMeta := NewAppMetaData("fid", "cid", "env")
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
Expand Down
17 changes: 17 additions & 0 deletions apps/internal/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, e
return []byte(body)
}

func GetAccessTokenBodyWithRefreshIn(accessToken, idToken, refreshToken, clientInfo string, expiresIn int, refreshIn int) []byte {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's quite a lot of dup code. Why don't you add a parameter to the existing method instead? or have one call the other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change this

body := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"refresh_in":%d ,"expires_on": %d,"token_type": "Bearer"`,
accessToken, expiresIn, refreshIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(),
)
if clientInfo != "" {
body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo)
}
if idToken != "" {
body += fmt.Sprintf(`, "id_token": "%s"`, idToken)
}
if refreshToken != "" {
body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken)
}
body += "}"
return []byte(body)
}
func GetIDToken(tenant, issuer string) string {
now := time.Now().Unix()
payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant))
Expand Down
59 changes: 59 additions & 0 deletions apps/internal/oauth/ops/accesstokens/accesstokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,50 @@ func TestTokenResponseUnmarshal(t *testing.T) {
},
jwtDecoder: jwtDecoderFake,
},
{
desc: "Success",
4gust marked this conversation as resolved.
Show resolved Hide resolved
payload: `
{
"access_token": "secret",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you want to add refresh_in here?

"expires_in": 3600,
"ext_expires_in": 86399,
"client_info": {"uid": "uid","utid": "utid"},
"scope": "openid profile"
}`,
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Unix(3600, 0)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
UID: "uid",
UTID: "utid",
},
},
jwtDecoder: jwtDecoderFake,
},
{
desc: "Success",
4gust marked this conversation as resolved.
Show resolved Hide resolved
payload: `
{
"access_token": "secret",
"expires_in": 36000,
"ext_expires_in": 86399,
"client_info": {"uid": "uid","utid": "utid"},
"scope": "openid profile"
}`,
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Unix(36000, 0)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
UID: "uid",
UTID: "utid",
},
},
jwtDecoder: jwtDecoderFake,
},
}

for _, test := range tests {
Expand All @@ -795,6 +839,21 @@ func TestTokenResponseUnmarshal(t *testing.T) {
case err != nil:
continue
}
now := time.Now()
timeRemaining := got.ExpiresOn.T.Sub(now)
if got.ExpiresOn.T.Before(time.Now().Add(time.Hour * 2)) {
expectedRefreshIn := now.Add(timeRemaining)
const tolerance = 100 * time.Millisecond
if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance {
t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn)
}
} else {
expectedRefreshIn := now.Add(timeRemaining / 2)
const tolerance = 100 * time.Millisecond
if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance {
t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn)
}
}

// Note: IncludeUnexported prevents minor differences in time.Time due to internal fields.
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" {
Expand Down
Loading
Loading