Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Django filter typing improvements #661

Merged
merged 2 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 35 additions & 30 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter

_NoHint = object()


class DjangoFilterExtension(OpenApiFilterExtension):
"""
Expand All @@ -31,7 +33,9 @@ class DjangoFilterExtension(OpenApiFilterExtension):
- ``TypedMultipleChoiceFilter``: enum, multi handled

In case of warnings or incorrect filter types, you can manually override the underlying
field type with a manual ``extend_schema_field`` decoration.
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
method.

.. code-block::

Expand Down Expand Up @@ -78,12 +82,23 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
}
if has_override(filter_field, 'field'):
annotation = get_override(filter_field, 'field')
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)

if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
annotation = (
get_override(filter_field, 'field') or get_override(filter_method, 'field')
)
if is_basic_type(annotation):
schema = build_basic_type(annotation)
else:
# allow injecting raw schema via @extend_schema_field decorator
schema = annotation
elif filter_method_hint is not _NoHint:
if is_basic_type(filter_method_hint):
schema = build_basic_type(filter_method_hint)
else:
schema = build_basic_type(OpenApiTypes.STR)
elif isinstance(filter_field, tuple(unambiguous_mapping)):
for cls in filter_field.__class__.__mro__:
if cls in unambiguous_mapping:
Expand All @@ -92,23 +107,15 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
# NumberField is underspecified by itself. try to find the
# type that makes the most sense or default to generic NUMBER
if filter_field.method:
schema = self._build_filter_method_type(filterset_class, filter_field)
if schema['type'] not in ['integer', 'number']:
schema = build_basic_type(OpenApiTypes.NUMBER)
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif filter_field.method:
# try to make best effort on the given method
schema = self._build_filter_method_type(filterset_class, filter_field)
schema = build_basic_type(OpenApiTypes.NUMBER)
else:
try:
# the last resort is to lookup the type via the model or queryset field.
Expand All @@ -120,7 +127,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
schema = auto_schema._map_model_field(model_field, direction=None)
except Exception as exc:
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
f'field "{field_name}". Defaulting to string (Exception: {exc})'
Expand Down Expand Up @@ -182,21 +189,19 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
for field_name in field_names
]

def _build_filter_method_type(self, filterset_class, filter_field):
def _get_filter_method(self, filterset_class, filter_field):
if callable(filter_field.method):
filter_method = filter_field.method
return filter_field.method
elif isinstance(filter_field.method, str):
return getattr(filterset_class, filter_field.method)
else:
filter_method = getattr(filterset_class, filter_field.method)
return None

def _get_filter_method_hint(self, filter_method):
try:
filter_method_hints = get_type_hints(filter_method)
return get_type_hints(filter_method)['value']
except: # noqa: E722
filter_method_hints = {}

if 'value' in filter_method_hints and is_basic_type(filter_method_hints['value']):
return build_basic_type(filter_method_hints['value'])
else:
return build_basic_type(OpenApiTypes.STR)
return _NoHint

def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
Expand Down
6 changes: 6 additions & 0 deletions tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ProductFilter(FilterSet):
int_id = NumberFilter(method='filter_method_typed')
number_id = NumberFilter(method='filter_method_untyped', help_text='some injected help text')
number_id_ext = NumberFilter(method=external_filter_method)
email = CharFilter(method='filter_method_decorated')
# implicit filter declaration
subproduct__sub_price = NumberFilter() # reverse relation
other_sub_product__uuid = UUIDFilter() # forward relation
Expand Down Expand Up @@ -129,6 +130,11 @@ def filter_method_typed(self, queryset, name, value: int):
def filter_method_untyped(self, queryset, name, value):
return queryset.filter(id=int(value)) # pragma: no cover

# email makes no sense here. it's just to test decoration
@extend_schema_field(OpenApiTypes.EMAIL)
def filter_method_decorated(self, queryset, name, value):
return queryset.filter(id=int(value)) # pragma: no cover


@extend_schema(
examples=[
Expand Down
5 changes: 5 additions & 0 deletions tests/contrib/test_django_filters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ paths:
description: Multiple values may be separated by commas.
explode: false
style: form
- in: query
name: email
schema:
type: string
format: email
- in: query
name: in_categories
schema:
Expand Down