Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Google] Make looking up google groups far less blocking #764

Merged
merged 13 commits into from
Oct 8, 2024
105 changes: 61 additions & 44 deletions oauthenticator/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class GoogleOAuthenticator(OAuthenticator, GoogleOAuth2Mixin):
user_auth_state_key = "google_user"
_service_credentials = {}

@default("login_service")
def _login_service_default(self):
Expand Down Expand Up @@ -243,7 +244,7 @@ async def update_auth_model(self, auth_model):

user_groups = set()
if self.allowed_google_groups or self.admin_google_groups:
user_groups = self._fetch_member_groups(user_email, user_domain)
user_groups = await self._fetch_member_groups(user_email, user_domain)
# sets are not JSONable, cast to list for auth_state
user_info["google_groups"] = list(user_groups)

Expand Down Expand Up @@ -314,6 +315,36 @@ async def check_allowed(self, username, auth_model):
# users should be explicitly allowed via config, otherwise they aren't
return False

def _get_service_credentials(self, user_email_domain):
"""
Returns the stored credentials or fetches and stores new ones.

Checks if the credentials are valid before returning them. Refreshes
if necessary and stores the refreshed credentials.
"""
if (
user_email_domain not in self._service_credentials
or not self._is_token_valid(user_email_domain)
):
self._service_credentials[user_email_domain] = (
self._setup_service_credentials(user_email_domain)
)

return self._service_credentials

def _is_token_valid(self, user_email_domain):
"""
Checks if the stored token is valid.
"""
if not self._service_credentials[user_email_domain]:
return False
if not self._service_credentials[user_email_domain].token:
return False
if self._service_credentials[user_email_domain].expired:
return False

return True

def _service_client_credentials(self, scopes, user_email_domain):
"""
Return a configured service client credentials for the API.
Expand All @@ -338,71 +369,57 @@ def _service_client_credentials(self, scopes, user_email_domain):

return credentials

def _service_client(self, service_name, service_version, credentials, http=None):
def _setup_service_credentials(self, user_email_domain):
"""
Return a configured service client for the API.
Set up the oauth credentials for Google API.
"""
credentials = self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)

try:
from googleapiclient.discovery import build
from google.auth.transport import requests
except:
raise ImportError(
"Could not import googleapiclient.discovery's build,"
"Could not import google.auth.transport's requests,"
"you may need to run 'pip install oauthenticator[googlegroups]' or not declare google groups"
)

self.log.debug(
f"service_name is {service_name}, service_version is {service_version}"
)

return build(
serviceName=service_name,
version=service_version,
credentials=credentials,
cache_discovery=False,
http=http,
)

def _setup_service(self, user_email_domain, http=None):
"""
Set up the service client for Google API.
"""
credentials = self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)
service = self._service_client(
service_name='admin',
service_version='directory_v1',
credentials=credentials,
http=http,
)
return service
request = requests.Request()
credentials.refresh(request)
consideRatio marked this conversation as resolved.
Show resolved Hide resolved
self.log.debug(f"Credentials refreshed for {user_email_domain}")
return credentials

def _fetch_member_groups(
async def _fetch_member_groups(
self,
member_email,
user_email_domain,
http=None,
checked_groups=None,
processed_groups=None,
credentials=None,
):
"""
Return a set with the google groups a given user/group is a member of, including nested groups if allowed.
"""
# FIXME: When this function is used and waiting for web request
# responses, JupyterHub gets blocked from doing other things.
# Ideally the web requests should be made using an async client
# that can be awaited while JupyterHub handles other things.
#
if not hasattr(self, 'service'):
self.service = self._setup_service(user_email_domain, http)

# WARNING: There's a race condition here if multiple users login at the same time.
# This is currently ignored.
credentials = credentials or self._get_service_credentials(user_email_domain)
token = credentials[user_email_domain].token
Comment on lines +406 to +409
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manics I've tried to understand this race condition, but I've failed to see it. Are you sure there is one?

My understanding made explicit: whenever an await is showing up, then and only then is Python possibly switching to execute something else.

checked_groups = checked_groups or set()
processed_groups = processed_groups or set()

resp = self.service.groups().list(userKey=member_email).execute()
headers = {'Authorization': f'Bearer {token}'}
url = f'https://www.googleapis.com/admin/directory/v1/groups?userKey={member_email}'
group_data = await self.httpfetch(
jrdnbradford marked this conversation as resolved.
Show resolved Hide resolved
url, headers=headers, label="fetching google groups"
)

member_groups = {
g['email'].split('@')[0] for g in resp.get('groups', []) if g.get('email')
g['email'].split('@')[0]
for g in group_data.get('groups', [])
if g.get('email')
}
self.log.debug(f"Fetched groups for {member_email}: {member_groups}")

Expand All @@ -414,7 +431,7 @@ def _fetch_member_groups(
if group in processed_groups:
continue
processed_groups.add(group)
nested_groups = self._fetch_member_groups(
nested_groups = await self._fetch_member_groups(
f"{group}@{user_email_domain}",
user_email_domain,
http,
Expand Down
3 changes: 2 additions & 1 deletion oauthenticator/tests/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from unittest import mock
from unittest.mock import AsyncMock

from pytest import fixture, mark, raises
from traitlets.config import Config
Expand Down Expand Up @@ -211,7 +212,7 @@ async def test_google(
handled_user_model = user_model("user1@example.com", "user1")
handler = google_client.handler_for_user(handled_user_model)
with mock.patch.object(
authenticator, "_fetch_member_groups", lambda *args: {"group1"}
authenticator, "_fetch_member_groups", AsyncMock(return_value={"group1"})
):
auth_model = await authenticator.get_authenticated_user(handler, None)

Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def run(self):
# googlegroups is required for use of GoogleOAuthenticator configured with
# either admin_google_groups and/or allowed_google_groups.
'googlegroups': [
'google-api-python-client',
'google-auth-oauthlib',
],
# mediawiki is required for use of MWOAuthenticator
Expand All @@ -105,7 +104,6 @@ def run(self):
'pytest-cov',
'requests-mock',
# dependencies from googlegroups:
'google-api-python-client',
'google-auth-oauthlib',
# dependencies from mediawiki:
'mwoauth>=0.3.8',
Expand Down