Skip to content

Commit

Permalink
Merge pull request #764 from jrdnbradford/async-group-lookup
Browse files Browse the repository at this point in the history
[Google] Make looking up google groups far less blocking
  • Loading branch information
consideRatio authored Oct 8, 2024
2 parents cb3a63f + 67cbbc6 commit b9c8905
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 47 deletions.
105 changes: 61 additions & 44 deletions oauthenticator/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class GoogleOAuthenticator(OAuthenticator, GoogleOAuth2Mixin):
user_auth_state_key = "google_user"
_service_credentials = {}

@default("login_service")
def _login_service_default(self):
Expand Down Expand Up @@ -251,7 +252,7 @@ async def update_auth_model(self, auth_model):

user_groups = set()
if self.allowed_google_groups or self.admin_google_groups:
user_groups = self._fetch_member_groups(user_email, user_domain)
user_groups = await self._fetch_member_groups(user_email, user_domain)
# sets are not JSONable, cast to list for auth_state
user_info["google_groups"] = list(user_groups)

Expand Down Expand Up @@ -322,6 +323,36 @@ async def check_allowed(self, username, auth_model):
# users should be explicitly allowed via config, otherwise they aren't
return False

def _get_service_credentials(self, user_email_domain):
"""
Returns the stored credentials or fetches and stores new ones.
Checks if the credentials are valid before returning them. Refreshes
if necessary and stores the refreshed credentials.
"""
if (
user_email_domain not in self._service_credentials
or not self._is_token_valid(user_email_domain)
):
self._service_credentials[user_email_domain] = (
self._setup_service_credentials(user_email_domain)
)

return self._service_credentials

def _is_token_valid(self, user_email_domain):
"""
Checks if the stored token is valid.
"""
if not self._service_credentials[user_email_domain]:
return False
if not self._service_credentials[user_email_domain].token:
return False
if self._service_credentials[user_email_domain].expired:
return False

return True

def _service_client_credentials(self, scopes, user_email_domain):
"""
Return a configured service client credentials for the API.
Expand All @@ -346,71 +377,57 @@ def _service_client_credentials(self, scopes, user_email_domain):

return credentials

def _service_client(self, service_name, service_version, credentials, http=None):
def _setup_service_credentials(self, user_email_domain):
"""
Return a configured service client for the API.
Set up the oauth credentials for Google API.
"""
credentials = self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)

try:
from googleapiclient.discovery import build
from google.auth.transport import requests
except:
raise ImportError(
"Could not import googleapiclient.discovery's build,"
"Could not import google.auth.transport's requests,"
"you may need to run 'pip install oauthenticator[googlegroups]' or not declare google groups"
)

self.log.debug(
f"service_name is {service_name}, service_version is {service_version}"
)

return build(
serviceName=service_name,
version=service_version,
credentials=credentials,
cache_discovery=False,
http=http,
)

def _setup_service(self, user_email_domain, http=None):
"""
Set up the service client for Google API.
"""
credentials = self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)
service = self._service_client(
service_name='admin',
service_version='directory_v1',
credentials=credentials,
http=http,
)
return service
request = requests.Request()
credentials.refresh(request)
self.log.debug(f"Credentials refreshed for {user_email_domain}")
return credentials

def _fetch_member_groups(
async def _fetch_member_groups(
self,
member_email,
user_email_domain,
http=None,
checked_groups=None,
processed_groups=None,
credentials=None,
):
"""
Return a set with the google groups a given user/group is a member of, including nested groups if allowed.
"""
# FIXME: When this function is used and waiting for web request
# responses, JupyterHub gets blocked from doing other things.
# Ideally the web requests should be made using an async client
# that can be awaited while JupyterHub handles other things.
#
if not hasattr(self, 'service'):
self.service = self._setup_service(user_email_domain, http)

# WARNING: There's a race condition here if multiple users login at the same time.
# This is currently ignored.
credentials = credentials or self._get_service_credentials(user_email_domain)
token = credentials[user_email_domain].token
checked_groups = checked_groups or set()
processed_groups = processed_groups or set()

resp = self.service.groups().list(userKey=member_email).execute()
headers = {'Authorization': f'Bearer {token}'}
url = f'https://www.googleapis.com/admin/directory/v1/groups?userKey={member_email}'
group_data = await self.httpfetch(
url, headers=headers, label="fetching google groups"
)

member_groups = {
g['email'].split('@')[0] for g in resp.get('groups', []) if g.get('email')
g['email'].split('@')[0]
for g in group_data.get('groups', [])
if g.get('email')
}
self.log.debug(f"Fetched groups for {member_email}: {member_groups}")

Expand All @@ -422,7 +439,7 @@ def _fetch_member_groups(
if group in processed_groups:
continue
processed_groups.add(group)
nested_groups = self._fetch_member_groups(
nested_groups = await self._fetch_member_groups(
f"{group}@{user_email_domain}",
user_email_domain,
http,
Expand Down
3 changes: 2 additions & 1 deletion oauthenticator/tests/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from unittest import mock
from unittest.mock import AsyncMock

from pytest import fixture, mark, raises
from traitlets.config import Config
Expand Down Expand Up @@ -211,7 +212,7 @@ async def test_google(
handled_user_model = user_model("user1@example.com", "user1")
handler = google_client.handler_for_user(handled_user_model)
with mock.patch.object(
authenticator, "_fetch_member_groups", lambda *args: {"group1"}
authenticator, "_fetch_member_groups", AsyncMock(return_value={"group1"})
):
auth_model = await authenticator.get_authenticated_user(handler, None)

Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def run(self):
# googlegroups is required for use of GoogleOAuthenticator configured with
# either admin_google_groups and/or allowed_google_groups.
'googlegroups': [
'google-api-python-client',
'google-auth-oauthlib',
],
# mediawiki is required for use of MWOAuthenticator
Expand All @@ -105,7 +104,6 @@ def run(self):
'pytest-cov',
'requests-mock',
# dependencies from googlegroups:
'google-api-python-client',
'google-auth-oauthlib',
# dependencies from mediawiki:
'mwoauth>=0.3.8',
Expand Down

0 comments on commit b9c8905

Please sign in to comment.