Skip to content

Commit

Permalink
Add Enum support (#10)
Browse files Browse the repository at this point in the history
* Add ENUM capability

* Add additional testing

* Clean up EnumField and update tests

* Add type hinting to EnumField, remove patch from test as it's not needed

* Add return type hinting
  • Loading branch information
thommor authored Jul 16, 2023
1 parent 43f93ae commit c6806b5
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/drf_pydantic/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from enum import Enum
from typing import Type, Optional, Union

from rest_framework.fields import empty
from rest_framework.serializers import ChoiceField


class EnumField(ChoiceField):
"""
Custom DRF field that restricts accepted values to that of a defined enum
"""

default_error_messages = {"invalid": "No matching enum type"}

def __init__(self, enum: Type[Enum], **kwargs):
self.enum = enum
kwargs.setdefault("choices", [(x, x.name) for x in self.enum])
super().__init__(**kwargs)

def run_validation(
self, data: Optional[Union[Enum, str, empty]] = empty
) -> Optional[Enum]:
if data and data != empty and not isinstance(data, self.enum):
match_found = False
for x in self.enum:
if x.value == data:
match_found = True
break

if not match_found:
self.fail("invalid")

return super().run_validation(data)

def to_internal_value(self, data: Optional[Union[Enum, str]]) -> Enum:
for choice in self.enum:
if choice == data or choice.name == data or choice.value == data:
return choice
self.fail("invalid")

def to_representation(self, value: Optional[Union[Enum, str]]) -> Optional[str]:
if isinstance(value, self.enum):
return value.value

return value
9 changes: 9 additions & 0 deletions src/drf_pydantic/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import typing
import uuid
import warnings
from enum import Enum

import pydantic

from rest_framework import serializers
from drf_pydantic.fields import EnumField

# Cache serializer classes to ensure that there is a one-to-one relationship
# between pydantic models and serializer classes
Expand Down Expand Up @@ -38,6 +40,8 @@
# Constraint fields
pydantic.ConstrainedStr: serializers.CharField,
pydantic.ConstrainedInt: serializers.IntegerField,
# Enum fields
Enum: EnumField
}


Expand Down Expand Up @@ -122,6 +126,11 @@ def _convert_field(field: pydantic.fields.ModelField) -> serializers.Field:
extra_kwargs["min_length"] = field.type_.min_length
extra_kwargs["max_length"] = field.type_.max_length

if inspect.isclass(field.type_) and issubclass(
field.type_, Enum
):
extra_kwargs['enum'] = field.type_

# Scalar field
if field.outer_type_ is field.type_:
# Normal class
Expand Down
101 changes: 101 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import datetime
import typing
from enum import Enum
from unittest.mock import patch

import pydantic
import pytest

from rest_framework import serializers
from rest_framework.exceptions import ValidationError

from drf_pydantic import BaseModel
from drf_pydantic.fields import EnumField


def test_simple_model():
Expand Down Expand Up @@ -280,3 +284,100 @@ class Cart(BaseModel):

name_field: serializers.Field = items_field.child.fields["name"]
assert isinstance(name_field, serializers.CharField)


def test_enum_model():
class CountryEnum(Enum):
US = 'US'
GB = 'GB'
FR = 'FR'

class NotificationPreferenceEnum(Enum):
NONE = 'no_notifications'
SOME = 'some_notifications'
ALL = 'all_notifications'

class Person(BaseModel):
name: str
email: pydantic.EmailStr
age: int
height: float
date_of_birth: datetime.date
notification_preferences: NotificationPreferenceEnum
original_nationality: typing.Optional[CountryEnum]
nationality: CountryEnum = CountryEnum.GB

serializer = Person.drf_serializer()

assert serializer.__class__.__name__ == "PersonSerializer"
assert len(serializer.fields) == 8

# Regular fields
assert isinstance(serializer.fields["name"], serializers.CharField)
assert isinstance(serializer.fields["email"], serializers.EmailField)
assert isinstance(serializer.fields["age"], serializers.IntegerField)
assert isinstance(serializer.fields["height"], serializers.FloatField)
assert isinstance(serializer.fields["date_of_birth"], serializers.DateField)
assert isinstance(serializer.fields["notification_preferences"], EnumField)
for name in [
"name",
"email",
"age",
"height",
"date_of_birth",
"notification_preferences"
]:
field = serializer.fields[name]
assert field.required is True, name
assert field.default is serializers.empty, name
assert field.allow_null is False, name
if name == 'notification_preferences':
assert field.choices == dict(
[(x, x.name) for x in NotificationPreferenceEnum]
)

# Optional
field: serializers.Field = serializer.fields["original_nationality"]
assert isinstance(field, EnumField)
assert field.allow_null is True
assert field.default is None
assert field.required is False
assert field.choices == dict([(x, x.name) for x in CountryEnum])

# With default
field: serializers.Field = serializer.fields["nationality"]
assert isinstance(field, EnumField)
assert field.allow_null is False
assert field.default == CountryEnum.GB
assert field.required is False
assert field.choices == dict([(x, x.name) for x in CountryEnum])


def test_enum_value():

class SexEnum(Enum):
MALE = 'male'
FEMALE = 'female'
OTHER = 'other'

class Human(BaseModel):
sex: SexEnum
age: int

serializer = Human.drf_serializer

normal_serializer = serializer(data={'sex': SexEnum.MALE, 'age': 25})

assert normal_serializer.is_valid()
assert normal_serializer.validated_data['sex'] == SexEnum.MALE
assert normal_serializer.validated_data['age'] == 25

value_serializer = serializer(data={'sex': 'male', 'age': 25})

assert value_serializer.is_valid()
assert value_serializer.validated_data['sex'] == SexEnum.MALE
assert value_serializer.validated_data['age'] == 25

bad_value_serializer = serializer(data={'sex': 'bad_value', 'age': 25})

assert bad_value_serializer.is_valid() is False

0 comments on commit c6806b5

Please sign in to comment.