diff --git a/inference_schema/parameter_types/_constants.py b/inference_schema/parameter_types/_constants.py index 058def4..7387617 100644 --- a/inference_schema/parameter_types/_constants.py +++ b/inference_schema/parameter_types/_constants.py @@ -7,3 +7,8 @@ TIME_FORMAT = "%H:%M:%S.%f %z" ERR_PYTHON_DATA_NOT_JSON_SERIALIZABLE = "Invalid python data sample provided: ensure that the data is fully JSON " \ "serializable to be able to generate swagger schema from it. Actual error: {}" + + +class SWAGGER_FORMAT_CONSTANTS: + NUMPY_FORMAT = "numpy.ndarray" + PANDAS_FORMAT = "pandas.DataFrame:{}" diff --git a/inference_schema/parameter_types/numpy_parameter_type.py b/inference_schema/parameter_types/numpy_parameter_type.py index d04cab4..7b68bf7 100644 --- a/inference_schema/parameter_types/numpy_parameter_type.py +++ b/inference_schema/parameter_types/numpy_parameter_type.py @@ -5,6 +5,7 @@ import numpy as np from .abstract_parameter_type import AbstractParameterType from ._swagger_from_dtype import Dtype2Swagger +from ._constants import SWAGGER_FORMAT_CONSTANTS class NumpyParameterType(AbstractParameterType): @@ -96,6 +97,7 @@ def input_to_swagger(self): swagger_schema = Dtype2Swagger.handle_swagger_array(swagger_item_type, shape) items_count = len(self.sample_input) swagger_schema['example'] = self._get_swagger_sample(self.sample_input, items_count, swagger_schema['items']) + swagger_schema["format"] = SWAGGER_FORMAT_CONSTANTS.NUMPY_FORMAT return swagger_schema @classmethod diff --git a/inference_schema/parameter_types/pandas_parameter_type.py b/inference_schema/parameter_types/pandas_parameter_type.py index 4a2d99e..2bb6d51 100644 --- a/inference_schema/parameter_types/pandas_parameter_type.py +++ b/inference_schema/parameter_types/pandas_parameter_type.py @@ -6,6 +6,7 @@ import pandas as pd from .abstract_parameter_type import AbstractParameterType from ._util import get_swagger_for_list, get_swagger_for_nested_dict +from ._constants import SWAGGER_FORMAT_CONSTANTS class PandasParameterType(AbstractParameterType): @@ -143,5 +144,5 @@ def input_to_swagger(self): elif data_type.startswith('timedelta'): swagger_schema['properties']['data']['items']['properties'][str(column_name)]['format'] = \ 'timedelta' - + swagger_schema["format"] = SWAGGER_FORMAT_CONSTANTS.PANDAS_FORMAT.format(self.orient) return swagger_schema diff --git a/tests/resources/sample_nested_input_schema.json b/tests/resources/sample_nested_input_schema.json index 7991ca3..b44d8a6 100644 --- a/tests/resources/sample_nested_input_schema.json +++ b/tests/resources/sample_nested_input_schema.json @@ -7,6 +7,7 @@ "properties": { "input1": { "type": "array", + "format": "pandas.DataFrame:records", "items": { "type": "object", "required": ["name", "state"], @@ -22,6 +23,7 @@ }, "input2": { "type": "array", + "format": "numpy.ndarray", "items": { "type": "object", "properties": { diff --git a/tests/resources/sample_nested_output_schema.json b/tests/resources/sample_nested_output_schema.json index e0ec73d..1769ec7 100644 --- a/tests/resources/sample_nested_output_schema.json +++ b/tests/resources/sample_nested_output_schema.json @@ -4,6 +4,7 @@ "properties": { "output1": { "type": "array", + "format": "pandas.DataFrame:records", "items": { "type": "object", "required": ["state"], @@ -16,6 +17,7 @@ }, "output2": { "type": "array", + "format": "numpy.ndarray", "items": { "type": "object", "properties": { diff --git a/tests/resources/sample_numpy_input_schema.json b/tests/resources/sample_numpy_input_schema.json index c2b845c..8fad515 100644 --- a/tests/resources/sample_numpy_input_schema.json +++ b/tests/resources/sample_numpy_input_schema.json @@ -3,6 +3,7 @@ "properties": { "param": { "type": "array", + "format": "numpy.ndarray", "items": { "type": "object", "properties": { diff --git a/tests/resources/sample_numpy_output_schema.json b/tests/resources/sample_numpy_output_schema.json index 2725bb1..12abf17 100644 --- a/tests/resources/sample_numpy_output_schema.json +++ b/tests/resources/sample_numpy_output_schema.json @@ -1,5 +1,6 @@ { "type": "array", + "format": "numpy.ndarray", "items": { "type": "object", "properties": { diff --git a/tests/resources/sample_pandas_datetime_schema.json b/tests/resources/sample_pandas_datetime_schema.json index 93c0a75..9664992 100644 --- a/tests/resources/sample_pandas_datetime_schema.json +++ b/tests/resources/sample_pandas_datetime_schema.json @@ -3,6 +3,7 @@ "properties": { "param": { "type": "array", + "format": "pandas.DataFrame:records", "items": { "type": "object", "required": [ diff --git a/tests/resources/sample_pandas_input_schema.json b/tests/resources/sample_pandas_input_schema.json index 8f35485..944c749 100644 --- a/tests/resources/sample_pandas_input_schema.json +++ b/tests/resources/sample_pandas_input_schema.json @@ -3,6 +3,7 @@ "properties": { "param": { "type": "array", + "format": "pandas.DataFrame:records", "items": { "type": "object", "properties": { diff --git a/tests/resources/sample_pandas_output_schema.json b/tests/resources/sample_pandas_output_schema.json index cbd8c19..dd5bf4d 100644 --- a/tests/resources/sample_pandas_output_schema.json +++ b/tests/resources/sample_pandas_output_schema.json @@ -1,5 +1,6 @@ { "type": "array", + "format": "pandas.DataFrame:records", "items": { "type": "object", "properties": {