diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index 237c7d54..2fd07be7 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -61,20 +61,6 @@ def _login_service_default(self): """, ) - username_claim = Union( - [Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()], - config=True, - help=""" - When `userdata_url` returns a json response, the username will be taken - from this key. - - Can be a string key name or a callable that accepts the returned - userdata json (as a dict) and returns the username. The callable is - useful e.g. for extracting the username from a nested object in the - response. - """, - ) - @default("http_client") def _default_http_client(self): return AsyncHTTPClient( @@ -114,17 +100,6 @@ def _default_http_client(self): """, ) - def user_info_to_username(self, user_info): - """ - Overrides OAuthenticator.user_info_to_username to support the - GenericOAuthenticator unique feature of allowing username_claim to be a - callable function. - """ - if callable(self.username_claim): - return self.username_claim(user_info) - else: - return super().user_info_to_username(user_info) - def get_user_groups(self, user_info): """ Returns a set of groups the user belongs to based on claim_groups_key diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 60edeb36..19ce762c 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -19,7 +19,7 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest from tornado.httputil import url_concat from tornado.log import app_log -from traitlets import Any, Bool, Dict, List, Unicode, default +from traitlets import Any, Bool, Callable, Dict, List, Unicode, Union, default def guess_callback_uri(protocol, host, hub_server_url): @@ -377,14 +377,17 @@ def _token_url_default(self): def _userdata_url_default(self): return os.environ.get("OAUTH2_USERDATA_URL", "") - username_claim = Unicode( - "username", + username_claim = Union( + [Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()], config=True, help=""" - The key to get the JupyterHub username from in the data response to the - request made to :attr:`userdata_url`. + When `userdata_url` returns a json response, the username will be taken + from this key. - Examples include: email, username, nickname + Can be a string key name or a callable that accepts the returned + userdata json (as a dict) and returns the username. The callable is + useful e.g. for extracting the username from a nested object in the + response or doing other post processing. What keys are available will depend on the scopes requested and the authenticator used. @@ -769,7 +772,11 @@ def user_info_to_username(self, user_info): Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ - username = user_info.get(self.username_claim, None) + + if callable(self.username_claim): + username = self.username_claim(user_info) + else: + username = user_info.get(self.username_claim, None) if not username: message = (f"No {self.username_claim} found in {user_info}",) self.log.error(message) diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index fcc96652..6ab5474d 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -12,6 +12,7 @@ def user_model(username, **kwargs): """Return a user model""" return { "username": username, + "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], **kwargs, @@ -205,6 +206,32 @@ async def test_generic( assert auth_model == None +async def test_username_claim_callable( + get_authenticator, + generic_client, +): + c = Config() + c.GenericOAuthenticator = Config() + + def username_claim(user_info): + username = user_info["sub"] + if username.startswith("oauth2|cilogon"): + cilogon_sub = username.rsplit("|", 1)[-1] + cilogon_sub_parts = cilogon_sub.split("/") + username = f"oauth2|cilogon|{cilogon_sub_parts[3]}|{cilogon_sub_parts[5]}" + return username + + c.GenericOAuthenticator.username_claim = username_claim + c.GenericOAuthenticator.allow_all = True + authenticator = get_authenticator(config=c) + + handled_user_model = user_model("user1") + handler = generic_client.handler_for_user(handled_user_model) + auth_model = await authenticator.get_authenticated_user(handler, None) + + assert auth_model["name"] == "oauth2|cilogon|servera|43431" + + async def test_generic_data(get_authenticator, generic_client): c = Config() c.GenericOAuthenticator.allow_all = True