Skip to content

Commit

Permalink
only refresh expired access tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Oct 17, 2024
1 parent 3d37ff8 commit a8fe500
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 35 deletions.
48 changes: 28 additions & 20 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,39 +1327,47 @@ async def refresh_user(self, user, handler=None, **kwargs):
f"No auth_state found for user {user.name} refresh, need full authentication",
)
return False

token_info = auth_state.get("token_response")
auth_model = None
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# usually this means the access token has expired
# handle more specific errors?
self.log.info(
f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible."
)
refresh_token = auth_state.get("refresh_token", None)
if refresh_token:
if refresh_token and not auth_model:
self.log.info(f"Refreshing oauth access token for {user.name}")
# access_token expired, try refreshing with refresh_token
refresh_token_params = self.build_refresh_token_request_params(
refresh_token
)
try:
token_info = await self.get_token_info(handler, refresh_token_params)
except Exception as e:
self.log.info(
f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired."
f"Error using refresh_token for {user.name}: {e}. Requiring fresh login."
)
return False
# refresh_token may not be returned when refreshing a token
# in which case, keep the current one
if not token_info.get("refresh_token"):
token_info["refresh_token"] = refresh_token
else:
# no refresh token, check access token validity
self.log.debug(
f"No refresh token for user {user.name}, checking access_token validity"
)
token_info = auth_state.get("token_response")
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# handle more specific errors?
# e.g. expired token!
self.log.info(
f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired."
)
return False
else:
# return False if auth_model is None for no-longer-authorized
return auth_model or False
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# this means we were issued a fresh access token,
# but it didn't work! Fail harder?
self.log.error(
f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login."
)
return False

# return False if auth_model is None for "needs new login"
return auth_model or False

async def _token_to_auth_model(self, token_info):
"""
Expand Down
47 changes: 32 additions & 15 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

def user_model(username, **kwargs):
"""Return a user model"""
return {
model = {
"username": username,
"aud": client_id,
"sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431",
"scope": "basic",
"groups": ["group1"],
**kwargs,
}
model.update(kwargs)
return model


@fixture(params=["id_token", "userdata_url"])
Expand Down Expand Up @@ -522,10 +523,13 @@ async def get_auth_state(self):
async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens):
generic_client.enable_refresh_tokens = enable_refresh_tokens
authenticator = get_authenticator(allowed_users={"user1"})
handled_user_model = user_model("user1", permissions={"groups": ["super_user"]})
handler = generic_client.handler_for_user(handled_user_model)
authenticator.manage_groups = True
authenticator.auth_state_groups_key = "oauth_user.groups"
oauth_userinfo = user_model("user1", groups=["round1"])
handler = generic_client.handler_for_user(oauth_userinfo)
auth_model = await authenticator.get_authenticated_user(handler, None)
auth_state = auth_model["auth_state"]
assert auth_model["groups"] == ["round1"]
if enable_refresh_tokens:
assert "refresh_token" in auth_state
assert "refresh_token" in auth_state["token_response"]
Expand All @@ -551,26 +555,39 @@ async def test_refresh_user(get_authenticator, generic_client, enable_refresh_to
assert refreshed is False

# case: actually refresh
oauth_userinfo["groups"] = ["refreshed"]
refreshed = await authenticator.refresh_user(user, handler)
assert isinstance(refreshed, dict)
assert refreshed
assert refreshed["name"] == auth_model["name"]
assert refreshed["groups"] == ["refreshed"]
refreshed_state = refreshed["auth_state"]
assert "access_token" in refreshed_state
# refresh with access token succeeds, keeps tokens unchanged
assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token")
assert refreshed_state["access_token"] == auth_state["access_token"]

# case: access token is no longer valid, triggers refresh
oauth_userinfo["groups"] = ["token_refreshed"]
generic_client.access_tokens.pop(refreshed_state["access_token"])
refreshed = await authenticator.refresh_user(user, handler)
if enable_refresh_tokens:
# refresh_token refreshed the access token
assert refreshed_state["access_token"] != auth_state["access_token"]
assert refreshed_state["refresh_token"]
# access_token refreshed
assert refreshed
refreshed_state = refreshed["auth_state"]
assert (
refreshed_state["access_token"] != auth_model["auth_state"]["access_token"]
)
assert refreshed["groups"] == ["token_refreshed"]
else:
# refresh with access token succeeds, keeps access token unchanged
assert refreshed_state["access_token"] == auth_state["access_token"]
assert refreshed is False

# case: token used for refresh is no longer valid
user = MockUser(refreshed)
generic_client.access_tokens.pop(refreshed_state["access_token"])
if enable_refresh_tokens:
# case: token used for refresh is no longer valid
user = MockUser(refreshed)
generic_client.access_tokens.pop(refreshed_state["access_token"])
generic_client.refresh_tokens.pop(refreshed_state["refresh_token"])
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is False
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is False


@mark.parametrize(
Expand Down

0 comments on commit a8fe500

Please sign in to comment.