Skip to content

Commit

Permalink
more cases for refresh_user
Browse files Browse the repository at this point in the history
- do not refresh if auth_state is disabled (would force re-login every 5 minutes in default config)
- always refresh if refresh_token is defined
- if refresh_token not available, only check validity of access_token and refresh associated user info
  • Loading branch information
minrk committed Oct 17, 2024
1 parent 6156c93 commit 3d37ff8
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 35 deletions.
64 changes: 41 additions & 23 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import json
import os
import secrets
import time
import uuid
from functools import reduce
from inspect import isawaitable
Expand Down Expand Up @@ -1071,9 +1070,8 @@ def build_refresh_token_request_params(self, refresh_token):
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
if not self.basic_auth:
params.update(
{"client_id": self.client_id, "client_secret": self.client_secret,}
)
params["client_id"] = self.client_id
params["client_secret"] = self.client_secret

return params

Expand Down Expand Up @@ -1312,48 +1310,68 @@ async def authenticate(self, handler, data=None, **kwargs):
"""
# build the parameters to be used in the request exchanging the oauth code for the access token
access_token_params = self.build_access_tokens_request_params(handler, data)
token_info = await self.get_token_info(handler, access_token_params)
# call the oauth endpoints
return await self._oauth_call(handler, access_token_params, **kwargs)
return await self._token_to_auth_model(token_info)

async def refresh_user(self, user, handler=None, **kwargs):
"""
Renew the Access Token with a valid Refresh Token
"""
if not self.enable_auth_state:
# auth state not enabled, can't refresh
return True
auth_state = await user.get_auth_state()
if not auth_state:
self.log.info(
f"No auth_state found for user {user.name} refresh, need full authentication",
)
return False
refresh_token = auth_state.get("refresh_token", None)
if refresh_token:
refresh_token_params = self.build_refresh_token_request_params(
refresh_token
)
try:
token_info = await self.get_token_info(handler, refresh_token_params)
except Exception as e:
self.log.info(
f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired."
)
return False
# refresh_token may not be returned when refreshing a token
if not token_info.get("refresh_token"):
token_info["refresh_token"] = refresh_token
else:
# no refresh token, check access token validity
self.log.debug(
f"No refresh token for user {user.name}, checking access_token validity"
)
token_info = auth_state.get("token_response")
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# handle more specific errors?
# e.g. expired token!
self.log.info(
f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired."
)
return False
else:
# return False if auth_model is None for no-longer-authorized
return auth_model or False

refresh_token_params = self.build_refresh_token_request_params(
auth_state['refresh_token']
)
return await self._oauth_call(handler, refresh_token_params, **kwargs)

async def _oauth_call(self, handler, params, **kwargs):
async def _token_to_auth_model(self, token_info):
"""
Common logic shared by authenticate() and refresh_user()
"""

# exchange the oauth code for an access token and get the JSON with info about it
token_info = await self.get_token_info(handler, params)
# use the access_token to get userdata info
user_info = await self.token_to_user(token_info)
# extract the username out of the user_info dict and normalize it
username = self.user_info_to_username(user_info)
username = self.normalize_username(username)

# check if there any refresh_token in the token_info dict
refresh_token = token_info.get("refresh_token", None)
if self.enable_auth_state and not refresh_token:
self.log.debug(
"Refresh token was empty, will try to pull refresh_token from previous auth_state"
)
refresh_token = await self.get_prev_refresh_token(handler, username)
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if isawaitable(auth_state):
auth_state = await auth_state
Expand Down
52 changes: 41 additions & 11 deletions oauthenticator/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def setup_oauth_mock(
user_path=None,
token_type='Bearer',
token_request_style='post',
enable_refresh_tokens=False,
scope="",
):
"""setup the mock client for OAuth
Expand Down Expand Up @@ -134,6 +135,8 @@ def setup_oauth_mock(

client.oauth_codes = oauth_codes = {}
client.access_tokens = access_tokens = {}
client.refresh_tokens = refresh_tokens = {}
client.enable_refresh_tokens = enable_refresh_tokens

def access_token(request):
"""Handler for access token endpoint
Expand All @@ -146,26 +149,53 @@ def access_token(request):
if not query:
query = request.body.decode('utf8')
query = parse_qs(query)
if 'code' not in query:
grant_type = query.get("grant_type", [""])[0]
if grant_type == 'authorization_code':
if 'code' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
)
user = oauth_codes.pop(code)
elif grant_type == 'refresh_token':
if 'refresh_token' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No refresh_token in access token request: url={request.url}, body={request.body}",
)
refresh_token = query['refresh_token'][0]
if refresh_token not in refresh_token:
return HTTPResponse(
request=request,
code=403,
reason=f"No such refresh_toekn: {refresh_token}",
)
user = refresh_tokens[refresh_token]
else:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}",
)

# consume code, allocate token
token = uuid.uuid4().hex
user = oauth_codes.pop(code)
access_tokens[token] = user
access_token = uuid.uuid4().hex
access_tokens[access_token] = user
model = {
'access_token': token,
'access_token': access_token,
'token_type': token_type,
}
if client.enable_refresh_tokens:
refresh_token = uuid.uuid4().hex
refresh_tokens[refresh_token] = user
model['refresh_token'] = refresh_token
if scope:
model['scope'] = scope
if 'id_token' in user:
Expand Down
68 changes: 68 additions & 0 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,74 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed):
assert await authenticator.check_allowed(name, None)


class MockUser:
"""Mock subset of JupyterHub User API from the `auth_model` dict"""

name: str

def __init__(self, auth_model):
self._auth_model = auth_model
self.name = auth_model["name"]

async def get_auth_state(self):
return self._auth_model["auth_state"]


@mark.parametrize("enable_refresh_tokens", [True, False])
async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens):
generic_client.enable_refresh_tokens = enable_refresh_tokens
authenticator = get_authenticator(allowed_users={"user1"})
handled_user_model = user_model("user1", permissions={"groups": ["super_user"]})
handler = generic_client.handler_for_user(handled_user_model)
auth_model = await authenticator.get_authenticated_user(handler, None)
auth_state = auth_model["auth_state"]
if enable_refresh_tokens:
assert "refresh_token" in auth_state
assert "refresh_token" in auth_state["token_response"]
assert (
auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"]
)
else:
assert "refresh_token" not in auth_state["token_response"]
assert auth_state.get("refresh_token") is None
user = MockUser(auth_model)
# case: auth_state not enabled, nothing to refresh
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is True

# from here on, enable auth state required for refresh to do anything
authenticator.enable_auth_state = True

# case: no auth state, but auth state enabled needs refresh
auth_without_state = auth_model.copy()
auth_without_state["auth_state"] = None
user_without_state = MockUser(auth_without_state)
refreshed = await authenticator.refresh_user(user_without_state, handler)
assert refreshed is False

# case: actually refresh
refreshed = await authenticator.refresh_user(user, handler)
assert isinstance(refreshed, dict)
assert refreshed["name"] == auth_model["name"]
refreshed_state = refreshed["auth_state"]
assert "access_token" in refreshed_state
if enable_refresh_tokens:
# refresh_token refreshed the access token
assert refreshed_state["access_token"] != auth_state["access_token"]
assert refreshed_state["refresh_token"]
else:
# refresh with access token succeeds, keeps access token unchanged
assert refreshed_state["access_token"] == auth_state["access_token"]

# case: token used for refresh is no longer valid
user = MockUser(refreshed)
generic_client.access_tokens.pop(refreshed_state["access_token"])
if enable_refresh_tokens:
generic_client.refresh_tokens.pop(refreshed_state["refresh_token"])
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is False


@mark.parametrize(
"test_variation_id,class_config,expect_config,expect_loglevel,expect_message",
[
Expand Down
2 changes: 1 addition & 1 deletion oauthenticator/tests/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def test_github(
assert user_info == handled_user_model
assert auth_model["name"] == user_info[authenticator.username_claim]
else:
assert auth_model == None
assert auth_model is None


def make_link_header(urlinfo, page):
Expand Down

0 comments on commit 3d37ff8

Please sign in to comment.