From edbf67ad4bc8333975e0f75a7ed067f614907d84 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Wed, 17 Jan 2024 13:53:30 -0800 Subject: [PATCH 1/3] Move username_claim being callable to oauth2 autheneticator While trying to use Auth0 for authentication in one of our hubs, we discovered that the most useful username_claim (`sub`) produces usernames that look like `oauth2|cilogon|http://cilogon.org/servera/users/43431` (when using auth0 with CILogon). The last part of `sub` is generally whatever is passed on to auth0, so it's going to be different for different users. I had thought `username_claim` was a callable, but turns out that's only true for GenericOAuthenticator. I think it's pretty useful for every authenticator, so I've just moved that functionality out to the base class instead. I also added a test to verify it works. The test is in GenericOAuthenticator because it was the easiest place to put it, but it works across authenticators. This also means it is fully backwards compatible. --- oauthenticator/generic.py | 11 ----------- oauthenticator/oauth2.py | 22 +++++++++++++++------- oauthenticator/tests/test_generic.py | 24 ++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index 31945d6a..4c7bb81c 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -113,17 +113,6 @@ def _default_http_client(self): """, ) - def user_info_to_username(self, user_info): - """ - Overrides OAuthenticator.user_info_to_username to support the - GenericOAuthenticator unique feature of allowing username_claim to be a - callable function. - """ - if callable(self.username_claim): - return self.username_claim(user_info) - else: - return super().user_info_to_username(user_info) - def get_user_groups(self, user_info): """ Returns a set of groups the user belongs to based on claim_groups_key diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 9ba66ecf..1aa210c3 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -18,7 +18,7 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest from tornado.httputil import url_concat from tornado.log import app_log -from traitlets import Any, Bool, Dict, List, Unicode, default +from traitlets import Any, Bool, Dict, List, Unicode, default, Union, Callable def guess_callback_uri(protocol, host, hub_server_url): @@ -376,14 +376,17 @@ def _token_url_default(self): def _userdata_url_default(self): return os.environ.get("OAUTH2_USERDATA_URL", "") - username_claim = Unicode( - "username", + username_claim = Union( + [Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()], config=True, help=""" - The key to get the JupyterHub username from in the data response to the - request made to :attr:`userdata_url`. + When `userdata_url` returns a json response, the username will be taken + from this key. - Examples include: email, username, nickname + Can be a string key name or a callable that accepts the returned + userdata json (as a dict) and returns the username. The callable is + useful e.g. for extracting the username from a nested object in the + response or doing other post processing. What keys are available will depend on the scopes requested and the authenticator used. @@ -768,7 +771,12 @@ def user_info_to_username(self, user_info): Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ - username = user_info.get(self.username_claim, None) + + + if callable(self.username_claim): + username = self.username_claim(user_info) + else: + username = user_info.get(self.username_claim, None) if not username: message = (f"No {self.username_claim} found in {user_info}",) self.log.error(message) diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 42dc7781..11d62755 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -12,6 +12,7 @@ def user_model(username, **kwargs): """Return a user model""" return { "username": username, + "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], **kwargs, @@ -186,6 +187,29 @@ async def test_generic( else: assert auth_model == None +async def test_username_claim_callable( + get_authenticator, + generic_client, +): + c = Config() + c.GenericOAuthenticator = Config() + def username_claim(user_info): + username = user_info["sub"] + if username.startswith("oauth2|cilogon"): + cilogon_sub = username.rsplit("|", 1)[-1] + cilogon_sub_parts = cilogon_sub.split("/") + username = f"oauth2|cilogon|{cilogon_sub_parts[3]}|{cilogon_sub_parts[5]}" + return username + c.GenericOAuthenticator.username_claim = username_claim + c.GenericOAuthenticator.allow_all = True + authenticator = get_authenticator(config=c) + + handled_user_model = user_model("user1") + handler = generic_client.handler_for_user(handled_user_model) + auth_model = await authenticator.get_authenticated_user(handler, None) + + assert auth_model["name"] == "oauth2|cilogon|servera|43431" + async def test_generic_data(get_authenticator, generic_client): c = Config() From 4417ec88e2b739f3ac1452cf4ce4dc83636d69e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 21:58:47 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- oauthenticator/oauth2.py | 3 +-- oauthenticator/tests/test_generic.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 1aa210c3..cff6b032 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -18,7 +18,7 @@ from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest from tornado.httputil import url_concat from tornado.log import app_log -from traitlets import Any, Bool, Dict, List, Unicode, default, Union, Callable +from traitlets import Any, Bool, Callable, Dict, List, Unicode, Union, default def guess_callback_uri(protocol, host, hub_server_url): @@ -772,7 +772,6 @@ def user_info_to_username(self, user_info): Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ - if callable(self.username_claim): username = self.username_claim(user_info) else: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index 11d62755..c30fe23e 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -187,12 +187,14 @@ async def test_generic( else: assert auth_model == None + async def test_username_claim_callable( get_authenticator, generic_client, ): c = Config() c.GenericOAuthenticator = Config() + def username_claim(user_info): username = user_info["sub"] if username.startswith("oauth2|cilogon"): @@ -200,6 +202,7 @@ def username_claim(user_info): cilogon_sub_parts = cilogon_sub.split("/") username = f"oauth2|cilogon|{cilogon_sub_parts[3]}|{cilogon_sub_parts[5]}" return username + c.GenericOAuthenticator.username_claim = username_claim c.GenericOAuthenticator.allow_all = True authenticator = get_authenticator(config=c) From 77a43d11c8435f8d270e2f3a2114baf60723ef94 Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Thu, 18 Jan 2024 14:03:00 -0800 Subject: [PATCH 3/3] Remove redefenition of `username_claim` in Generic --- oauthenticator/generic.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/oauthenticator/generic.py b/oauthenticator/generic.py index 4c7bb81c..4798280d 100644 --- a/oauthenticator/generic.py +++ b/oauthenticator/generic.py @@ -60,20 +60,6 @@ def _login_service_default(self): """, ) - username_claim = Union( - [Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()], - config=True, - help=""" - When `userdata_url` returns a json response, the username will be taken - from this key. - - Can be a string key name or a callable that accepts the returned - userdata json (as a dict) and returns the username. The callable is - useful e.g. for extracting the username from a nested object in the - response. - """, - ) - @default("http_client") def _default_http_client(self): return AsyncHTTPClient(