Skip to content

Commit

Permalink
auth support for paasta APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
piax93 committed Nov 7, 2024
1 parent 379b45d commit 9f35622
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tests/api/auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from tron.api.auth import AuthorizationFilter
from tron.api.auth import AuthorizationOutcome


@pytest.fixture
def mock_auth_filter():
with patch("tron.api.auth.requests"):
yield AuthorizationFilter("http://localhost:31337/whatever", True)


def mock_request(path: str, token: str, method: str):
res = MagicMock(path=path.encode(), method=method.encode())
res.getHeader.return_value = token
return res


def test_is_request_authorized(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": True, "reason": "User allowed"}
}
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "aaa.bbb.ccc", "get")
) == AuthorizationOutcome(True, "User allowed")
mock_auth_filter.session.post.assert_called_once_with(
url="http://localhost:31337/whatever",
json={
"input": {
"path": "/allowed",
"backend": "tron",
"token": "aaa.bbb.ccc",
"method": "get",
}
},
timeout=2,
)


def test_is_request_authorized_fail(mock_auth_filter):
mock_auth_filter.session.post.side_effect = Exception
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "eee.ddd.fff", "get")
) == AuthorizationOutcome(False, "Auth backend error")


def test_is_request_authorized_malformed(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {"foo": "bar"}
assert mock_auth_filter.is_request_authorized(
mock_request("/allowed", "eee.ddd.fff", "post")
) == AuthorizationOutcome(False, "Malformed auth response")


def test_is_request_authorized_no_enforce(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": False, "reason": "Missing token"}
}
with patch.object(mock_auth_filter, "enforce", False):
assert mock_auth_filter.is_request_authorized(mock_request("/foobar", "", "post")) == AuthorizationOutcome(
True, "Auth dry-run"
)


def test_is_request_authorized_disabled(mock_auth_filter):
mock_auth_filter.session.post.return_value.json.return_value = {
"result": {"allowed": False, "reason": "Missing token"}
}
with patch.object(mock_auth_filter, "endpoint", None):
assert mock_auth_filter.is_request_authorized(mock_request("/buzz", "", "post")) == AuthorizationOutcome(
True, "Auth not enabled"
)
94 changes: 94 additions & 0 deletions tron/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import os
from functools import lru_cache
from typing import NamedTuple

import cachetools.func
import requests
from twisted.web.server import Request


logger = logging.getLogger(__name__)
AUTH_CACHE_SIZE = 50000
AUTH_CACHE_TTL = 30 * 60


class AuthorizationOutcome(NamedTuple):
authorized: bool
reason: str


class AuthorizationFilter:
"""API request authorization via external system"""

def __init__(self, endpoint: str, enforce: bool):
"""Constructor
:param str endpoint: HTTP endpoint of external authorization system
:param bool enforce: whether to enforce authorization decisions
"""
self.endpoint = endpoint
self.enforce = enforce
self.session = requests.Session()

@classmethod
@lru_cache(maxsize=1)
def get_from_env(cls) -> "AuthorizationFilter":
return cls(
endpoint=os.getenv("API_AUTH_ENDPOINT", ""),
enforce=bool(os.getenv("API_AUTH_ENFORCE", "")),
)

def is_request_authorized(self, request: Request) -> AuthorizationOutcome:
"""Check if API request is authorized
:param Request request: API request object
:return: auth outcome
"""
if not self.endpoint:
return AuthorizationOutcome(True, "Auth not enabled")
token = (request.getHeader("Authorization") or "").strip()
token = token.split()[-1] if token else "" # removes "Bearer" prefix
auth_outcome = self._is_request_authorized_impl(
# path and method are byte arrays in twisted
path=request.path.decode(),
token=token,
method=request.method.decode(),
)
return auth_outcome if self.enforce else AuthorizationOutcome(True, "Auth dry-run")

@cachetools.func.ttl_cache(maxsize=AUTH_CACHE_SIZE, ttl=AUTH_CACHE_TTL)
def _is_request_authorized_impl(self, path: str, token: str, method: str) -> AuthorizationOutcome:
"""Check if API request is authorized
:param str path: API path
:param str token: authentication token
:param str method: http method
:return: auth outcome
"""
try:
response = self.session.post(
url=self.endpoint,
json={
"input": {
"path": path,
"backend": "tron",
"token": token,
"method": method.lower(),
},
},
timeout=2,
).json()
except Exception as e:
logger.exception(f"Issue communicating with auth endpoint: {e}")
return AuthorizationOutcome(False, "Auth backend error")

if "result" not in response or "allowed" not in response["result"]:
return AuthorizationOutcome(False, "Malformed auth response")

if not response["result"]["allowed"]:
reason = response["result"].get("reason", "Denied")
return AuthorizationOutcome(False, reason)

reason = response["result"].get("reason", "Ok")
return AuthorizationOutcome(True, reason)
13 changes: 13 additions & 0 deletions tron/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tron.api import adapter, controller
from tron.api import requestargs
from tron.api.async_resource import AsyncResource
from tron.api.auth import AuthorizationFilter
from tron.metrics import view_all_metrics
from tron.metrics import meter
from tron.utils import maybe_decode
Expand Down Expand Up @@ -514,6 +515,18 @@ def render_GET(self, request):
}
return respond(request=request, response=response)

def render(self, request):
"""Overriding base `render` method to support auth"""
auth_outcome = AuthorizationFilter.get_from_env().is_request_authorized(request)
if not auth_outcome.authorized:
return respond(
request=request,
response={"reason": auth_outcome.reason},
code=http.FORBIDDEN,
headers={"X-Auth-Failure-Reason": auth_outcome.reason},
)
return super().render(request)


class RootResource(resource.Resource):
def __init__(self, mcp, web_path):
Expand Down
13 changes: 13 additions & 0 deletions tron/commands/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,22 @@ class RequestError(ValueError):
}


def get_sso_auth_token() -> str:
"""Generate an authentication token for the calling user from the Single Sign On provider, if configured"""
from okta_auth import get_and_cache_jwt_default # type: ignore
from tron.commands.cmd_utils import get_client_config

client_id = get_client_config().get("auth_sso_oidc_client_id")
return get_and_cache_jwt_default(client_id) if client_id else "" # type: ignore


def build_url_request(uri, data, headers=None, method=None):
headers = headers or default_headers
enc_data = urllib.parse.urlencode(data).encode() if data else None
if os.getenv("TRONCTL_API_AUTH") and (data or method.upper() == "POST"):
token = get_sso_auth_token()
if token:
headers["Authorization"] = f"Bearer {token}"
return urllib.request.Request(uri, enc_data, headers=headers, method=method)


Expand Down
3 changes: 3 additions & 0 deletions yelp_package/extra_requirements_yelp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ clusterman-metrics==2.2.1 # used by tron for pre-scaling for Spark runs
dateglob==1.1.1 # required by yelp-logging
geogrid==2.1.0 # required by yelp-logging
monk==3.0.4 # required by yelp-clog
okta-auth==1.0.1 # used for API auth
ply==3.11 # required by thriftpy2
pyjwt==2.9.0 # required by okta-auth
saml-helper==2.3.3 # required by okta-auth
scribereader==1.1.1 # used by tron to get tronjob logs
simplejson==3.19.2 # required by yelp-logging
srv-configs==1.3.4 # required by monk
Expand Down

0 comments on commit 9f35622

Please sign in to comment.