From 49bb6bfd0dacb686a3c3813bf12c67833ca01bfd Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 17 Feb 2022 13:12:19 +0100 Subject: [PATCH 1/2] also allow @extend_schema_field on django-filter filter method #660 --- drf_spectacular/contrib/django_filters.py | 21 +++++++++++++++------ tests/contrib/test_django_filters.py | 6 ++++++ tests/contrib/test_django_filters.yml | 5 +++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/drf_spectacular/contrib/django_filters.py b/drf_spectacular/contrib/django_filters.py index 14ffb29e..becae9d4 100644 --- a/drf_spectacular/contrib/django_filters.py +++ b/drf_spectacular/contrib/django_filters.py @@ -78,11 +78,16 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME, filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME, } - if has_override(filter_field, 'field'): - annotation = get_override(filter_field, 'field') + filter_method = self._get_filter_method(filterset_class, filter_field) + + if has_override(filter_field, 'field') or has_override(filter_method, 'field'): + annotation = ( + get_override(filter_field, 'field') or get_override(filter_method, 'field') + ) if is_basic_type(annotation): schema = build_basic_type(annotation) else: + # allow injecting raw schema via @extend_schema_field decorator schema = annotation elif isinstance(filter_field, tuple(unambiguous_mapping)): for cls in filter_field.__class__.__mro__: @@ -92,7 +97,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)): # NumberField is underspecified by itself. try to find the # type that makes the most sense or default to generic NUMBER - if filter_field.method: + if filter_method: schema = self._build_filter_method_type(filterset_class, filter_field) if schema['type'] not in ['integer', 'number']: schema = build_basic_type(OpenApiTypes.NUMBER) @@ -182,12 +187,16 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, for field_name in field_names ] - def _build_filter_method_type(self, filterset_class, filter_field): + def _get_filter_method(self, filterset_class, filter_field): if callable(filter_field.method): - filter_method = filter_field.method + return filter_field.method + elif isinstance(filter_field.method, str): + return getattr(filterset_class, filter_field.method) else: - filter_method = getattr(filterset_class, filter_field.method) + return None + def _build_filter_method_type(self, filterset_class, filter_field): + filter_method = self._get_filter_method(filterset_class, filter_field) try: filter_method_hints = get_type_hints(filter_method) except: # noqa: E722 diff --git a/tests/contrib/test_django_filters.py b/tests/contrib/test_django_filters.py index 5cb20b6c..02ee466a 100644 --- a/tests/contrib/test_django_filters.py +++ b/tests/contrib/test_django_filters.py @@ -88,6 +88,7 @@ class ProductFilter(FilterSet): int_id = NumberFilter(method='filter_method_typed') number_id = NumberFilter(method='filter_method_untyped', help_text='some injected help text') number_id_ext = NumberFilter(method=external_filter_method) + email = CharFilter(method='filter_method_decorated') # implicit filter declaration subproduct__sub_price = NumberFilter() # reverse relation other_sub_product__uuid = UUIDFilter() # forward relation @@ -129,6 +130,11 @@ def filter_method_typed(self, queryset, name, value: int): def filter_method_untyped(self, queryset, name, value): return queryset.filter(id=int(value)) # pragma: no cover + # email makes no sense here. it's just to test decoration + @extend_schema_field(OpenApiTypes.EMAIL) + def filter_method_decorated(self, queryset, name, value): + return queryset.filter(id=int(value)) + @extend_schema( examples=[ diff --git a/tests/contrib/test_django_filters.yml b/tests/contrib/test_django_filters.yml index fe01024c..d4ccbac0 100644 --- a/tests/contrib/test_django_filters.yml +++ b/tests/contrib/test_django_filters.yml @@ -36,6 +36,11 @@ paths: description: Multiple values may be separated by commas. explode: false style: form + - in: query + name: email + schema: + type: string + format: email - in: query name: in_categories schema: From 6f1de26a9524a4bd5ee6379d3fae0b25e87fc1aa Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 17 Feb 2022 14:49:26 +0100 Subject: [PATCH 2/2] django-filter: raise priority of explicitly given filter method type hints #660 --- drf_spectacular/contrib/django_filters.py | 50 +++++++++++------------ tests/contrib/test_django_filters.py | 2 +- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/drf_spectacular/contrib/django_filters.py b/drf_spectacular/contrib/django_filters.py index becae9d4..f1b8bbef 100644 --- a/drf_spectacular/contrib/django_filters.py +++ b/drf_spectacular/contrib/django_filters.py @@ -9,6 +9,8 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter +_NoHint = object() + class DjangoFilterExtension(OpenApiFilterExtension): """ @@ -31,7 +33,9 @@ class DjangoFilterExtension(OpenApiFilterExtension): - ``TypedMultipleChoiceFilter``: enum, multi handled In case of warnings or incorrect filter types, you can manually override the underlying - field type with a manual ``extend_schema_field`` decoration. + field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a + filter method for your filter field, you can attach ``extend_schema_field`` to that filter + method. .. code-block:: @@ -79,6 +83,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME, } filter_method = self._get_filter_method(filterset_class, filter_field) + filter_method_hint = self._get_filter_method_hint(filter_method) if has_override(filter_field, 'field') or has_override(filter_method, 'field'): annotation = ( @@ -89,6 +94,11 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, else: # allow injecting raw schema via @extend_schema_field decorator schema = annotation + elif filter_method_hint is not _NoHint: + if is_basic_type(filter_method_hint): + schema = build_basic_type(filter_method_hint) + else: + schema = build_basic_type(OpenApiTypes.STR) elif isinstance(filter_field, tuple(unambiguous_mapping)): for cls in filter_field.__class__.__mro__: if cls in unambiguous_mapping: @@ -97,23 +107,15 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)): # NumberField is underspecified by itself. try to find the # type that makes the most sense or default to generic NUMBER - if filter_method: - schema = self._build_filter_method_type(filterset_class, filter_field) - if schema['type'] not in ['integer', 'number']: - schema = build_basic_type(OpenApiTypes.NUMBER) + model_field = self._get_model_field(filter_field, model) + if isinstance(model_field, (models.IntegerField, models.AutoField)): + schema = build_basic_type(OpenApiTypes.INT) + elif isinstance(model_field, models.FloatField): + schema = build_basic_type(OpenApiTypes.FLOAT) + elif isinstance(model_field, models.DecimalField): + schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved else: - model_field = self._get_model_field(filter_field, model) - if isinstance(model_field, (models.IntegerField, models.AutoField)): - schema = build_basic_type(OpenApiTypes.INT) - elif isinstance(model_field, models.FloatField): - schema = build_basic_type(OpenApiTypes.FLOAT) - elif isinstance(model_field, models.DecimalField): - schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved - else: - schema = build_basic_type(OpenApiTypes.NUMBER) - elif filter_field.method: - # try to make best effort on the given method - schema = self._build_filter_method_type(filterset_class, filter_field) + schema = build_basic_type(OpenApiTypes.NUMBER) else: try: # the last resort is to lookup the type via the model or queryset field. @@ -125,7 +127,7 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, qs = auto_schema.view.get_queryset() model_field = qs.query.annotations[filter_field.field_name].field schema = auto_schema._map_model_field(model_field, direction=None) - except Exception as exc: + except Exception as exc: # pragma: no cover warn( f'Exception raised while trying resolve model field for django-filter ' f'field "{field_name}". Defaulting to string (Exception: {exc})' @@ -195,17 +197,11 @@ def _get_filter_method(self, filterset_class, filter_field): else: return None - def _build_filter_method_type(self, filterset_class, filter_field): - filter_method = self._get_filter_method(filterset_class, filter_field) + def _get_filter_method_hint(self, filter_method): try: - filter_method_hints = get_type_hints(filter_method) + return get_type_hints(filter_method)['value'] except: # noqa: E722 - filter_method_hints = {} - - if 'value' in filter_method_hints and is_basic_type(filter_method_hints['value']): - return build_basic_type(filter_method_hints['value']) - else: - return build_basic_type(OpenApiTypes.STR) + return _NoHint def _get_model_field(self, filter_field, model): if not filter_field.field_name: diff --git a/tests/contrib/test_django_filters.py b/tests/contrib/test_django_filters.py index 02ee466a..441c8e21 100644 --- a/tests/contrib/test_django_filters.py +++ b/tests/contrib/test_django_filters.py @@ -133,7 +133,7 @@ def filter_method_untyped(self, queryset, name, value): # email makes no sense here. it's just to test decoration @extend_schema_field(OpenApiTypes.EMAIL) def filter_method_decorated(self, queryset, name, value): - return queryset.filter(id=int(value)) + return queryset.filter(id=int(value)) # pragma: no cover @extend_schema(