diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml new file mode 100644 index 0000000..f578b55 --- /dev/null +++ b/.github/workflows/main.yaml @@ -0,0 +1,38 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + name: Run tests + runs-on: ubuntu-latest + + steps: + # Checkout the code from the repository + - name: Checkout code + uses: actions/checkout@v3 + + # Set up Python + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python3 - + export PATH="$HOME/.local/bin:$PATH" + + # Install dependencies + - name: Install dependencies + run: | + poetry install --with tests,dev + + # Run tests + - name: Run tests + run: | + poetry run pytest . \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8d4d380..3f1595a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ +### VsCode ## +.vscode + ### PyCharm ### .idea ### Cache ### __pycache__ .ruff_cache +.pytest_cache \ No newline at end of file diff --git a/README.md b/README.md index ef2377b..bf3ee2d 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,9 @@ def add_user( - [x] ^ Same behavior for dependency functions ^ - [x] `:raises HTTPException: 401 Invalid token` in dependency function docstring parses -> `{401: {description: "Invalid token"}}` -- [ ] Avoid false positives (e.g., `raise HTTPException(401)` where `HTTPException` is not from `starlette` or +- [x] Avoid false positives (e.g., `raise HTTPException(401)` where `HTTPException` is not from `starlette` or `fastapi`) -- [ ] Parse custom classes that inherit from `HTTPException` +- [x] Parse custom classes that inherit from `HTTPException` - [ ] Check for custom response models - [ ] Allow to omit some responses from parsing - [ ] Code analysis for complex structures in detail and headers @@ -68,7 +68,9 @@ pip install fastapi-derive-responses ### Basic Usage -It will propagate `{404: {"description": "Item not found"}}` to OpenAPI schema. +You can just raise subclasses of `starlette.HTTPException` in endpoint. + +It will propagate `{404: {"description": "Item not found"}, 400: {"description": "Invalid item id"}}` to OpenAPI schema. ```python from fastapi import FastAPI, HTTPException @@ -77,11 +79,15 @@ from fastapi_derive_responses import AutoDeriveResponsesAPIRoute app = FastAPI() app.router.route_class = AutoDeriveResponsesAPIRoute +class CustomExeption(HTTPException): + ... @app.get("/items/{item_id}") def read_item(item_id: int): if item_id == 0: raise HTTPException(status_code=404, detail="Item not found") + if item_id == -1: + raise CustomExeption(status_code=400, detail="Invalid item id") return {"id": item_id, "name": "Item Name"} ``` @@ -117,16 +123,43 @@ def get_current_user(user_id: Annotated[int, Depends(auth_user)]): return {"id": user_id, "role": "admin"} ``` -Also, you can just raise `HTTPException` in your dependency function: +Also, you can just raise subclasses of `starlette.HTTPException` in your dependency function: ```python +class CustomException(HTTPException): + ... + def auth_user(token: int) -> int: if token < 100: raise HTTPException(401, "Invalid token") user_id = token - 100 if user_id not in statuses: - raise HTTPException(404, "User not found") + raise CustomException(404, "User not found") if user_id in banlist: raise HTTPException(403, "You are banned") return user_id ``` + + +Also, it works then you import your custom exception from other modules: + +```python +# exceptions.py +from starlette.exceptions import HTTPException + +class ImportedCustomException(HTTPException): + ... +``` +```python +# main.py +from exceptions import ImportedCustomException +... + +app = FastAPI(title="Custom Exception App") +app.router.route_class = AutoDeriveResponsesAPIRoute + +@app.get("/") +def raise_custom_exception(): + raise ImportedCustomException(status_code=601, detail="CustomException!") + +``` \ No newline at end of file diff --git a/fastapi_derive_responses/__init__.py b/fastapi_derive_responses/__init__.py index 6606062..d5bc7b1 100644 --- a/fastapi_derive_responses/__init__.py +++ b/fastapi_derive_responses/__init__.py @@ -7,19 +7,77 @@ import re import textwrap from collections import defaultdict +from typing import Callable, Any from fastapi.routing import APIRoute +from starlette.exceptions import HTTPException logger = logging.getLogger("fastapi-derive-responses") -def _responses_from_raise_in_source(function) -> dict: +def _inspect_function_source(function: Callable[..., Any]) -> dict[str, bool]: + """ + Parse the function's source code and inspect all imported and defined classes + to check if they are subclasses of HTTPException. + Return a dict: {class_name: bool, ...} where `True` indicates the class is a + subclass of HTTPException, and `False` indicates it is not. + """ + # Get file contents and AST parsing + path = inspect.getfile(function) + with open(path, "r") as file: + content = file.read() + + file_ast = ast.parse(content) + + # Inspecting imports for subclasses of HTTPException + import_details: list[tuple[str, list[str]]] = [] + for node in ast.walk(file_ast): + if isinstance(node, ast.ImportFrom): + module_path = node.module + imported_names = [alias.name for alias in node.names] + import_details.append((module_path, imported_names)) + + elif isinstance(node, ast.Import): + for alias in node.names: + module_path = alias.name + import_details.append((module_path, [alias.name])) + + inspected_subclasses = defaultdict(bool) + for module_path, imported_names in import_details: + try: + # Import module and accessing inspected names + module = importlib.import_module(module_path) + for name in imported_names: + imported_object = getattr(module, name) + # Check if the imported object is a subclass of HTTPException + if isinstance(imported_object, type) and issubclass(imported_object, HTTPException): + inspected_subclasses[name] = True + logger.debug(f"{name} is a subclass of HTTPException") + except (AttributeError, ModuleNotFoundError, ImportError) as e: + logger.debug(f"Error importing {module_path}: {str(e)}") + + # Inspect defined classes + defined_classes: list[ast.ClassDef] = [node for node in ast.walk(file_ast) if isinstance(node, ast.ClassDef)] + + for classDef in defined_classes: + # Check inheritance + for baseClass in classDef.bases: + if inspected_subclasses.get(baseClass.id): + inspected_subclasses[classDef.name] = True + logger.debug(f"{classDef.name} is a subclass of HTTPException") + break + + return inspected_subclasses + + +def _responses_from_raise_in_source(function: Callable[..., Any]) -> dict: """ Parse the endpoint's source code and extract all HTTPExceptions raised. Return a dict: {status_code: [{"description": str, "headers": dict}, ...], ...} """ derived = defaultdict(list) + exception_classes = _inspect_function_source(function) source = textwrap.dedent(inspect.getsource(function)) as_ast = ast.parse(source) exceptions = [node for node in ast.walk(as_ast) if isinstance(node, ast.Raise)] @@ -30,8 +88,8 @@ def _responses_from_raise_in_source(function) -> dict: try: match exception.exc: case ast.Call(func=ast.Name(func_id, func_ctx), args=call_args, keywords=keywords): - if func_id != "HTTPException": - logger.debug(f"Exception (Call) is not HTTPException: func={func_id}") + if not exception_classes[func_id]: + logger.debug(f"Exception (Call) is not subclass of HTTPException: func={func_id}") continue status_code = detail = headers = None diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..45c9e49 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +# pytest.ini +[pytest] +testpaths = tests +python_files = test*.py +python_functions = test_* +addopts = --disable-warnings +verbosity = 2 \ No newline at end of file diff --git a/tests/exceptions.py b/tests/exceptions.py new file mode 100644 index 0000000..9e41d6a --- /dev/null +++ b/tests/exceptions.py @@ -0,0 +1,5 @@ +from starlette.exceptions import HTTPException + + +class ImportedCustomException(HTTPException): + ... diff --git a/tests/test_custom.py b/tests/test_custom.py new file mode 100644 index 0000000..914f311 --- /dev/null +++ b/tests/test_custom.py @@ -0,0 +1,42 @@ +from fastapi import FastAPI +from starlette.testclient import TestClient + +from fastapi_derive_responses import AutoDeriveResponsesAPIRoute +from tests.exceptions import ImportedCustomException + + +def test_custom_exception_in_source(): + from fastapi import HTTPException + + class CustomException(HTTPException): + ... + + app = FastAPI(title="Custom Exception App") + app.router.route_class = AutoDeriveResponsesAPIRoute + + @app.get("/") + def raise_custom_exception(): + raise CustomException(status_code=601, detail="CustomException!") + + client = TestClient(app) + response = client.get("/openapi.json") + assert response.status_code == 200 + actual_dict = response.json() + responses = actual_dict["paths"]["/"]["get"]["responses"] + assert responses.get("601") == {"description": "CustomException!"} + + +def test_imported_custom_exception_in_source(): + app = FastAPI(title="Custom Exception App") + app.router.route_class = AutoDeriveResponsesAPIRoute + + @app.get("/") + def raise_custom_exception(): + raise ImportedCustomException(status_code=601, detail="CustomException!") + + client = TestClient(app) + response = client.get("/openapi.json") + assert response.status_code == 200 + actual_dict = response.json() + responses = actual_dict["paths"]["/"]["get"]["responses"] + assert responses.get("601") == {"description": "CustomException!"} diff --git a/tests/test_fake.py b/tests/test_fake.py new file mode 100644 index 0000000..b957d3e --- /dev/null +++ b/tests/test_fake.py @@ -0,0 +1,23 @@ +from fastapi import FastAPI +from starlette.testclient import TestClient + +from fastapi_derive_responses import AutoDeriveResponsesAPIRoute + +def test_fake_http_exception_in_source(): + class HTTPException(Exception): + def __init__(self, status_code: int, detail: str): + pass + + app = FastAPI(title="Custom Exception App") + app.router.route_class = AutoDeriveResponsesAPIRoute + + @app.get("/") + def raise_custom_exception(): + raise HTTPException(status_code=601, detail="CustomException!") + + client = TestClient(app) + response = client.get("/openapi.json") + assert response.status_code == 200 + actual_dict = response.json() + responses = actual_dict["paths"]["/"]["get"]["responses"] + assert responses.get("601") is None \ No newline at end of file