diff --git a/tests/unit/oidc/test_models.py b/tests/unit/oidc/test_models.py index bb12aebbd5ab..fdac08c070c3 100644 --- a/tests/unit/oidc/test_models.py +++ b/tests/unit/oidc/test_models.py @@ -11,10 +11,18 @@ # limitations under the License. import pretend +import pytest from warehouse.oidc import models +def test_check_claim_binary(): + wrapped = models._check_claim_binary(str.__eq__) + + assert wrapped("foo", "bar", pretend.stub()) is False + assert wrapped("foo", "foo", pretend.stub()) is True + + class TestOIDCProvider: def test_oidc_provider_not_default_verifiable(self): provider = models.OIDCProvider(projects=[]) @@ -27,9 +35,9 @@ def test_github_provider_all_known_claims(self): assert models.GitHubProvider.all_known_claims() == { # verifiable claims "repository", - "workflow", "repository_owner", "repository_owner_id", + "job_workflow_ref", # preverified claims "iss", "iat", @@ -51,7 +59,7 @@ def test_github_provider_all_known_claims(self): "event_name", "ref_type", "repository_id", - "job_workflow_ref", + "workflow", } def test_github_provider_computed_properties(self): @@ -120,7 +128,7 @@ def test_github_provider_verifies(self, monkeypatch): workflow_filename="fakeworkflow.yml", ) - noop_check = pretend.call_recorder(lambda l, r: True) + noop_check = pretend.call_recorder(lambda gt, sc, ac: True) verifiable_claims = { claim_name: noop_check for claim_name in provider.__verifiable_claims__ } @@ -132,3 +140,62 @@ def test_github_provider_verifies(self, monkeypatch): } assert provider.verify_claims(signed_claims=signed_claims) assert len(noop_check.calls) == len(verifiable_claims) + + @pytest.mark.parametrize( + ("claim", "ref", "valid"), + [ + # okay: workflow name, followed by a nonempty ref + ( + "foo/bar/.github/workflows/baz.yml@refs/tags/v0.0.1", + "refs/tags/v0.0.1", + True, + ), + ("foo/bar/.github/workflows/baz.yml@refs/pulls/6", "refs/pulls/6", True), + ( + "foo/bar/.github/workflows/baz.yml@refs/heads/main", + "refs/heads/main", + True, + ), + ( + "foo/bar/.github/workflows/baz.yml@notrailingslash", + "notrailingslash", + True, + ), + # bad: workflow name, empty or missing ref + ("foo/bar/.github/workflows/baz.yml@emptyref", "", False), + ("foo/bar/.github/workflows/baz.yml@missingref", None, False), + # bad: workflow name with various attempted impersonations + ( + "foo/bar/.github/workflows/baz.yml@fake.yml@notrailingslash", + "notrailingslash", + False, + ), + ( + "foo/bar/.github/workflows/baz.yml@fake.yml@refs/pulls/6", + "refs/pulls/6", + False, + ), + # bad: missing tail or workflow name or otherwise partial + ("foo/bar/.github/workflows/baz.yml@", "notrailingslash", False), + ("foo/bar/.github/workflows/@", "notrailingslash", False), + ("foo/bar/.github/workflows/", "notrailingslash", False), + ("baz.yml", "notrailingslash", False), + ( + "foo/bar/.github/workflows/baz.yml@malicious.yml@", + "notrailingslash", + False, + ), + ("foo/bar/.github/workflows/baz.yml@@", "notrailingslash", False), + ("", "notrailingslash", False), + ], + ) + def test_github_provider_job_workflow_ref(self, claim, ref, valid): + provider = models.GitHubProvider( + repository_name="bar", + repository_owner="foo", + repository_owner_id=pretend.stub(), + workflow_filename="baz.yml", + ) + + check = models.GitHubProvider.__verifiable_claims__["job_workflow_ref"] + assert check(provider.job_workflow_ref, claim, {"ref": ref}) is valid diff --git a/warehouse/oidc/models.py b/warehouse/oidc/models.py index f3a228963354..86b52c776065 100644 --- a/warehouse/oidc/models.py +++ b/warehouse/oidc/models.py @@ -22,6 +22,38 @@ from warehouse.packaging.models import Project +def _check_claim_binary(binary_func): + """ + Wraps a binary comparison function so that it takes three arguments instead, + ignoring the third. + + This is used solely to make claim verification compatible with "trivial" + checks like `str.__eq__`. + """ + + def wrapper(ground_truth, signed_claim, all_signed_claims): + return binary_func(ground_truth, signed_claim) + + return wrapper + + +def _check_job_workflow_ref(ground_truth, signed_claim, all_signed_claims): + # We expect a string formatted as follows: + # OWNER/REPO/.github/workflows/WORKFLOW.yml@REF + # where REF is the value of the `ref` claim. + + # Defensive: GitHub should never give us an empty job_workflow_ref, + # but we check for one anyways just in case. + if not signed_claim: + return False + + ref = all_signed_claims.get("ref") + if not ref: + return False + + return f"{ground_truth}@{ref}" == signed_claim + + class OIDCProviderProjectAssociation(db.Model): __tablename__ = "oidc_provider_project_association" @@ -52,8 +84,10 @@ class OIDCProvider(db.Model): } # A map of claim names to "check" functions, each of which - # has the signature `check(ground-truth, signed-claim) -> bool`. - __verifiable_claims__: Dict[str, Callable[[Any, Any], bool]] = dict() + # has the signature `check(ground-truth, signed-claim, all-signed-claims) -> bool`. + __verifiable_claims__: Dict[ + str, Callable[[Any, Any, Dict[str, Any]], bool] + ] = dict() # Claims that have already been verified during the JWT signature # verification phase. @@ -115,7 +149,7 @@ def verify_claims(self, signed_claims): ) return False - if not check(getattr(self, claim_name), signed_claim): + if not check(getattr(self, claim_name), signed_claim, signed_claims): return False return True @@ -145,10 +179,10 @@ class GitHubProvider(OIDCProvider): workflow_filename = Column(String) __verifiable_claims__ = { - "repository": str.__eq__, - "workflow": str.__eq__, - "repository_owner": str.__eq__, - "repository_owner_id": str.__eq__, + "repository": _check_claim_binary(str.__eq__), + "repository_owner": _check_claim_binary(str.__eq__), + "repository_owner_id": _check_claim_binary(str.__eq__), + "job_workflow_ref": _check_job_workflow_ref, } __unchecked_claims__ = { @@ -166,8 +200,7 @@ class GitHubProvider(OIDCProvider): "event_name", "ref_type", "repository_id", - # TODO(#11096): Support reusable workflows. - "job_workflow_ref", + "workflow", } @property @@ -179,8 +212,8 @@ def repository(self): return f"{self.repository_owner}/{self.repository_name}" @property - def workflow(self): - return self.workflow_filename + def job_workflow_ref(self): + return f"{self.repository}/.github/workflows/{self.workflow_filename}" def __str__(self): return f"{self.workflow_filename} @ {self.repository}"