Skip to content

Commit

Permalink
Merge pull request #717 from yuvipanda/move-username-callable
Browse files Browse the repository at this point in the history
Make `username_claim` callable in all Authenticators except CILogon, like it has been in Generic
  • Loading branch information
consideRatio authored Feb 7, 2024
2 parents fc20682 + 77a43d1 commit 9b96d47
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
25 changes: 0 additions & 25 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,6 @@ def _login_service_default(self):
""",
)

username_claim = Union(
[Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()],
config=True,
help="""
When `userdata_url` returns a json response, the username will be taken
from this key.
Can be a string key name or a callable that accepts the returned
userdata json (as a dict) and returns the username. The callable is
useful e.g. for extracting the username from a nested object in the
response.
""",
)

@default("http_client")
def _default_http_client(self):
return AsyncHTTPClient(
Expand Down Expand Up @@ -114,17 +100,6 @@ def _default_http_client(self):
""",
)

def user_info_to_username(self, user_info):
"""
Overrides OAuthenticator.user_info_to_username to support the
GenericOAuthenticator unique feature of allowing username_claim to be a
callable function.
"""
if callable(self.username_claim):
return self.username_claim(user_info)
else:
return super().user_info_to_username(user_info)

def get_user_groups(self, user_info):
"""
Returns a set of groups the user belongs to based on claim_groups_key
Expand Down
21 changes: 14 additions & 7 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest
from tornado.httputil import url_concat
from tornado.log import app_log
from traitlets import Any, Bool, Dict, List, Unicode, default
from traitlets import Any, Bool, Callable, Dict, List, Unicode, Union, default


def guess_callback_uri(protocol, host, hub_server_url):
Expand Down Expand Up @@ -377,14 +377,17 @@ def _token_url_default(self):
def _userdata_url_default(self):
return os.environ.get("OAUTH2_USERDATA_URL", "")

username_claim = Unicode(
"username",
username_claim = Union(
[Unicode(os.environ.get('OAUTH2_USERNAME_KEY', 'username')), Callable()],
config=True,
help="""
The key to get the JupyterHub username from in the data response to the
request made to :attr:`userdata_url`.
When `userdata_url` returns a json response, the username will be taken
from this key.
Examples include: email, username, nickname
Can be a string key name or a callable that accepts the returned
userdata json (as a dict) and returns the username. The callable is
useful e.g. for extracting the username from a nested object in the
response or doing other post processing.
What keys are available will depend on the scopes requested and the
authenticator used.
Expand Down Expand Up @@ -769,7 +772,11 @@ def user_info_to_username(self, user_info):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
username = user_info.get(self.username_claim, None)

if callable(self.username_claim):
username = self.username_claim(user_info)
else:
username = user_info.get(self.username_claim, None)
if not username:
message = (f"No {self.username_claim} found in {user_info}",)
self.log.error(message)
Expand Down
27 changes: 27 additions & 0 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def user_model(username, **kwargs):
"""Return a user model"""
return {
"username": username,
"sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431",
"scope": "basic",
"groups": ["group1"],
**kwargs,
Expand Down Expand Up @@ -205,6 +206,32 @@ async def test_generic(
assert auth_model == None


async def test_username_claim_callable(
get_authenticator,
generic_client,
):
c = Config()
c.GenericOAuthenticator = Config()

def username_claim(user_info):
username = user_info["sub"]
if username.startswith("oauth2|cilogon"):
cilogon_sub = username.rsplit("|", 1)[-1]
cilogon_sub_parts = cilogon_sub.split("/")
username = f"oauth2|cilogon|{cilogon_sub_parts[3]}|{cilogon_sub_parts[5]}"
return username

c.GenericOAuthenticator.username_claim = username_claim
c.GenericOAuthenticator.allow_all = True
authenticator = get_authenticator(config=c)

handled_user_model = user_model("user1")
handler = generic_client.handler_for_user(handled_user_model)
auth_model = await authenticator.get_authenticated_user(handler, None)

assert auth_model["name"] == "oauth2|cilogon|servera|43431"


async def test_generic_data(get_authenticator, generic_client):
c = Config()
c.GenericOAuthenticator.allow_all = True
Expand Down

0 comments on commit 9b96d47

Please sign in to comment.