-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from unittest.mock import MagicMock | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from tron.api.auth import AuthorizationFilter | ||
from tron.api.auth import AuthorizationOutcome | ||
|
||
|
||
@pytest.fixture | ||
def mock_auth_filter(): | ||
with patch("tron.api.auth.requests"): | ||
yield AuthorizationFilter("http://localhost:31337/whatever", True) | ||
|
||
|
||
def mock_request(path: str, token: str, method: str): | ||
res = MagicMock(path=path.encode(), method=method.encode()) | ||
res.getHeader.return_value = token | ||
return res | ||
|
||
|
||
def test_is_request_authorized(mock_auth_filter): | ||
mock_auth_filter.session.post.return_value.json.return_value = { | ||
"result": {"allowed": True, "reason": "User allowed"} | ||
} | ||
assert mock_auth_filter.is_request_authorized( | ||
mock_request("/allowed", "aaa.bbb.ccc", "get") | ||
) == AuthorizationOutcome(True, "User allowed") | ||
mock_auth_filter.session.post.assert_called_once_with( | ||
url="http://localhost:31337/whatever", | ||
json={ | ||
"input": { | ||
"path": "/allowed", | ||
"backend": "tron", | ||
"token": "aaa.bbb.ccc", | ||
"method": "get", | ||
} | ||
}, | ||
timeout=2, | ||
) | ||
|
||
|
||
def test_is_request_authorized_fail(mock_auth_filter): | ||
mock_auth_filter.session.post.side_effect = Exception | ||
assert mock_auth_filter.is_request_authorized( | ||
mock_request("/allowed", "eee.ddd.fff", "get") | ||
) == AuthorizationOutcome(False, "Auth backend error") | ||
|
||
|
||
def test_is_request_authorized_malformed(mock_auth_filter): | ||
mock_auth_filter.session.post.return_value.json.return_value = {"foo": "bar"} | ||
assert mock_auth_filter.is_request_authorized( | ||
mock_request("/allowed", "eee.ddd.fff", "post") | ||
) == AuthorizationOutcome(False, "Malformed auth response") | ||
|
||
|
||
def test_is_request_authorized_no_enforce(mock_auth_filter): | ||
mock_auth_filter.session.post.return_value.json.return_value = { | ||
"result": {"allowed": False, "reason": "Missing token"} | ||
} | ||
with patch.object(mock_auth_filter, "enforce", False): | ||
assert mock_auth_filter.is_request_authorized(mock_request("/foobar", "", "post")) == AuthorizationOutcome( | ||
True, "Auth dry-run" | ||
) | ||
|
||
|
||
def test_is_request_authorized_disabled(mock_auth_filter): | ||
mock_auth_filter.session.post.return_value.json.return_value = { | ||
"result": {"allowed": False, "reason": "Missing token"} | ||
} | ||
with patch.object(mock_auth_filter, "endpoint", None): | ||
assert mock_auth_filter.is_request_authorized(mock_request("/buzz", "", "post")) == AuthorizationOutcome( | ||
True, "Auth not enabled" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import logging | ||
import os | ||
from functools import lru_cache | ||
from typing import NamedTuple | ||
|
||
import cachetools.func | ||
import requests | ||
from twisted.web.server import Request | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
AUTH_CACHE_SIZE = 50000 | ||
AUTH_CACHE_TTL = 30 * 60 | ||
|
||
|
||
class AuthorizationOutcome(NamedTuple): | ||
authorized: bool | ||
reason: str | ||
|
||
|
||
class AuthorizationFilter: | ||
"""API request authorization via external system""" | ||
|
||
def __init__(self, endpoint: str, enforce: bool): | ||
"""Constructor | ||
:param str endpoint: HTTP endpoint of external authorization system | ||
:param bool enforce: whether to enforce authorization decisions | ||
""" | ||
self.endpoint = endpoint | ||
self.enforce = enforce | ||
self.session = requests.Session() | ||
|
||
@classmethod | ||
@lru_cache(maxsize=1) | ||
def get_from_env(cls) -> "AuthorizationFilter": | ||
return cls( | ||
endpoint=os.getenv("API_AUTH_ENDPOINT", ""), | ||
enforce=bool(os.getenv("API_AUTH_ENFORCE", "")), | ||
) | ||
|
||
def is_request_authorized(self, request: Request) -> AuthorizationOutcome: | ||
"""Check if API request is authorized | ||
:param Request request: API request object | ||
:return: auth outcome | ||
""" | ||
if not self.endpoint: | ||
return AuthorizationOutcome(True, "Auth not enabled") | ||
token = (request.getHeader("Authorization") or "").strip() | ||
token = token.split()[-1] if token else "" # removes "Bearer" prefix | ||
auth_outcome = self._is_request_authorized_impl( | ||
# path and method are byte arrays in twisted | ||
path=request.path.decode(), | ||
token=token, | ||
method=request.method.decode(), | ||
) | ||
return auth_outcome if self.enforce else AuthorizationOutcome(True, "Auth dry-run") | ||
|
||
@cachetools.func.ttl_cache(maxsize=AUTH_CACHE_SIZE, ttl=AUTH_CACHE_TTL) | ||
def _is_request_authorized_impl(self, path: str, token: str, method: str) -> AuthorizationOutcome: | ||
"""Check if API request is authorized | ||
:param str path: API path | ||
:param str token: authentication token | ||
:param str method: http method | ||
:return: auth outcome | ||
""" | ||
try: | ||
response = self.session.post( | ||
url=self.endpoint, | ||
json={ | ||
"input": { | ||
"path": path, | ||
"backend": "tron", | ||
"token": token, | ||
"method": method.lower(), | ||
}, | ||
}, | ||
timeout=2, | ||
).json() | ||
except Exception as e: | ||
logger.exception(f"Issue communicating with auth endpoint: {e}") | ||
return AuthorizationOutcome(False, "Auth backend error") | ||
|
||
if "result" not in response or "allowed" not in response["result"]: | ||
return AuthorizationOutcome(False, "Malformed auth response") | ||
|
||
if not response["result"]["allowed"]: | ||
reason = response["result"].get("reason", "Denied") | ||
return AuthorizationOutcome(False, reason) | ||
|
||
reason = response["result"].get("reason", "Ok") | ||
return AuthorizationOutcome(True, reason) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters