Skip to content

Commit

Permalink
Merge pull request #661 from tfranzel/django_filter_improvements
Browse files Browse the repository at this point in the history
Django filter typing improvements
  • Loading branch information
tfranzel authored Feb 18, 2022
2 parents 5da99c5 + 6f1de26 commit fb4f437
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
65 changes: 35 additions & 30 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter

_NoHint = object()


class DjangoFilterExtension(OpenApiFilterExtension):
"""
Expand All @@ -31,7 +33,9 @@ class DjangoFilterExtension(OpenApiFilterExtension):
- ``TypedMultipleChoiceFilter``: enum, multi handled
In case of warnings or incorrect filter types, you can manually override the underlying
field type with a manual ``extend_schema_field`` decoration.
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
method.
.. code-block::
Expand Down Expand Up @@ -78,12 +82,23 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
}
if has_override(filter_field, 'field'):
annotation = get_override(filter_field, 'field')
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)

if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
annotation = (
get_override(filter_field, 'field') or get_override(filter_method, 'field')
)
if is_basic_type(annotation):
schema = build_basic_type(annotation)
else:
# allow injecting raw schema via @extend_schema_field decorator
schema = annotation
elif filter_method_hint is not _NoHint:
if is_basic_type(filter_method_hint):
schema = build_basic_type(filter_method_hint)
else:
schema = build_basic_type(OpenApiTypes.STR)
elif isinstance(filter_field, tuple(unambiguous_mapping)):
for cls in filter_field.__class__.__mro__:
if cls in unambiguous_mapping:
Expand All @@ -92,23 +107,15 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
# NumberField is underspecified by itself. try to find the
# type that makes the most sense or default to generic NUMBER
if filter_field.method:
schema = self._build_filter_method_type(filterset_class, filter_field)
if schema['type'] not in ['integer', 'number']:
schema = build_basic_type(OpenApiTypes.NUMBER)
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif filter_field.method:
# try to make best effort on the given method
schema = self._build_filter_method_type(filterset_class, filter_field)
schema = build_basic_type(OpenApiTypes.NUMBER)
else:
try:
# the last resort is to lookup the type via the model or queryset field.
Expand All @@ -120,7 +127,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
schema = auto_schema._map_model_field(model_field, direction=None)
except Exception as exc:
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
f'field "{field_name}". Defaulting to string (Exception: {exc})'
Expand Down Expand Up @@ -182,21 +189,19 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
for field_name in field_names
]

def _build_filter_method_type(self, filterset_class, filter_field):
def _get_filter_method(self, filterset_class, filter_field):
if callable(filter_field.method):
filter_method = filter_field.method
return filter_field.method
elif isinstance(filter_field.method, str):
return getattr(filterset_class, filter_field.method)
else:
filter_method = getattr(filterset_class, filter_field.method)
return None

def _get_filter_method_hint(self, filter_method):
try:
filter_method_hints = get_type_hints(filter_method)
return get_type_hints(filter_method)['value']
except: # noqa: E722
filter_method_hints = {}

if 'value' in filter_method_hints and is_basic_type(filter_method_hints['value']):
return build_basic_type(filter_method_hints['value'])
else:
return build_basic_type(OpenApiTypes.STR)
return _NoHint

def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
Expand Down
6 changes: 6 additions & 0 deletions tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ProductFilter(FilterSet):
int_id = NumberFilter(method='filter_method_typed')
number_id = NumberFilter(method='filter_method_untyped', help_text='some injected help text')
number_id_ext = NumberFilter(method=external_filter_method)
email = CharFilter(method='filter_method_decorated')
# implicit filter declaration
subproduct__sub_price = NumberFilter() # reverse relation
other_sub_product__uuid = UUIDFilter() # forward relation
Expand Down Expand Up @@ -129,6 +130,11 @@ def filter_method_typed(self, queryset, name, value: int):
def filter_method_untyped(self, queryset, name, value):
return queryset.filter(id=int(value)) # pragma: no cover

# email makes no sense here. it's just to test decoration
@extend_schema_field(OpenApiTypes.EMAIL)
def filter_method_decorated(self, queryset, name, value):
return queryset.filter(id=int(value)) # pragma: no cover


@extend_schema(
examples=[
Expand Down
5 changes: 5 additions & 0 deletions tests/contrib/test_django_filters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ paths:
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: email
schema:
type: string
format: email
- in: query
name: in_categories
schema:
Expand Down

0 comments on commit fb4f437

Please sign in to comment.