diff --git a/src/drf_pydantic/fields.py b/src/drf_pydantic/fields.py new file mode 100644 index 0000000..809bfe4 --- /dev/null +++ b/src/drf_pydantic/fields.py @@ -0,0 +1,45 @@ +from enum import Enum +from typing import Type, Optional, Union + +from rest_framework.fields import empty +from rest_framework.serializers import ChoiceField + + +class EnumField(ChoiceField): + """ + Custom DRF field that restricts accepted values to that of a defined enum + """ + + default_error_messages = {"invalid": "No matching enum type"} + + def __init__(self, enum: Type[Enum], **kwargs): + self.enum = enum + kwargs.setdefault("choices", [(x, x.name) for x in self.enum]) + super().__init__(**kwargs) + + def run_validation( + self, data: Optional[Union[Enum, str, empty]] = empty + ) -> Optional[Enum]: + if data and data != empty and not isinstance(data, self.enum): + match_found = False + for x in self.enum: + if x.value == data: + match_found = True + break + + if not match_found: + self.fail("invalid") + + return super().run_validation(data) + + def to_internal_value(self, data: Optional[Union[Enum, str]]) -> Enum: + for choice in self.enum: + if choice == data or choice.name == data or choice.value == data: + return choice + self.fail("invalid") + + def to_representation(self, value: Optional[Union[Enum, str]]) -> Optional[str]: + if isinstance(value, self.enum): + return value.value + + return value diff --git a/src/drf_pydantic/parse.py b/src/drf_pydantic/parse.py index d7cf535..f6281fa 100644 --- a/src/drf_pydantic/parse.py +++ b/src/drf_pydantic/parse.py @@ -5,10 +5,12 @@ import typing import uuid import warnings +from enum import Enum import pydantic from rest_framework import serializers +from drf_pydantic.fields import EnumField # Cache serializer classes to ensure that there is a one-to-one relationship # between pydantic models and serializer classes @@ -38,6 +40,8 @@ # Constraint fields pydantic.ConstrainedStr: serializers.CharField, pydantic.ConstrainedInt: serializers.IntegerField, + # Enum fields + Enum: EnumField } @@ -122,6 +126,11 @@ def _convert_field(field: pydantic.fields.ModelField) -> serializers.Field: extra_kwargs["min_length"] = field.type_.min_length extra_kwargs["max_length"] = field.type_.max_length + if inspect.isclass(field.type_) and issubclass( + field.type_, Enum + ): + extra_kwargs['enum'] = field.type_ + # Scalar field if field.outer_type_ is field.type_: # Normal class diff --git a/tests/test_models.py b/tests/test_models.py index dea9fbc..9ef34d2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,12 +1,16 @@ import datetime import typing +from enum import Enum +from unittest.mock import patch import pydantic import pytest from rest_framework import serializers +from rest_framework.exceptions import ValidationError from drf_pydantic import BaseModel +from drf_pydantic.fields import EnumField def test_simple_model(): @@ -280,3 +284,100 @@ class Cart(BaseModel): name_field: serializers.Field = items_field.child.fields["name"] assert isinstance(name_field, serializers.CharField) + + +def test_enum_model(): + class CountryEnum(Enum): + US = 'US' + GB = 'GB' + FR = 'FR' + + class NotificationPreferenceEnum(Enum): + NONE = 'no_notifications' + SOME = 'some_notifications' + ALL = 'all_notifications' + + class Person(BaseModel): + name: str + email: pydantic.EmailStr + age: int + height: float + date_of_birth: datetime.date + notification_preferences: NotificationPreferenceEnum + original_nationality: typing.Optional[CountryEnum] + nationality: CountryEnum = CountryEnum.GB + + serializer = Person.drf_serializer() + + assert serializer.__class__.__name__ == "PersonSerializer" + assert len(serializer.fields) == 8 + + # Regular fields + assert isinstance(serializer.fields["name"], serializers.CharField) + assert isinstance(serializer.fields["email"], serializers.EmailField) + assert isinstance(serializer.fields["age"], serializers.IntegerField) + assert isinstance(serializer.fields["height"], serializers.FloatField) + assert isinstance(serializer.fields["date_of_birth"], serializers.DateField) + assert isinstance(serializer.fields["notification_preferences"], EnumField) + for name in [ + "name", + "email", + "age", + "height", + "date_of_birth", + "notification_preferences" + ]: + field = serializer.fields[name] + assert field.required is True, name + assert field.default is serializers.empty, name + assert field.allow_null is False, name + if name == 'notification_preferences': + assert field.choices == dict( + [(x, x.name) for x in NotificationPreferenceEnum] + ) + + # Optional + field: serializers.Field = serializer.fields["original_nationality"] + assert isinstance(field, EnumField) + assert field.allow_null is True + assert field.default is None + assert field.required is False + assert field.choices == dict([(x, x.name) for x in CountryEnum]) + + # With default + field: serializers.Field = serializer.fields["nationality"] + assert isinstance(field, EnumField) + assert field.allow_null is False + assert field.default == CountryEnum.GB + assert field.required is False + assert field.choices == dict([(x, x.name) for x in CountryEnum]) + + +def test_enum_value(): + + class SexEnum(Enum): + MALE = 'male' + FEMALE = 'female' + OTHER = 'other' + + class Human(BaseModel): + sex: SexEnum + age: int + + serializer = Human.drf_serializer + + normal_serializer = serializer(data={'sex': SexEnum.MALE, 'age': 25}) + + assert normal_serializer.is_valid() + assert normal_serializer.validated_data['sex'] == SexEnum.MALE + assert normal_serializer.validated_data['age'] == 25 + + value_serializer = serializer(data={'sex': 'male', 'age': 25}) + + assert value_serializer.is_valid() + assert value_serializer.validated_data['sex'] == SexEnum.MALE + assert value_serializer.validated_data['age'] == 25 + + bad_value_serializer = serializer(data={'sex': 'bad_value', 'age': 25}) + + assert bad_value_serializer.is_valid() is False