diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index bd7bb9ad..2a223663 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -9,6 +9,7 @@ from jupyterhub.auth import LocalAuthenticator from tornado.httpclient import HTTPRequest from traitlets import default +from traitlets import List from traitlets import Unicode from .oauth2 import OAuthenticator @@ -50,6 +51,22 @@ def _token_url_default(self): self.tenant_id ) + allowed_groups = List( + Unicode(), + config=True, + help="Automatically allow members of selected groups", + ) + + admin_groups = List( + Unicode(), + config=True, + help="Groups whose members should have Jupyterhub admin privileges", + ) + + @staticmethod + def check_user_in_groups(member_groups, allowed_groups): + return bool(set(member_groups) & set(allowed_groups)) + async def authenticate(self, handler, data=None): code = handler.get_argument("code") @@ -94,6 +111,15 @@ async def authenticate(self, handler, data=None): # results in a decoded JWT for the user data auth_state['user'] = decoded + groups = self.allowed_groups + self.admin_groups + if groups: + ad_groups = decoded.get('groups') + if self.check_user_in_groups(ad_groups, groups): + userdict['admin'] = self.check_user_in_groups( + ad_groups, self.admin_groups + ) + else: + userdict = None return userdict diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index 35c8a2bf..5055325e 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -3,10 +3,12 @@ import re import time import uuid +from functools import partial from unittest import mock import jwt import pytest +from pytest import fixture from ..azuread import AzureAdOAuthenticator from .mocks import setup_oauth_mock @@ -19,26 +21,30 @@ def test_tenant_id_from_env(): assert aad.tenant_id == tenant_id -def user_model(tenant_id, client_id, name): +def user_model(tenant_id, client_id, name, **kwargs): """Return a user model""" # model derived from https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#v20 now = int(time.time()) + + user = { + "ver": "2.0", + "iss": f"https://login.microsoftonline.com/{tenant_id}/v2.0", + "sub": "AAAAAAAAAAAAAAAAAAAAAIkzqFVrSaSaFHy782bbtaQ", + "aud": client_id, + "exp": now + 3600, + "iat": now, + "nbf": now, + "name": name, + "preferred_username": name, + "oid": str(uuid.uuid1()), + "tid": tenant_id, + "nonce": "123523", + "aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY", + } + user.update(kwargs) + id_token = jwt.encode( - { - "ver": "2.0", - "iss": f"https://login.microsoftonline.com/{tenant_id}/v2.0", - "sub": "AAAAAAAAAAAAAAAAAAAAAIkzqFVrSaSaFHy782bbtaQ", - "aud": client_id, - "exp": now + 3600, - "iat": now, - "nbf": now, - "name": name, - "preferred_username": name, - "oid": str(uuid.uuid1()), - "tid": tenant_id, - "nonce": "123523", - "aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY", - }, + user, os.urandom(5), ).decode("ascii") @@ -48,6 +54,15 @@ def user_model(tenant_id, client_id, name): } +def _get_authenticator(**kwargs): + return AzureAdOAuthenticator( + tenant_id=str(uuid.uuid1()), + client_id=str(uuid.uuid1()), + client_secret=str(uuid.uuid1()), + **kwargs, + ) + + @pytest.fixture def azure_client(client): setup_oauth_mock( @@ -59,6 +74,11 @@ def azure_client(client): return client +@fixture +def get_authenticator(azure_client, **kwargs): + return partial(_get_authenticator, http_client=azure_client) + + @pytest.mark.parametrize( 'username_claim', [ @@ -68,12 +88,8 @@ def azure_client(client): 'preferred_username', ], ) -async def test_azuread(username_claim, azure_client): - authenticator = AzureAdOAuthenticator( - tenant_id=str(uuid.uuid1()), - client_id=str(uuid.uuid1()), - client_secret=str(uuid.uuid1()), - ) +async def test_azuread(get_authenticator, username_claim, azure_client): + authenticator = get_authenticator() if username_claim: authenticator.username_claim = username_claim @@ -95,3 +111,53 @@ async def test_azuread(username_claim, azure_client): name = user_info['name'] assert name == jwt_user[authenticator.username_claim] + + +@pytest.mark.parametrize( + 'allowed_groups,admin_groups,azuread_groups,expected', + [ + ( + [], + ['jupyterhub-admin'], + ['jupyterhub-admin'], + lambda r: bool(r) and r['admin'], + ), + ([], ['jupyterhub-admin'], ['jupyter-admin'], lambda r: not bool(r)), + (['jupyterhub'], [], ['jupyterhub'], lambda r: bool(r) and not r['admin']), + (['jupyterhub'], [], ['jupyter'], lambda r: not bool(r)), + ([], [], ['jupyterhub'], lambda r: bool(r)), + ( + ['jupyterhub'], + ['jupyterhub-admin'], + ['jupyterhub', 'jupyterhub-admin'], + lambda r: bool(r) and r['admin'], + ), + (['jupyterhub'], [], [], lambda r: not bool(r)), + ([], [], [], lambda r: bool(r) and r.get('admin') is None), + ], +) +async def test_azuread_groups( + get_authenticator, + azure_client, + allowed_groups, + admin_groups, + azuread_groups, + expected, +): + authenticator = get_authenticator( + scope=['openid', 'profile'], + allowed_groups=allowed_groups, + admin_groups=admin_groups, + ) + + handler = azure_client.handler_for_user( + user_model( + tenant_id=authenticator.tenant_id, + client_id=authenticator.client_id, + name="somebody", + groups=azuread_groups, + ) + ) + + r = await authenticator.authenticate(handler) + assert expected(r)