diff --git a/requirements.txt b/requirements.txt index 76725833c..6a6f0585a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ cffi==1.12.3 cfn-lint==0.24.4 chardet==3.0.4 constantly==15.1.0 -cryptography==39.0.1 +cryptography==43.0.3 dataclasses==0.6 DateTime==4.3 decorator==4.4.0 diff --git a/tests/api/auth_test.py b/tests/api/auth_test.py new file mode 100644 index 000000000..77a946176 --- /dev/null +++ b/tests/api/auth_test.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from tron.api.auth import AuthorizationFilter +from tron.api.auth import AuthorizationOutcome + + +@pytest.fixture +def mock_auth_filter(): + with patch("tron.api.auth.requests"): + yield AuthorizationFilter("http://localhost:31337/whatever", True) + + +def mock_request(path: str, token: str, method: str): + res = MagicMock(path=path.encode(), method=method.encode()) + res.getHeader.return_value = token + return res + + +def test_is_request_authorized(mock_auth_filter): + mock_auth_filter.session.post.return_value.json.return_value = { + "result": {"allowed": True, "reason": "User allowed"} + } + assert mock_auth_filter.is_request_authorized( + mock_request("/allowed", "aaa.bbb.ccc", "get") + ) == AuthorizationOutcome(True, "User allowed") + mock_auth_filter.session.post.assert_called_once_with( + url="http://localhost:31337/whatever", + json={ + "input": { + "path": "/allowed", + "backend": "tron", + "token": "aaa.bbb.ccc", + "method": "get", + } + }, + timeout=2, + ) + + +def test_is_request_authorized_fail(mock_auth_filter): + mock_auth_filter.session.post.side_effect = Exception + assert mock_auth_filter.is_request_authorized( + mock_request("/allowed", "eee.ddd.fff", "get") + ) == AuthorizationOutcome(False, "Auth backend error") + + +def test_is_request_authorized_malformed(mock_auth_filter): + mock_auth_filter.session.post.return_value.json.return_value = {"foo": "bar"} + assert mock_auth_filter.is_request_authorized( + mock_request("/allowed", "eee.ddd.fff", "post") + ) == AuthorizationOutcome(False, "Malformed auth response") + + +def test_is_request_authorized_no_enforce(mock_auth_filter): + mock_auth_filter.session.post.return_value.json.return_value = { + "result": {"allowed": False, "reason": "Missing token"} + } + with patch.object(mock_auth_filter, "enforce", False): + assert mock_auth_filter.is_request_authorized(mock_request("/foobar", "", "post")) == AuthorizationOutcome( + True, "Auth dry-run" + ) + + +def test_is_request_authorized_disabled(mock_auth_filter): + mock_auth_filter.session.post.return_value.json.return_value = { + "result": {"allowed": False, "reason": "Missing token"} + } + with patch.object(mock_auth_filter, "endpoint", None): + assert mock_auth_filter.is_request_authorized(mock_request("/buzz", "", "post")) == AuthorizationOutcome( + True, "Auth not enabled" + ) diff --git a/tron/api/auth.py b/tron/api/auth.py new file mode 100644 index 000000000..35864790f --- /dev/null +++ b/tron/api/auth.py @@ -0,0 +1,94 @@ +import logging +import os +from functools import lru_cache +from typing import NamedTuple + +import cachetools.func +import requests +from twisted.web.server import Request + + +logger = logging.getLogger(__name__) +AUTH_CACHE_SIZE = 50000 +AUTH_CACHE_TTL = 30 * 60 + + +class AuthorizationOutcome(NamedTuple): + authorized: bool + reason: str + + +class AuthorizationFilter: + """API request authorization via external system""" + + def __init__(self, endpoint: str, enforce: bool): + """Constructor + + :param str endpoint: HTTP endpoint of external authorization system + :param bool enforce: whether to enforce authorization decisions + """ + self.endpoint = endpoint + self.enforce = enforce + self.session = requests.Session() + + @classmethod + @lru_cache(maxsize=1) + def get_from_env(cls) -> "AuthorizationFilter": + return cls( + endpoint=os.getenv("API_AUTH_ENDPOINT", ""), + enforce=bool(os.getenv("API_AUTH_ENFORCE", "")), + ) + + def is_request_authorized(self, request: Request) -> AuthorizationOutcome: + """Check if API request is authorized + + :param Request request: API request object + :return: auth outcome + """ + if not self.endpoint: + return AuthorizationOutcome(True, "Auth not enabled") + token = (request.getHeader("Authorization") or "").strip() + token = token.split()[-1] if token else "" # removes "Bearer" prefix + auth_outcome = self._is_request_authorized_impl( + # path and method are byte arrays in twisted + path=request.path.decode(), + token=token, + method=request.method.decode(), + ) + return auth_outcome if self.enforce else AuthorizationOutcome(True, "Auth dry-run") + + @cachetools.func.ttl_cache(maxsize=AUTH_CACHE_SIZE, ttl=AUTH_CACHE_TTL) + def _is_request_authorized_impl(self, path: str, token: str, method: str) -> AuthorizationOutcome: + """Check if API request is authorized + + :param str path: API path + :param str token: authentication token + :param str method: http method + :return: auth outcome + """ + try: + response = self.session.post( + url=self.endpoint, + json={ + "input": { + "path": path, + "backend": "tron", + "token": token, + "method": method.lower(), + }, + }, + timeout=2, + ).json() + except Exception as e: + logger.exception(f"Issue communicating with auth endpoint: {e}") + return AuthorizationOutcome(False, "Auth backend error") + + if "result" not in response or "allowed" not in response["result"]: + return AuthorizationOutcome(False, "Malformed auth response") + + if not response["result"]["allowed"]: + reason = response["result"].get("reason", "Denied") + return AuthorizationOutcome(False, reason) + + reason = response["result"].get("reason", "Ok") + return AuthorizationOutcome(True, reason) diff --git a/tron/api/resource.py b/tron/api/resource.py index 0460c910f..13c70bf4f 100644 --- a/tron/api/resource.py +++ b/tron/api/resource.py @@ -26,6 +26,7 @@ from tron.api import adapter, controller from tron.api import requestargs from tron.api.async_resource import AsyncResource +from tron.api.auth import AuthorizationFilter from tron.metrics import view_all_metrics from tron.metrics import meter from tron.utils import maybe_decode @@ -514,6 +515,18 @@ def render_GET(self, request): } return respond(request=request, response=response) + def render(self, request): + """Overriding base `render` method to support auth""" + auth_outcome = AuthorizationFilter.get_from_env().is_request_authorized(request) + if not auth_outcome.authorized: + return respond( + request=request, + response={"reason": auth_outcome.reason}, + code=http.FORBIDDEN, + headers={"X-Auth-Failure-Reason": auth_outcome.reason}, + ) + return super().render(request) + class RootResource(resource.Resource): def __init__(self, mcp, web_path): diff --git a/tron/commands/client.py b/tron/commands/client.py index f80c8dad3..b92361984 100644 --- a/tron/commands/client.py +++ b/tron/commands/client.py @@ -37,9 +37,22 @@ class RequestError(ValueError): } +def get_sso_auth_token() -> str: + """Generate an authentication token for the calling user from the Single Sign On provider, if configured""" + from okta_auth import get_and_cache_jwt_default # type: ignore + from tron.commands.cmd_utils import get_client_config + + client_id = get_client_config().get("auth_sso_oidc_client_id") + return get_and_cache_jwt_default(client_id) if client_id else "" # type: ignore + + def build_url_request(uri, data, headers=None, method=None): headers = headers or default_headers enc_data = urllib.parse.urlencode(data).encode() if data else None + if os.getenv("TRONCTL_API_AUTH") and (data or method.upper() == "POST"): + token = get_sso_auth_token() + if token: + headers["Authorization"] = f"Bearer {token}" return urllib.request.Request(uri, enc_data, headers=headers, method=method) diff --git a/yelp_package/extra_requirements_yelp.txt b/yelp_package/extra_requirements_yelp.txt index f4d008a36..9af99c009 100644 --- a/yelp_package/extra_requirements_yelp.txt +++ b/yelp_package/extra_requirements_yelp.txt @@ -2,7 +2,10 @@ clusterman-metrics==2.2.1 # used by tron for pre-scaling for Spark runs dateglob==1.1.1 # required by yelp-logging geogrid==2.1.0 # required by yelp-logging monk==3.0.4 # required by yelp-clog +okta-auth==1.0.1 # used for API auth ply==3.11 # required by thriftpy2 +pyjwt==2.9.0 # required by okta-auth +saml-helper==2.3.3 # required by okta-auth scribereader==1.1.1 # used by tron to get tronjob logs simplejson==3.19.2 # required by yelp-logging srv-configs==1.3.4 # required by monk