Skip to content

Commit

Permalink
Merge pull request #699 from consideRatio/pr/cilogon-default-allowed-idp
Browse files Browse the repository at this point in the history
[CILogon] add config to specify default idp under allowed_idps
  • Loading branch information
consideRatio authored Nov 22, 2023
2 parents 596d9c8 + 5ced631 commit 48056dd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
29 changes: 25 additions & 4 deletions oauthenticator/cilogon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
yaml = YAML(typ="safe", pure=True)


def _get_select_idp_param(allowed_idps):
"""
The "selected_idp" query parameter included when the user is redirected to
CILogon should be a comma separated string of idps to choose from, where the
first entry is pre-selected as the default choice. The ordering of the
remaining idps has no meaning.
"""
# pick the first idp marked as default, or fallback to the first idp
default_keys = [k for k, v in allowed_idps.items() if v.get("default")]
default_key = next(iter(default_keys), next(iter(allowed_idps)))

# put the default idp first followed by the other idps
other_keys = [k for k, _ in allowed_idps.items() if k != default_key]
selected_idp = ",".join([default_key] + other_keys)

return selected_idp


class CILogonLoginHandler(OAuthLoginHandler):
"""See https://www.cilogon.org/oidc for general information."""

Expand All @@ -30,10 +48,9 @@ def authorize_redirect(self, *args, **kwargs):
# include it, we then modify kwargs' extra_params dictionary
extra_params = kwargs.setdefault('extra_params', {})

# selected_idp should be a comma separated string
allowed_idps = ",".join(self.authenticator.allowed_idps.keys())
extra_params["selected_idp"] = allowed_idps

extra_params["selected_idp"] = _get_select_idp_param(
self.authenticator.allowed_idps
)
if self.authenticator.skin:
extra_params["skin"] = self.authenticator.skin

Expand Down Expand Up @@ -125,6 +142,7 @@ def _validate_scope(self, proposal):
"domain": "utoronto.ca",
},
"allow_all": True,
"default": True,
},
"http://google.com/accounts/o8/id": {
"username_derivation": {
Expand All @@ -151,6 +169,9 @@ def _validate_scope(self, proposal):
This is a description of the configuration you can pass to
`allowed_idps`.
* `default`: bool (optional)
Determines the identity provider to be pre-selected in a list for
users arriving to CILogons login screen.
* `username_derivation`: string (required)
* `username_claim`: string (required)
The claim in the `userinfo` response from which to define the
Expand Down
43 changes: 42 additions & 1 deletion oauthenticator/tests/test_cilogon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from traitlets.config import Config
from traitlets.traitlets import TraitError

from ..cilogon import CILogonOAuthenticator
from ..cilogon import CILogonOAuthenticator, _get_select_idp_param
from .mocks import setup_oauth_mock


Expand Down Expand Up @@ -688,3 +688,44 @@ async def test_allowed_idps_username_derivation_actions(cilogon_client):
auth_model = await authenticator.get_authenticated_user(handler, None)
print(json.dumps(auth_model, sort_keys=True, indent=4))
assert auth_model['name'] == 'jtkirk'


@mark.parametrize(
"test_variation_id,allowed_idps,expected_return_value",
[
(
"default-specified",
{
'https://example4.org': {},
'https://example3.org': {'default': False},
'https://example2.org': {'default': True},
'https://example1.org': {},
},
"https://example2.org,https://example4.org,https://example3.org,https://example1.org",
),
(
"no-truthy-default-specified",
{
'https://example4.org': {},
'https://example3.org': {'default': False},
'https://example2.org': {},
'https://example1.org': {},
},
"https://example4.org,https://example3.org,https://example2.org,https://example1.org",
),
(
"no-default-specified-pick-first-entry",
{
'https://example4.org': {},
'https://example3.org': {},
'https://example2.org': {},
'https://example1.org': {},
},
"https://example4.org,https://example3.org,https://example2.org,https://example1.org",
),
],
)
async def test__get_selected_idp_param(
test_variation_id, allowed_idps, expected_return_value
):
assert _get_select_idp_param(allowed_idps) == expected_return_value

0 comments on commit 48056dd

Please sign in to comment.