diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index 7e1620a7..0c0178c6 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -3,6 +3,7 @@ """ import base64 import os +import time from functools import reduce from urllib.parse import urlencode @@ -131,8 +132,10 @@ def _get_user_data(self, token_response): @staticmethod def _create_auth_state(token_response, user_data_response): - access_token = token_response['access_token'] + now = time.time() + access_token = token_response.get('access_token', None) refresh_token = token_response.get('refresh_token', None) + expires_in = token_response.get('expires_in', 0) scope = token_response.get('scope', '') if isinstance(scope, str): scope = scope.split(' ') @@ -141,9 +144,14 @@ def _create_auth_state(token_response, user_data_response): 'access_token': access_token, 'refresh_token': refresh_token, 'oauth_user': user_data_response, + 'expires_at': now + float(expires_in), 'scope': scope, } + @staticmethod + def is_auth_token_expired(auth_state: dict): + return time.time() < float(auth_state.get('expires_at', 0)) + @staticmethod def check_user_in_groups(member_groups, allowed_groups): return bool(set(member_groups) & set(allowed_groups)) diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index eca2d0a9..6b6c2566 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -1,6 +1,7 @@ from functools import partial +from time import time -from pytest import fixture +from pytest import approx, fixture from ..generic import GenericOAuthenticator from .mocks import setup_oauth_mock @@ -56,6 +57,7 @@ async def test_generic(get_authenticator, generic_client): assert 'access_token' in auth_state assert 'oauth_user' in auth_state assert 'refresh_token' in auth_state + assert 'expires_at' in auth_state assert 'scope' in auth_state @@ -199,3 +201,44 @@ async def test_generic_callable_groups_claim_key_with_allowed_groups_and_admin_g user_info = await authenticator.authenticate(handler) assert user_info['name'] == 'zoe' assert user_info['admin'] is True + + +async def test_expires_at(get_authenticator, generic_client): + authenticator = get_authenticator() + + handler = get_simple_handler(generic_client) + + now = time() + user_info = await authenticator.authenticate(handler) + + assert type(user_info.get('auth_state').get('expires_at')) is float + # the expires_at in this mocked example will be the current time which should be created + # pretty much at the same time as the now variable + assert approx(now, 0.01) == user_info.get('auth_state').get('expires_at') + + +async def test_is_auth_token_expired(get_authenticator, generic_client): + authenticator = get_authenticator() + + # mock auth_state result + expired_token_auth_state = { + 'access_token': '4701dcf296cc4a8fa8040a754f6e9ef3', + 'expires_at': 1631611075.6157327, + 'oauth_user': {'scope': 'basic', 'username': 'wash'}, + 'refresh_token': None, + 'scopes': None, + } + assert ( + authenticator.is_auth_token_expired(auth_state=expired_token_auth_state) + is False + ) + valid_token_auth_state = { + 'access_token': '4701dcf296cc4a8fa8040a754f6e9ef3', + 'expires_at': time() + 3600, + 'oauth_user': {'scope': 'basic', 'username': 'wash'}, + 'refresh_token': None, + 'scopes': None, + } + assert ( + authenticator.is_auth_token_expired(auth_state=valid_token_auth_state) is True + )