Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
jrdnbradford committed Oct 3, 2024
2 parents d32940d + f4da2e8 commit 09e863c
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 31 deletions.
34 changes: 20 additions & 14 deletions oauthenticator/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _userdata_url_default(self):
""",
)

allow_nested_groups = Bool(
include_nested_groups = Bool(
config=True,
help="""
Include members of nested Google groups in `allowed_google_groups` and
Expand Down Expand Up @@ -365,8 +365,8 @@ async def _fetch_member_groups(
member_email,
user_email_domain,
http=None,
checked_groups=set(),
processed_groups=set(),
checked_groups=None,
processed_groups=None,
):
"""
Return a set with the google groups a given user/group is a member of, including nested groups if allowed.
Expand All @@ -383,6 +383,10 @@ async def _fetch_member_groups(
) as resp:
group_data = await resp.json()

checked_groups = checked_groups or set()
processed_groups = processed_groups or set()

resp = self.service.groups().list(userKey=member_email).execute()
member_groups = {
g['email'].split('@')[0]
for g in group_data.get('groups', [])
Expand All @@ -393,18 +397,20 @@ async def _fetch_member_groups(
checked_groups.update(member_groups)
self.log.debug(f"Checked groups after update: {checked_groups}")

if self.allow_nested_groups:
if self.include_nested_groups:
for group in member_groups:
if group not in processed_groups:
processed_groups.add(group)
nested_groups = await self._fetch_member_groups(
f"{group}@{user_email_domain}",
user_email_domain,
http,
checked_groups,
processed_groups,
)
checked_groups.update(nested_groups)

if group in processed_groups:
continue
processed_groups.add(group)
nested_groups = await self._fetch_member_groups(
f"{group}@{user_email_domain}",
user_email_domain,
http,
checked_groups,
processed_groups,
)
checked_groups.update(nested_groups)

self.log.debug(f"member_email {member_email} is a member of {checked_groups}")
return checked_groups
Expand Down
64 changes: 63 additions & 1 deletion oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""

import base64
import hashlib
import json
import os
import secrets
import uuid
from functools import reduce
from inspect import isawaitable
Expand Down Expand Up @@ -91,6 +93,18 @@ def set_state_cookie(self, state_cookie_value):
STATE_COOKIE_NAME, state_cookie_value, expires_days=1, httponly=True
)

def _generate_pkce_params(self):
# https://datatracker.ietf.org/doc/html/rfc7636#section-4
# It is recommended that the output of the random number generator creates
# a 32-octet sequence which is base64url-encoded to produce a 43-octet URL
# safe string to use as the code verifier.
code_verifier = secrets.token_urlsafe(32)
code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge_base64 = (
base64.urlsafe_b64encode(code_challenge).decode("utf-8").rstrip("=")
)
return code_verifier, code_challenge_base64

def _generate_state_id(self):
return uuid.uuid4().hex

Expand All @@ -115,7 +129,15 @@ def get(self):
state_id = self._generate_state_id()
next_url = self._get_next_url()

cookie_state = _serialize_state({"state_id": state_id, "next_url": next_url})
state = {"state_id": state_id, "next_url": next_url}

if self.authenticator.enable_pkce:
code_verifier, code_challenge = self._generate_pkce_params()
state["code_verifier"] = code_verifier
token_params["code_challenge"] = code_challenge
token_params["code_challenge_method"] = "S256"

cookie_state = _serialize_state(state)
self.set_state_cookie(cookie_state)

authorize_state = _serialize_state({"state_id": state_id})
Expand Down Expand Up @@ -663,6 +685,34 @@ def _allowed_scopes_validation(self, proposal):
""",
)

