Skip to content

Commit

Permalink
Rebuild serializer on model refs update
Browse files Browse the repository at this point in the history
  • Loading branch information
JargeZ committed Aug 15, 2024
1 parent b5b8e66 commit 3c07b9b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/drf_pydantic/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework import serializers
from typing_extensions import dataclass_transform

from drf_pydantic.parse import create_serializer_from_model
from drf_pydantic.parse import SERIALIZER_REGISTRY, create_serializer_from_model


@dataclass_transform(kw_only_default=True, field_specifiers=(pydantic.Field,))
Expand Down Expand Up @@ -50,3 +50,25 @@ def __new__(
class BaseModel(pydantic.BaseModel, metaclass=ModelMetaclass):
# Populated by the metaclass or manually set by the user
drf_serializer: ClassVar[type[serializers.Serializer]]

@classmethod
def model_rebuild(
cls,
*,
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: Optional[dict[str, Any]] = None,
) -> bool | None:
ret = super().model_rebuild(
force=force,
raise_errors=raise_errors,
_parent_namespace_depth=_parent_namespace_depth,
_types_namespace=_types_namespace,
)

if cls in SERIALIZER_REGISTRY:
SERIALIZER_REGISTRY.pop(cls)
cls.drf_serializer = create_serializer_from_model(cls)

return ret
24 changes: 24 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,30 @@ class Person(BaseModel):
assert isinstance(job.fields["salary"], serializers.FloatField)


def test_nested_recursive_model():
class Task(BaseModel):
title: str
parent: "Task"
subtasks: list["Task"]

Task.model_rebuild()

serializer = Task.drf_serializer()

# Parent model
assert serializer.__class__.__name__ == "TaskSerializer"
assert len(serializer.fields) == 3
assert isinstance(serializer.fields["title"], serializers.CharField)
assert isinstance(serializer.fields["parent"], serializers.Serializer)
assert isinstance(serializer.fields["subtasks"], serializers.ListField)

parent: serializers.Serializer = serializer.fields["parent"]
assert parent.__class__.__name__ == "TaskSerializer"
assert len(parent.fields) == 3
assert isinstance(parent.fields["title"], serializers.CharField)
assert isinstance(parent.fields["parent"], serializers.Serializer)


def test_list_of_nested_models():
class Apartment(BaseModel):
floor: int
Expand Down

0 comments on commit 3c07b9b

Please sign in to comment.