diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 4cceaba6..123557cb 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1327,8 +1327,21 @@ async def refresh_user(self, user, handler=None, **kwargs): f"No auth_state found for user {user.name} refresh, need full authentication", ) return False + + token_info = auth_state.get("token_response") + auth_model = None + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # usually this means the access token has expired + # handle more specific errors? + self.log.info( + f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." + ) refresh_token = auth_state.get("refresh_token", None) - if refresh_token: + if refresh_token and not auth_model: + self.log.info(f"Refreshing oauth access token for {user.name}") + # access_token expired, try refreshing with refresh_token refresh_token_params = self.build_refresh_token_request_params( refresh_token ) @@ -1336,30 +1349,25 @@ async def refresh_user(self, user, handler=None, **kwargs): token_info = await self.get_token_info(handler, refresh_token_params) except Exception as e: self.log.info( - f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired." + f"Error using refresh_token for {user.name}: {e}. Requiring fresh login." ) return False # refresh_token may not be returned when refreshing a token + # in which case, keep the current one if not token_info.get("refresh_token"): token_info["refresh_token"] = refresh_token - else: - # no refresh token, check access token validity - self.log.debug( - f"No refresh token for user {user.name}, checking access_token validity" - ) - token_info = auth_state.get("token_response") - try: - auth_model = await self._token_to_auth_model(token_info) - except Exception as e: - # handle more specific errors? - # e.g. expired token! - self.log.info( - f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired." - ) - return False - else: - # return False if auth_model is None for no-longer-authorized - return auth_model or False + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # this means we were issued a fresh access token, + # but it didn't work! Fail harder? + self.log.error( + f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login." + ) + return False + + # return False if auth_model is None for "needs new login" + return auth_model or False async def _token_to_auth_model(self, token_info): """ diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index bb14f72f..e5c99602 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -15,14 +15,15 @@ def user_model(username, **kwargs): """Return a user model""" - return { + model = { "username": username, "aud": client_id, "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], - **kwargs, } + model.update(kwargs) + return model @fixture(params=["id_token", "userdata_url"]) @@ -522,10 +523,13 @@ async def get_auth_state(self): async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): generic_client.enable_refresh_tokens = enable_refresh_tokens authenticator = get_authenticator(allowed_users={"user1"}) - handled_user_model = user_model("user1", permissions={"groups": ["super_user"]}) - handler = generic_client.handler_for_user(handled_user_model) + authenticator.manage_groups = True + authenticator.auth_state_groups_key = "oauth_user.groups" + oauth_userinfo = user_model("user1", groups=["round1"]) + handler = generic_client.handler_for_user(oauth_userinfo) auth_model = await authenticator.get_authenticated_user(handler, None) auth_state = auth_model["auth_state"] + assert auth_model["groups"] == ["round1"] if enable_refresh_tokens: assert "refresh_token" in auth_state assert "refresh_token" in auth_state["token_response"] @@ -551,26 +555,39 @@ async def test_refresh_user(get_authenticator, generic_client, enable_refresh_to assert refreshed is False # case: actually refresh + oauth_userinfo["groups"] = ["refreshed"] refreshed = await authenticator.refresh_user(user, handler) - assert isinstance(refreshed, dict) + assert refreshed assert refreshed["name"] == auth_model["name"] + assert refreshed["groups"] == ["refreshed"] refreshed_state = refreshed["auth_state"] assert "access_token" in refreshed_state + # refresh with access token succeeds, keeps tokens unchanged + assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token") + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: access token is no longer valid, triggers refresh + oauth_userinfo["groups"] = ["token_refreshed"] + generic_client.access_tokens.pop(refreshed_state["access_token"]) + refreshed = await authenticator.refresh_user(user, handler) if enable_refresh_tokens: - # refresh_token refreshed the access token - assert refreshed_state["access_token"] != auth_state["access_token"] - assert refreshed_state["refresh_token"] + # access_token refreshed + assert refreshed + refreshed_state = refreshed["auth_state"] + assert ( + refreshed_state["access_token"] != auth_model["auth_state"]["access_token"] + ) + assert refreshed["groups"] == ["token_refreshed"] else: - # refresh with access token succeeds, keeps access token unchanged - assert refreshed_state["access_token"] == auth_state["access_token"] + assert refreshed is False - # case: token used for refresh is no longer valid - user = MockUser(refreshed) - generic_client.access_tokens.pop(refreshed_state["access_token"]) if enable_refresh_tokens: + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) - refreshed = await authenticator.refresh_user(user, handler) - assert refreshed is False + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False @mark.parametrize(