Skip to content

Commit

Permalink
django-filter: raise priority of explicitly given filter method type …
Browse files Browse the repository at this point in the history
…hints #660
  • Loading branch information
tfranzel committed Feb 17, 2022
1 parent 49bb6bf commit 6f1de26
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
50 changes: 23 additions & 27 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter

_NoHint = object()


class DjangoFilterExtension(OpenApiFilterExtension):
"""
Expand All @@ -31,7 +33,9 @@ class DjangoFilterExtension(OpenApiFilterExtension):
- ``TypedMultipleChoiceFilter``: enum, multi handled
In case of warnings or incorrect filter types, you can manually override the underlying
field type with a manual ``extend_schema_field`` decoration.
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
method.
.. code-block::
Expand Down Expand Up @@ -79,6 +83,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
}
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)

if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
annotation = (
Expand All @@ -89,6 +94,11 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
else:
# allow injecting raw schema via @extend_schema_field decorator
schema = annotation
elif filter_method_hint is not _NoHint:
if is_basic_type(filter_method_hint):
schema = build_basic_type(filter_method_hint)
else:
schema = build_basic_type(OpenApiTypes.STR)
elif isinstance(filter_field, tuple(unambiguous_mapping)):
for cls in filter_field.__class__.__mro__:
if cls in unambiguous_mapping:
Expand All @@ -97,23 +107,15 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
# NumberField is underspecified by itself. try to find the
# type that makes the most sense or default to generic NUMBER
if filter_method:
schema = self._build_filter_method_type(filterset_class, filter_field)
if schema['type'] not in ['integer', 'number']:
schema = build_basic_type(OpenApiTypes.NUMBER)
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif filter_field.method:
# try to make best effort on the given method
schema = self._build_filter_method_type(filterset_class, filter_field)
schema = build_basic_type(OpenApiTypes.NUMBER)
else:
try:
# the last resort is to lookup the type via the model or queryset field.
Expand All @@ -125,7 +127,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
schema = auto_schema._map_model_field(model_field, direction=None)
except Exception as exc:
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
f'field "{field_name}". Defaulting to string (Exception: {exc})'
Expand Down Expand Up @@ -195,17 +197,11 @@ def _get_filter_method(self, filterset_class, filter_field):
else:
return None

def _build_filter_method_type(self, filterset_class, filter_field):
filter_method = self._get_filter_method(filterset_class, filter_field)
def _get_filter_method_hint(self, filter_method):
try:
filter_method_hints = get_type_hints(filter_method)
return get_type_hints(filter_method)['value']
except: # noqa: E722
filter_method_hints = {}

if 'value' in filter_method_hints and is_basic_type(filter_method_hints['value']):
return build_basic_type(filter_method_hints['value'])
else:
return build_basic_type(OpenApiTypes.STR)
return _NoHint

def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def filter_method_untyped(self, queryset, name, value):
# email makes no sense here. it's just to test decoration
@extend_schema_field(OpenApiTypes.EMAIL)
def filter_method_decorated(self, queryset, name, value):
return queryset.filter(id=int(value))
return queryset.filter(id=int(value)) # pragma: no cover


@extend_schema(
Expand Down

0 comments on commit 6f1de26

Please sign in to comment.