From a0de32862710970dc8d7b0e0db2826f6be963d9e Mon Sep 17 00:00:00 2001 From: Benjamin Pelletier Date: Sat, 4 Nov 2023 01:41:02 +0000 Subject: [PATCH] Improve base class support --- src/implicitdict/__init__.py | 3 +++ src/implicitdict/jsonschema.py | 3 +++ tests/test_inheritance.py | 18 +++++++++++++++++- tests/test_jsonschema.py | 13 ++++++++++++- tests/test_types.py | 19 +++++++++++++++++++ 5 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/implicitdict/__init__.py b/src/implicitdict/__init__.py index d4aaa6c..a43d575 100644 --- a/src/implicitdict/__init__.py +++ b/src/implicitdict/__init__.py @@ -226,6 +226,9 @@ def _parse_value(value, value_type: Type): # value is an ImplicitDict return ImplicitDict.parse(value, value_type) + if hasattr(value_type, "__orig_bases__") and value_type.__orig_bases__: + return value_type(_parse_value(value, value_type.__orig_bases__[0])) + else: # value is a non-generic type that is not an ImplicitDict return value_type(value) if value_type else value diff --git a/src/implicitdict/jsonschema.py b/src/implicitdict/jsonschema.py index 4a1046c..5963fb9 100644 --- a/src/implicitdict/jsonschema.py +++ b/src/implicitdict/jsonschema.py @@ -187,6 +187,9 @@ def _schema_for(value_type: Type, schema_vars_resolver: SchemaVarsResolver, sche if value_type == dict or issubclass(value_type, dict): return {"type": "object"}, False + if hasattr(value_type, "__orig_bases__") and value_type.__orig_bases__: + return _schema_for(value_type.__orig_bases__[0], schema_vars_resolver, schema_repository, context) + raise NotImplementedError(f"Automatic JSON schema generation for {value_type} type is not yet implemented") diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index a655d34..dd0ccc7 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -2,7 +2,8 @@ from implicitdict import ImplicitDict -from .test_types import InheritanceData, MySubclass +from .test_types import InheritanceData, MySubclass, SpecialSubclassesContainer, SpecialListClass, MySpecialClass, \ + SpecialComplexListClass def test_inheritance(): @@ -49,3 +50,18 @@ def test_inheritance(): assert subclass2.buzz == "burrs" assert subclass.has_default_baseclass == "In MyData 3" assert subclass.has_default_subclass == "In MySubclass 3" + + +def test_inherited_classes(): + data: SpecialSubclassesContainer = SpecialSubclassesContainer.example_value() + assert isinstance(data.special_list, SpecialListClass) + assert data.special_list.hello() == "SpecialListClass" + for item in data.special_list: + assert isinstance(item, MySpecialClass) + assert item.is_special + + assert isinstance(data.special_complex_list, SpecialComplexListClass) + assert data.special_complex_list.hello() == "SpecialComplexListClass" + for item in data.special_complex_list: + assert isinstance(item, MySubclass) + assert item.hello() == "MySubclass" diff --git a/tests/test_jsonschema.py b/tests/test_jsonschema.py index d274d27..5b53cb9 100644 --- a/tests/test_jsonschema.py +++ b/tests/test_jsonschema.py @@ -6,7 +6,8 @@ from implicitdict import ImplicitDict import jsonschema -from .test_types import ContainerData, InheritanceData, NestedDefinitionsData, NormalUsageData, OptionalData, PropertiesData, SpecialTypesData +from .test_types import ContainerData, InheritanceData, NestedDefinitionsData, NormalUsageData, OptionalData, \ + PropertiesData, SpecialTypesData, SpecialSubclassesContainer def _resolver(t: Type) -> SchemaVars: @@ -72,6 +73,16 @@ def test_inheritance(): data = InheritanceData.example_value() _verify_schema_validation(data, InheritanceData) + data = SpecialSubclassesContainer.example_value() + _verify_schema_validation(data, SpecialSubclassesContainer) + repo = {} + implicitdict.jsonschema.make_json_schema(SpecialSubclassesContainer, _resolver, repo) + name = _resolver(SpecialSubclassesContainer).name + schema = repo[name] + props = schema["properties"] + assert "special_list" in props + assert "special_complex_list" in props + def test_optional(): for data in OptionalData.example_values().values(): diff --git a/tests/test_types.py b/tests/test_types.py index 3de419f..5e54cf6 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -55,6 +55,25 @@ def hello(self) -> str: return "MySubclass" +class SpecialListClass(List[MySpecialClass]): + def hello(self) -> str: + return "SpecialListClass" + + +class SpecialComplexListClass(List[MySubclass]): + def hello(self) -> str: + return "SpecialComplexListClass" + + +class SpecialSubclassesContainer(ImplicitDict): + special_list: SpecialListClass + special_complex_list: SpecialComplexListClass + + @staticmethod + def example_value(): + return ImplicitDict.parse({'special_list': ['foo'], 'special_complex_list': [{'foo': 'oof'}]}, SpecialSubclassesContainer) + + class MutabilityData(ImplicitDict): primitive: str list_of_primitives: List[str]