Skip to content

Commit

Permalink
Merge pull request #95 from rsinger86/typed-statements
Browse files Browse the repository at this point in the history
Typed statements
  • Loading branch information
rsinger86 authored Mar 2, 2023
2 parents 43e57f0 + 61d8397 commit 68c4c2b
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 27 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ See [migration notes](https://rsinger86.github.io/drf-access-policy/migration_no

# Changelog <a id="changelog"></a>

## 1.5 (March 2023)

- Adds `Statement` dataclass as alternative to dictionaries. Drops Python 3.5 support.


## 1.4 (March 2023)

- Fixes read-only scenario for FieldAccessMixin. Thanks @hungryseven!
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ See [migration notes](/migration_notes.html) if your policy statements combine m

## Requirements

Python 3.5+
Python 3.6+

## Installation

Expand Down
40 changes: 40 additions & 0 deletions docs/statement_dataclasses.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Statement Dataclass

A `Statement` dataclass can be used instead of dictionaries to define policy statements.

For example, the following policies are equivalent:

```python
from rest_access_policy import Statement


class ArticleAccessPolicy(AccessPolicy):
statements = [
Statement(
action="destroy",
principal=["*"],
effect="allow",
condition="is_author"
)
]

def is_author(self, request, view, action) -> bool:
article = view.get_object()
return request.user == article.author



class ArticleAccessPolicy(AccessPolicy):
statements = [
{
"action": ["destroy"],
"principal": ["*"],
"effect": "allow",
"condition": "is_author"
}
]

def is_author(self, request, view, action) -> bool:
article = view.get_object()
return request.user == article.author
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ nav:
- Usage - ViewSets: usage/view_set_usage.md
- Usage - Function-Based Views: usage/function_based_view_usage.md
- Statement Elements: statement_elements.md
- Statement Dataclasses: statement_dataclasses.md
- Policy Evaluation Logic: policy_logic.md
- Custom/Object-Level Conditions: object_level_permissions.md
- Field-Level Permissions: field_level_permissions.md
Expand Down
2 changes: 1 addition & 1 deletion rest_access_policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .exceptions import AccessPolicyException
from .access_policy import AccessPolicy
from .access_policy import AccessPolicy, Statement
from .access_view_set_mixin import AccessViewSetMixin
from .field_access_mixin import FieldAccessMixin
from .fields import PermittedPkRelatedField, PermittedSlugRelatedField
48 changes: 36 additions & 12 deletions rest_access_policy/access_policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from typing import List
from dataclasses import asdict, dataclass, field
from typing import List, Union

from django.conf import settings
from django.db.models import prefetch_related_objects
Expand All @@ -8,7 +9,7 @@

from rest_access_policy import AccessPolicyException

from .parsing import BoolAnd, BoolNot, BoolOr, ConditionOperand, BoolOperand
from .parsing import BoolAnd, BoolNot, BoolOperand, BoolOr, ConditionOperand


class AnonymousUser(object):
Expand Down Expand Up @@ -36,8 +37,23 @@ def allowed(self) -> bool:
return self._allowed


@dataclass
class Statement:
principal: Union[List[str], str]
action: Union[List[str], str]
effect: str = "deny" # allow, deny
condition: Union[List[str], str] = field(default_factory=list)
condition_expression: Union[List[str], str] = field(default_factory=list)

def __post_init__(self):
permitted = ("allow", "deny")

if self.effect not in ("allow", "deny"):
raise Exception(f"effect must be one of {permitted}")


class AccessPolicy(permissions.BasePermission):
statements: List[dict] = []
statements: List[Union[dict, Statement]] = []
field_permissions: dict = {}
id = None
group_prefix = "group:"
Expand All @@ -54,7 +70,7 @@ def has_permission(self, request, view) -> bool:
request.access_enforcement = AccessEnforcement(action=action, allowed=allowed)
return allowed

def get_policy_statements(self, request, view) -> List[dict]:
def get_policy_statements(self, request, view) -> List[Union[dict, Statement]]:
return self.statements

def get_user_group_values(self, user) -> List[str]:
Expand Down Expand Up @@ -88,7 +104,7 @@ def _get_invoked_action(self, view) -> str:
raise AccessPolicyException("Could not determine action of request")

def _evaluate_statements(
self, statements: List[dict], request, view, action: str
self, statements: List[Union[dict, Statement]], request, view, action: str
) -> bool:
statements = self._normalize_statements(statements)
matched = self._get_statements_matching_principal(request, statements)
Expand All @@ -109,8 +125,15 @@ def _evaluate_statements(

return True

def _normalize_statements(self, statements=[]) -> List[dict]:
def _normalize_statements(
self, statements: List[Union[dict, Statement]]
) -> List[dict]:
normalized = []

for statement in statements:
if isinstance(statement, Statement):
statement = asdict(statement)

if isinstance(statement["principal"], str):
statement["principal"] = [statement["principal"]]

Expand All @@ -127,7 +150,9 @@ def _normalize_statements(self, statements=[]) -> List[dict]:
elif isinstance(statement["condition_expression"], str):
statement["condition_expression"] = [statement["condition_expression"]]

return statements
normalized.append(statement)

return normalized

@classmethod
def _get_statements_matching_principal(
Expand Down Expand Up @@ -176,7 +201,7 @@ def _get_statements_matching_action(
"""
matched = []
SAFE_METHODS = ("GET", "HEAD", "OPTIONS")
http_method = "<method:%s>" % request.method.lower()
http_method = f"<method:{request.method.lower()}>"

for statement in statements:
if action in statement["action"] or "*" in statement["action"]:
Expand Down Expand Up @@ -264,8 +289,7 @@ def _check_condition(self, condition: str, request, view, action: str):

if type(result) is not bool:
raise AccessPolicyException(
"condition '%s' must return true/false, not %s"
% (condition, type(result))
f"condition '{condition}' must return true/false, not {type(result)}"
)

return result
Expand Down Expand Up @@ -294,6 +318,6 @@ def _get_condition_method(self, method_name: str):
return getattr(module, method_name)

raise AccessPolicyException(
"condition '%s' must be a method on the access policy or be defined in the 'reusable_conditions' module"
% method_name
f"condition '{method_name}' must be a method on the access policy "
f"or be defined in the 'reusable_conditions' module"
)
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ def readme():
"Framework :: Django :: 3.0",
"Framework :: Django :: 3.1",
"Framework :: Django :: 3.2",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
setup(
name="drf-access-policy",
version="1.4.0",
version="1.5.0",
description="Declarative access policies/permissions modeled after AWS' IAM policies.",
author="Robert Singer",
author_email="robertgsinger@gmail.com",
Expand Down
51 changes: 40 additions & 11 deletions test_project/testapp/tests/test_access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from rest_framework.decorators import api_view
from rest_framework.viewsets import ModelViewSet

from rest_access_policy.access_policy import Statement


class FakeRequest(object):
def __init__(self, user: Optional[User], method: str = "GET"):
Expand Down Expand Up @@ -73,7 +75,10 @@ def test_normalize_statements(self):
"principal": "group:admin",
"action": "destroy",
"condition": "is_nice_day",
}
},
Statement(
principal="user:1", action="create", condition="is_competent"
),
]
)

Expand All @@ -85,7 +90,14 @@ def test_normalize_statements(self):
"action": ["destroy"],
"condition": ["is_nice_day"],
"condition_expression": [],
}
},
{
"principal": ["user:1"],
"action": ["create"],
"effect": "deny",
"condition": ["is_competent"],
"condition_expression": [],
},
],
)

Expand All @@ -108,7 +120,9 @@ def test_get_statements_matching_principal_if_user_is_authenticated(self):

policy = AccessPolicy()

result = policy._get_statements_matching_principal(FakeRequest(user), statements)
result = policy._get_statements_matching_principal(
FakeRequest(user), statements
)

self.assertEqual(len(result), 4)
self.assertEqual(result[0]["action"], ["create"])
Expand Down Expand Up @@ -137,7 +151,9 @@ def test_get_statements_matching_principal_if_user_is_staff(self):

policy = AccessPolicy()

result = policy._get_statements_matching_principal(FakeRequest(user), statements)
result = policy._get_statements_matching_principal(
FakeRequest(user), statements
)

self.assertEqual(len(result), 5)
self.assertEqual(result[0]["action"], ["create"])
Expand Down Expand Up @@ -168,7 +184,9 @@ def test_get_statements_matching_principal_if_user_is_admin(self):

policy = AccessPolicy()

result = policy._get_statements_matching_principal(FakeRequest(user), statements)
result = policy._get_statements_matching_principal(
FakeRequest(user), statements
)

self.assertEqual(len(result), 6)
self.assertEqual(result[0]["action"], ["create"])
Expand All @@ -192,7 +210,9 @@ def test_get_statements_matching_principal_if_user_is_anonymous(self):

policy = AccessPolicy()

result = policy._get_statements_matching_principal(FakeRequest(user), statements)
result = policy._get_statements_matching_principal(
FakeRequest(user), statements
)

self.assertEqual(len(result), 2)
self.assertEqual(result[0]["action"], ["list"])
Expand Down Expand Up @@ -472,7 +492,8 @@ class TestPolicy(AccessPolicy):
policy._check_condition("is_sunny", None, None, "action")

self.assertTrue(
"condition 'is_sunny' must be a method on the access policy" in str(context.exception)
"condition 'is_sunny' must be a method on the access policy"
in str(context.exception)
)

def test_check_condition_throws_error_if_returns_non_boolean(self):
Expand Down Expand Up @@ -514,8 +535,12 @@ class TestPolicy(AccessPolicy):

policy = TestPolicy()

self.assertTrue(policy._check_condition("is_a_cat:Garfield", None, None, "action"))
self.assertFalse(policy._check_condition("is_a_cat:Snoopy", None, None, "action"))
self.assertTrue(
policy._check_condition("is_a_cat:Garfield", None, None, "action")
)
self.assertFalse(
policy._check_condition("is_a_cat:Snoopy", None, None, "action")
)

def test_get_condition_method_from_self(self):
class TestPolicy(AccessPolicy):
Expand Down Expand Up @@ -589,7 +614,9 @@ def test_evaluate_statements_true_if_any_allow_and_none_deny(
{"principal": "*", "action": "take_out_the_trash", "effect": "allow"},
]

result = policy._evaluate_statements(statements, FakeRequest(user), None, "create")
result = policy._evaluate_statements(
statements, FakeRequest(user), None, "create"
)
self.assertTrue(result)

def test_has_permission(self):
Expand Down Expand Up @@ -668,7 +695,9 @@ class TestPolicy(AccessPolicy):

def test_has_permission_is_false_when_user_is_none(self):
class TestPolicy(AccessPolicy):
statements = [{"action": "*", "principal": "authenticated", "effect": "allow"}]
statements = [
{"action": "*", "principal": "authenticated", "effect": "allow"}
]

view = FakeViewSet(action="create")
policy = TestPolicy()
Expand Down
33 changes: 33 additions & 0 deletions test_project/testapp/tests/test_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from dataclasses import asdict
from rest_framework.test import APITestCase

from rest_access_policy import AccessViewSetMixin, AccessPolicy, Statement
from rest_framework.viewsets import ViewSet
from rest_framework.permissions import AllowAny


class StatementTestCase(APITestCase):
def test_should_raise_error_if_invalid_effect(self):
with self.assertRaises(Exception) as context:
Statement(principal="*", action="build", effect="veto")

self.assertTrue("effect must be one of" in str(context.exception))

def test_to_dict(self):
statement = Statement(
principal="*",
action="build",
effect="allow",
condition_expression=["method1"],
)

self.assertEqual(
asdict(statement),
{
"principal": "*",
"action": "build",
"effect": "allow",
"condition": [],
"condition_expression": ["method1"],
},
)

0 comments on commit 68c4c2b

Please sign in to comment.