Skip to content

Commit

Permalink
Merge pull request #1 from peplxx/main
Browse files Browse the repository at this point in the history
[Feature] Custom Exception Classes Recognition
  • Loading branch information
dantetemplar authored Jan 10, 2025
2 parents 5a21ac8 + adaa1ef commit 3ff5596
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 8 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Tests

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
name: Run tests
runs-on: ubuntu-latest

steps:
# Checkout the code from the repository
- name: Checkout code
uses: actions/checkout@v3

# Set up Python
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"
# Install dependencies
- name: Install dependencies
run: |
poetry install --with tests,dev
# Run tests
- name: Run tests
run: |
poetry run pytest .
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
### VsCode ##
.vscode

### PyCharm ###
.idea

### Cache ###
__pycache__
.ruff_cache
.pytest_cache
43 changes: 38 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def add_user(
- [x] ^ Same behavior for dependency functions ^
- [x] `:raises HTTPException: 401 Invalid token` in dependency function docstring parses -> `{401: {description: "Invalid
token"}}`
- [ ] Avoid false positives (e.g., `raise HTTPException(401)` where `HTTPException` is not from `starlette` or
- [x] Avoid false positives (e.g., `raise HTTPException(401)` where `HTTPException` is not from `starlette` or
`fastapi`)
- [ ] Parse custom classes that inherit from `HTTPException`
- [x] Parse custom classes that inherit from `HTTPException`
- [ ] Check for custom response models
- [ ] Allow to omit some responses from parsing
- [ ] Code analysis for complex structures in detail and headers
Expand All @@ -68,7 +68,9 @@ pip install fastapi-derive-responses

### Basic Usage

It will propagate `{404: {"description": "Item not found"}}` to OpenAPI schema.
You can just raise subclasses of `starlette.HTTPException` in endpoint.

It will propagate `{404: {"description": "Item not found"}, 400: {"description": "Invalid item id"}}` to OpenAPI schema.

```python
from fastapi import FastAPI, HTTPException
Expand All @@ -77,11 +79,15 @@ from fastapi_derive_responses import AutoDeriveResponsesAPIRoute
app = FastAPI()
app.router.route_class = AutoDeriveResponsesAPIRoute

class CustomExeption(HTTPException):
...

@app.get("/items/{item_id}")
def read_item(item_id: int):
if item_id == 0:
raise HTTPException(status_code=404, detail="Item not found")
if item_id == -1:
raise CustomExeption(status_code=400, detail="Invalid item id")
return {"id": item_id, "name": "Item Name"}
```

Expand Down Expand Up @@ -117,16 +123,43 @@ def get_current_user(user_id: Annotated[int, Depends(auth_user)]):
return {"id": user_id, "role": "admin"}
```

Also, you can just raise `HTTPException` in your dependency function:
Also, you can just raise subclasses of `starlette.HTTPException` in your dependency function:

```python
class CustomException(HTTPException):
...

def auth_user(token: int) -> int:
if token < 100:
raise HTTPException(401, "Invalid token")
user_id = token - 100
if user_id not in statuses:
raise HTTPException(404, "User not found")
raise CustomException(404, "User not found")
if user_id in banlist:
raise HTTPException(403, "You are banned")
return user_id
```


Also, it works then you import your custom exception from other modules:

```python
# exceptions.py
from starlette.exceptions import HTTPException

class ImportedCustomException(HTTPException):
...
```
```python
# main.py
from exceptions import ImportedCustomException
...

app = FastAPI(title="Custom Exception App")
app.router.route_class = AutoDeriveResponsesAPIRoute

@app.get("/")
def raise_custom_exception():
raise ImportedCustomException(status_code=601, detail="CustomException!")

```
64 changes: 61 additions & 3 deletions fastapi_derive_responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,77 @@
import re
import textwrap
from collections import defaultdict
from typing import Callable, Any

from fastapi.routing import APIRoute
from starlette.exceptions import HTTPException

logger = logging.getLogger("fastapi-derive-responses")


def _responses_from_raise_in_source(function) -> dict:
def _inspect_function_source(function: Callable[..., Any]) -> dict[str, bool]:
"""
Parse the function's source code and inspect all imported and defined classes
to check if they are subclasses of HTTPException.
Return a dict: {class_name: bool, ...} where `True` indicates the class is a
subclass of HTTPException, and `False` indicates it is not.
"""
# Get file contents and AST parsing
path = inspect.getfile(function)
with open(path, "r") as file:
content = file.read()

file_ast = ast.parse(content)

# Inspecting imports for subclasses of HTTPException
import_details: list[tuple[str, list[str]]] = []
for node in ast.walk(file_ast):
if isinstance(node, ast.ImportFrom):
module_path = node.module
imported_names = [alias.name for alias in node.names]
import_details.append((module_path, imported_names))

elif isinstance(node, ast.Import):
for alias in node.names:
module_path = alias.name
import_details.append((module_path, [alias.name]))

inspected_subclasses = defaultdict(bool)
for module_path, imported_names in import_details:
try:
# Import module and accessing inspected names
module = importlib.import_module(module_path)
for name in imported_names:
imported_object = getattr(module, name)
# Check if the imported object is a subclass of HTTPException
if isinstance(imported_object, type) and issubclass(imported_object, HTTPException):
inspected_subclasses[name] = True
logger.debug(f"{name} is a subclass of HTTPException")
except (AttributeError, ModuleNotFoundError, ImportError) as e:
logger.debug(f"Error importing {module_path}: {str(e)}")

# Inspect defined classes
defined_classes: list[ast.ClassDef] = [node for node in ast.walk(file_ast) if isinstance(node, ast.ClassDef)]

for classDef in defined_classes:
# Check inheritance
for baseClass in classDef.bases:
if inspected_subclasses.get(baseClass.id):
inspected_subclasses[classDef.name] = True
logger.debug(f"{classDef.name} is a subclass of HTTPException")
break

return inspected_subclasses


def _responses_from_raise_in_source(function: Callable[..., Any]) -> dict:
"""
Parse the endpoint's source code and extract all HTTPExceptions raised.
Return a dict: {status_code: [{"description": str, "headers": dict}, ...], ...}
"""
derived = defaultdict(list)

exception_classes = _inspect_function_source(function)
source = textwrap.dedent(inspect.getsource(function))
as_ast = ast.parse(source)
exceptions = [node for node in ast.walk(as_ast) if isinstance(node, ast.Raise)]
Expand All @@ -30,8 +88,8 @@ def _responses_from_raise_in_source(function) -> dict:
try:
match exception.exc:
case ast.Call(func=ast.Name(func_id, func_ctx), args=call_args, keywords=keywords):
if func_id != "HTTPException":
logger.debug(f"Exception (Call) is not HTTPException: func={func_id}")
if not exception_classes[func_id]:
logger.debug(f"Exception (Call) is not subclass of HTTPException: func={func_id}")
continue

status_code = detail = headers = None
Expand Down
7 changes: 7 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pytest.ini
[pytest]
testpaths = tests
python_files = test*.py
python_functions = test_*
addopts = --disable-warnings
verbosity = 2
5 changes: 5 additions & 0 deletions tests/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from starlette.exceptions import HTTPException


class ImportedCustomException(HTTPException):
...
42 changes: 42 additions & 0 deletions tests/test_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from fastapi import FastAPI
from starlette.testclient import TestClient

from fastapi_derive_responses import AutoDeriveResponsesAPIRoute
from tests.exceptions import ImportedCustomException


def test_custom_exception_in_source():
from fastapi import HTTPException

class CustomException(HTTPException):
...

app = FastAPI(title="Custom Exception App")
app.router.route_class = AutoDeriveResponsesAPIRoute

@app.get("/")
def raise_custom_exception():
raise CustomException(status_code=601, detail="CustomException!")

client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200
actual_dict = response.json()
responses = actual_dict["paths"]["/"]["get"]["responses"]
assert responses.get("601") == {"description": "CustomException!"}


def test_imported_custom_exception_in_source():
app = FastAPI(title="Custom Exception App")
app.router.route_class = AutoDeriveResponsesAPIRoute

@app.get("/")
def raise_custom_exception():
raise ImportedCustomException(status_code=601, detail="CustomException!")

client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200
actual_dict = response.json()
responses = actual_dict["paths"]["/"]["get"]["responses"]
assert responses.get("601") == {"description": "CustomException!"}
23 changes: 23 additions & 0 deletions tests/test_fake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import FastAPI
from starlette.testclient import TestClient

from fastapi_derive_responses import AutoDeriveResponsesAPIRoute

def test_fake_http_exception_in_source():
class HTTPException(Exception):
def __init__(self, status_code: int, detail: str):
pass

app = FastAPI(title="Custom Exception App")
app.router.route_class = AutoDeriveResponsesAPIRoute

@app.get("/")
def raise_custom_exception():
raise HTTPException(status_code=601, detail="CustomException!")

client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200
actual_dict = response.json()
responses = actual_dict["paths"]["/"]["get"]["responses"]
assert responses.get("601") is None

0 comments on commit 3ff5596

Please sign in to comment.