enable_pkce = Bool(
True,
config=True,
help="""
Enable Proof Key for Code Exchange (PKCE) for the OAuth2 authorization code flow.
For more information, see `RFC 7636 <https://datatracker.ietf.org/doc/html/rfc7636>`_.
PKCE can be used even if the authorization server does not support it. According to
`section 3.1 of RFC 6749 <https://www.rfc-editor.org/rfc/rfc6749#section-3.1>`_:
The authorization server MUST ignore unrecognized request parameters.
Additionally, `section 5 of RFC 7636 <https://datatracker.ietf.org/doc/html/rfc7636#section-5>`_ states:
As the OAuth 2.0 [RFC6749] server responses are unchanged by this
specification, client implementations of this specification do not
need to know if the server has implemented this specification or not
and SHOULD send the additional parameters as defined in Section 4 to
all servers.
Note that S256 is the only code challenge method supported. As per `section 4.2 of RFC 6749
<https://www.rfc-editor.org/rfc/rfc6749#section-3.1>`_:
If the client is capable of using "S256", it MUST use "S256", as
"S256" is Mandatory To Implement (MTI) on the server.
""",
)

client_id_env = ""
client_id = Unicode(
config=True,
Expand Down Expand Up @@ -980,6 +1030,18 @@ def build_access_tokens_request_params(self, handler, data=None):
"data": data,
}

if self.enable_pkce:
# https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
cookie_state = handler.get_state_cookie()
if not cookie_state:
raise web.HTTPError(400, "OAuth state missing from cookies")

code_verifier = _deserialize_state(cookie_state).get("code_verifier")
if not code_verifier:
raise web.HTTPError(400, "Missing code_verifier")

params.update([("code_verifier", code_verifier)])

# the client_id and client_secret should not be included in the access token request params
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
Expand Down
7 changes: 7 additions & 0 deletions oauthenticator/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from tornado.log import app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient

from ..oauth2 import _serialize_state

RegExpType = type(re.compile('.'))


Expand Down Expand Up @@ -222,6 +224,11 @@ def handler_for_user(user):
method="GET", uri=f"https://hub.example.com?code={code}"
)
handler.hub = Mock(server=Mock(base_url='/hub/'), base_url='/hub/')
handler.get_state_cookie = Mock(
return_value=_serialize_state(
{"state_id": "123", "next_url": "/ABC", "code_verifier": "123"}
)
)
return handler

client.handler_for_user = handler_for_user
Expand Down
36 changes: 21 additions & 15 deletions oauthenticator/tests/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,46 @@ async def test_serialize_state():
TEST_NEXT_URL = '/ABC'


async def test_login_states():
@mark.parametrize("enable_pkce", [True, False])
async def test_login_states(enable_pkce):
login_request_uri = f"http://myhost/login?next={TEST_NEXT_URL}"
authenticator = OAuthenticator()
authenticator = OAuthenticator(enable_pkce=enable_pkce)
login_handler = mock_handler(
OAuthLoginHandler,
uri=login_request_uri,
authenticator=authenticator,
)

login_handler._generate_state_id = Mock(return_value=TEST_STATE_ID)

code_verifier, code_challenge = login_handler._generate_pkce_params()
login_handler._generate_pkce_params = Mock(
return_value=(code_verifier, code_challenge)
)
login_handler.set_state_cookie = Mock()
login_handler.authorize_redirect = Mock()

login_handler.get() # no await, we've mocked the authorizer_redirect to NOT be async

expected_cookie_value = _serialize_state(
{
'state_id': TEST_STATE_ID,
'next_url': TEST_NEXT_URL,
}
)
expected_state = {
'state_id': TEST_STATE_ID,
'next_url': TEST_NEXT_URL,
}
if enable_pkce:
expected_state['code_verifier'] = code_verifier
expected_cookie_value = _serialize_state(expected_state)

login_handler.set_state_cookie.assert_called_once_with(expected_cookie_value)

expected_state_param_value = _serialize_state(
{
'state_id': TEST_STATE_ID,
}
)
expected_state_param_value = {
'state': _serialize_state({'state_id': TEST_STATE_ID})
}
if enable_pkce:
expected_state_param_value['code_challenge'] = code_challenge
expected_state_param_value['code_challenge_method'] = 'S256'

login_handler.authorize_redirect.assert_called_once()
assert (
login_handler.authorize_redirect.call_args.kwargs['extra_params']['state']
login_handler.authorize_redirect.call_args.kwargs['extra_params']
== expected_state_param_value
)

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
aiohttp
# jsonschema is used for validating authenticator configurations
jsonschema
jupyterhub>=2.2
Expand Down

0 comments on commit 09e863c

Please sign in to comment.