From 3d37ff89885b319cd457fd71870f034c60320e3d Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 17 Oct 2024 13:17:59 +0200 Subject: [PATCH] more cases for refresh_user - do not refresh if auth_state is disabled (would force re-login every 5 minutes in default config) - always refresh if refresh_token is defined - if refresh_token not available, only check validity of access_token and refresh associated user info --- oauthenticator/oauth2.py | 64 ++++++++++++++++---------- oauthenticator/tests/mocks.py | 52 ++++++++++++++++----- oauthenticator/tests/test_generic.py | 68 ++++++++++++++++++++++++++++ oauthenticator/tests/test_github.py | 2 +- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 5c37f0db..4cceaba6 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -9,7 +9,6 @@ import json import os import secrets -import time import uuid from functools import reduce from inspect import isawaitable @@ -1071,9 +1070,8 @@ def build_refresh_token_request_params(self, refresh_token): # when basic authentication is used # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 if not self.basic_auth: - params.update( - {"client_id": self.client_id, "client_secret": self.client_secret,} - ) + params["client_id"] = self.client_id + params["client_secret"] = self.client_secret return params @@ -1312,48 +1310,68 @@ async def authenticate(self, handler, data=None, **kwargs): """ # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) + token_info = await self.get_token_info(handler, access_token_params) # call the oauth endpoints - return await self._oauth_call(handler, access_token_params, **kwargs) + return await self._token_to_auth_model(token_info) async def refresh_user(self, user, handler=None, **kwargs): """ Renew the Access Token with a valid Refresh Token """ + if not self.enable_auth_state: + # auth state not enabled, can't refresh + return True auth_state = await user.get_auth_state() if not auth_state: self.log.info( f"No auth_state found for user {user.name} refresh, need full authentication", ) return False + refresh_token = auth_state.get("refresh_token", None) + if refresh_token: + refresh_token_params = self.build_refresh_token_request_params( + refresh_token + ) + try: + token_info = await self.get_token_info(handler, refresh_token_params) + except Exception as e: + self.log.info( + f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired." + ) + return False + # refresh_token may not be returned when refreshing a token + if not token_info.get("refresh_token"): + token_info["refresh_token"] = refresh_token + else: + # no refresh token, check access token validity + self.log.debug( + f"No refresh token for user {user.name}, checking access_token validity" + ) + token_info = auth_state.get("token_response") + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # handle more specific errors? + # e.g. expired token! + self.log.info( + f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired." + ) + return False + else: + # return False if auth_model is None for no-longer-authorized + return auth_model or False - refresh_token_params = self.build_refresh_token_request_params( - auth_state['refresh_token'] - ) - return await self._oauth_call(handler, refresh_token_params, **kwargs) - - async def _oauth_call(self, handler, params, **kwargs): + async def _token_to_auth_model(self, token_info): """ Common logic shared by authenticate() and refresh_user() """ - # exchange the oauth code for an access token and get the JSON with info about it - token_info = await self.get_token_info(handler, params) # use the access_token to get userdata info user_info = await self.token_to_user(token_info) # extract the username out of the user_info dict and normalize it username = self.user_info_to_username(user_info) username = self.normalize_username(username) - # check if there any refresh_token in the token_info dict - refresh_token = token_info.get("refresh_token", None) - if self.enable_auth_state and not refresh_token: - self.log.debug( - "Refresh token was empty, will try to pull refresh_token from previous auth_state" - ) - refresh_token = await self.get_prev_refresh_token(handler, username) - if refresh_token: - token_info["refresh_token"] = refresh_token - auth_state = self.build_auth_state_dict(token_info, user_info) if isawaitable(auth_state): auth_state = await auth_state diff --git a/oauthenticator/tests/mocks.py b/oauthenticator/tests/mocks.py index d9174604..5186f5ef 100644 --- a/oauthenticator/tests/mocks.py +++ b/oauthenticator/tests/mocks.py @@ -107,6 +107,7 @@ def setup_oauth_mock( user_path=None, token_type='Bearer', token_request_style='post', + enable_refresh_tokens=False, scope="", ): """setup the mock client for OAuth @@ -134,6 +135,8 @@ def setup_oauth_mock( client.oauth_codes = oauth_codes = {} client.access_tokens = access_tokens = {} + client.refresh_tokens = refresh_tokens = {} + client.enable_refresh_tokens = enable_refresh_tokens def access_token(request): """Handler for access token endpoint @@ -146,26 +149,53 @@ def access_token(request): if not query: query = request.body.decode('utf8') query = parse_qs(query) - if 'code' not in query: + grant_type = query.get("grant_type", [""])[0] + if grant_type == 'authorization_code': + if 'code' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No code in access token request: url={request.url}, body={request.body}", + ) + code = query['code'][0] + if code not in oauth_codes: + return HTTPResponse( + request=request, code=403, reason=f"No such code: {code}" + ) + user = oauth_codes.pop(code) + elif grant_type == 'refresh_token': + if 'refresh_token' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No refresh_token in access token request: url={request.url}, body={request.body}", + ) + refresh_token = query['refresh_token'][0] + if refresh_token not in refresh_token: + return HTTPResponse( + request=request, + code=403, + reason=f"No such refresh_toekn: {refresh_token}", + ) + user = refresh_tokens[refresh_token] + else: return HTTPResponse( request=request, code=400, - reason=f"No code in access token request: url={request.url}, body={request.body}", - ) - code = query['code'][0] - if code not in oauth_codes: - return HTTPResponse( - request=request, code=403, reason=f"No such code: {code}" + reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}", ) # consume code, allocate token - token = uuid.uuid4().hex - user = oauth_codes.pop(code) - access_tokens[token] = user + access_token = uuid.uuid4().hex + access_tokens[access_token] = user model = { - 'access_token': token, + 'access_token': access_token, 'token_type': token_type, } + if client.enable_refresh_tokens: + refresh_token = uuid.uuid4().hex + refresh_tokens[refresh_token] = user + model['refresh_token'] = refresh_token if scope: model['scope'] = scope if 'id_token' in user: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index e3225039..bb14f72f 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -505,6 +505,74 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed): assert await authenticator.check_allowed(name, None) +class MockUser: + """Mock subset of JupyterHub User API from the `auth_model` dict""" + + name: str + + def __init__(self, auth_model): + self._auth_model = auth_model + self.name = auth_model["name"] + + async def get_auth_state(self): + return self._auth_model["auth_state"] + + +@mark.parametrize("enable_refresh_tokens", [True, False]) +async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): + generic_client.enable_refresh_tokens = enable_refresh_tokens + authenticator = get_authenticator(allowed_users={"user1"}) + handled_user_model = user_model("user1", permissions={"groups": ["super_user"]}) + handler = generic_client.handler_for_user(handled_user_model) + auth_model = await authenticator.get_authenticated_user(handler, None) + auth_state = auth_model["auth_state"] + if enable_refresh_tokens: + assert "refresh_token" in auth_state + assert "refresh_token" in auth_state["token_response"] + assert ( + auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"] + ) + else: + assert "refresh_token" not in auth_state["token_response"] + assert auth_state.get("refresh_token") is None + user = MockUser(auth_model) + # case: auth_state not enabled, nothing to refresh + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is True + + # from here on, enable auth state required for refresh to do anything + authenticator.enable_auth_state = True + + # case: no auth state, but auth state enabled needs refresh + auth_without_state = auth_model.copy() + auth_without_state["auth_state"] = None + user_without_state = MockUser(auth_without_state) + refreshed = await authenticator.refresh_user(user_without_state, handler) + assert refreshed is False + + # case: actually refresh + refreshed = await authenticator.refresh_user(user, handler) + assert isinstance(refreshed, dict) + assert refreshed["name"] == auth_model["name"] + refreshed_state = refreshed["auth_state"] + assert "access_token" in refreshed_state + if enable_refresh_tokens: + # refresh_token refreshed the access token + assert refreshed_state["access_token"] != auth_state["access_token"] + assert refreshed_state["refresh_token"] + else: + # refresh with access token succeeds, keeps access token unchanged + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) + if enable_refresh_tokens: + generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False + + @mark.parametrize( "test_variation_id,class_config,expect_config,expect_loglevel,expect_message", [ diff --git a/oauthenticator/tests/test_github.py b/oauthenticator/tests/test_github.py index e49fe064..1ca5f590 100644 --- a/oauthenticator/tests/test_github.py +++ b/oauthenticator/tests/test_github.py @@ -141,7 +141,7 @@ async def test_github( assert user_info == handled_user_model assert auth_model["name"] == user_info[authenticator.username_claim] else: - assert auth_model == None + assert auth_model is None def make_link_header(urlinfo, page):