From 7a764f8bbcd7335b428ffc127df6d6dccddd1734 Mon Sep 17 00:00:00 2001 From: Alexandre Girard Date: Mon, 23 Oct 2023 14:12:59 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20low-code=20CDK:=20Allow=20connector?= =?UTF-8?q?=20developers=20to=20specify=20the=20type=20of=20an=20added=20f?= =?UTF-8?q?ield=20(#31638)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: girarda Co-authored-by: erohmensing --- .../declarative_component_schema.yaml | 13 ++ .../models/declarative_component_schema.py | 74 ++++++---- .../parsers/model_to_component_factory.py | 26 +++- .../strategies/pagination_strategy.py | 7 +- .../paginators/strategies/stop_condition.py | 6 +- .../declarative/transformations/add_fields.py | 26 ++-- .../transformations/transformation.py | 4 +- .../declarative/interpolation/test_jinja.py | 14 ++ .../test_model_to_component_factory.py | 132 ++++++++++++++++-- .../transformations/test_add_fields.py | 41 ++++-- 10 files changed, 280 insertions(+), 63 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 019e0ba80204..f6c516591b6c 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -69,6 +69,10 @@ definitions: - "{{ record['updates'] }}" - "{{ record['MetaData']['LastUpdatedTime'] }}" - "{{ stream_partition['segment_id'] }}" + value_type: + title: Value Type + description: Type of the value. If not specified, the type will be inferred from the value. + "$ref": "#/definitions/ValueType" $parameters: type: object additionalProperties: true @@ -1987,6 +1991,15 @@ definitions: $parameters: type: object additionalProperties: true + ValueType: + title: Value Type + description: A schema type. + type: string + enum: + - string + - number + - integer + - boolean WaitTimeFromHeader: title: Wait Time Extracted From Response Header description: Extract wait time from a HTTP header in the response. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 4b7b4f5542a4..aab608bacc35 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -10,37 +10,6 @@ from typing_extensions import Literal -class AddedFieldDefinition(BaseModel): - type: Literal['AddedFieldDefinition'] - path: List[str] = Field( - ..., - description='List of strings defining the path where to add the value on the record.', - examples=[['segment_id'], ['metadata', 'segment_id']], - title='Path', - ) - value: str = Field( - ..., - description="Value of the new field. Use {{ record['existing_field'] }} syntax to refer to other fields in the record.", - examples=[ - "{{ record['updates'] }}", - "{{ record['MetaData']['LastUpdatedTime'] }}", - "{{ stream_partition['segment_id'] }}", - ], - title='Value', - ) - parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') - - -class AddFields(BaseModel): - type: Literal['AddFields'] - fields: List[AddedFieldDefinition] = Field( - ..., - description='List of transformations (path and corresponding value) that will be added to the record.', - title='Fields', - ) - parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') - - class AuthFlowType(Enum): oauth2_0 = 'oauth2.0' oauth1_0 = 'oauth1.0' @@ -694,6 +663,13 @@ class LegacySessionTokenAuthenticator(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') +class ValueType(Enum): + string = 'string' + number = 'number' + integer = 'integer' + boolean = 'boolean' + + class WaitTimeFromHeader(BaseModel): type: Literal['WaitTimeFromHeader'] header: str = Field( @@ -734,6 +710,42 @@ class WaitUntilTimeFromHeader(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') +class AddedFieldDefinition(BaseModel): + type: Literal['AddedFieldDefinition'] + path: List[str] = Field( + ..., + description='List of strings defining the path where to add the value on the record.', + examples=[['segment_id'], ['metadata', 'segment_id']], + title='Path', + ) + value: str = Field( + ..., + description="Value of the new field. Use {{ record['existing_field'] }} syntax to refer to other fields in the record.", + examples=[ + "{{ record['updates'] }}", + "{{ record['MetaData']['LastUpdatedTime'] }}", + "{{ stream_partition['segment_id'] }}", + ], + title='Value', + ) + value_type: Optional[ValueType] = Field( + None, + description='Type of the value. If not specified, the type will be inferred from the value.', + title='Value Type', + ) + parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') + + +class AddFields(BaseModel): + type: Literal['AddFields'] + fields: List[AddedFieldDefinition] = Field( + ..., + description='List of transformations (path and corresponding value) that will be added to the record.', + title='Fields', + ) + parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters') + + class ApiKeyAuthenticator(BaseModel): type: Literal['ApiKeyAuthenticator'] api_token: Optional[str] = Field( diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 763a7e22065f..8ab779dcad6b 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -80,6 +80,7 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import SimpleRetriever as SimpleRetrieverModel from airbyte_cdk.sources.declarative.models.declarative_component_schema import Spec as SpecModel from airbyte_cdk.sources.declarative.models.declarative_component_schema import SubstreamPartitionRouter as SubstreamPartitionRouterModel +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ValueType from airbyte_cdk.sources.declarative.models.declarative_component_schema import WaitTimeFromHeader as WaitTimeFromHeaderModel from airbyte_cdk.sources.declarative.models.declarative_component_schema import WaitUntilTimeFromHeader as WaitUntilTimeFromHeaderModel from airbyte_cdk.sources.declarative.partition_routers import ListPartitionRouter, SinglePartitionRouter, SubstreamPartitionRouter @@ -232,15 +233,36 @@ def _create_component_from_model(self, model: BaseModel, config: Config, **kwarg @staticmethod def create_added_field_definition(model: AddedFieldDefinitionModel, config: Config, **kwargs: Any) -> AddedFieldDefinition: interpolated_value = InterpolatedString.create(model.value, parameters=model.parameters or {}) - return AddedFieldDefinition(path=model.path, value=interpolated_value, parameters=model.parameters or {}) + return AddedFieldDefinition( + path=model.path, + value=interpolated_value, + value_type=ModelToComponentFactory._json_schema_type_name_to_type(model.value_type), + parameters=model.parameters or {}, + ) def create_add_fields(self, model: AddFieldsModel, config: Config, **kwargs: Any) -> AddFields: added_field_definitions = [ - self._create_component_from_model(model=added_field_definition_model, config=config) + self._create_component_from_model( + model=added_field_definition_model, + value_type=ModelToComponentFactory._json_schema_type_name_to_type(added_field_definition_model.value_type), + config=config, + ) for added_field_definition_model in model.fields ] return AddFields(fields=added_field_definitions, parameters=model.parameters or {}) + @staticmethod + def _json_schema_type_name_to_type(value_type: Optional[ValueType]) -> Optional[Type[Any]]: + if not value_type: + return None + names_to_types = { + ValueType.string: str, + ValueType.number: float, + ValueType.integer: int, + ValueType.boolean: bool, + } + return names_to_types[value_type] + @staticmethod def create_api_key_authenticator( model: ApiKeyAuthenticatorModel, config: Config, token_provider: Optional[TokenProvider] = None, **kwargs: Any diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py index fefe885bbc05..3ebe49a059b5 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/pagination_strategy.py @@ -4,9 +4,10 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, List, Mapping, Optional +from typing import Any, List, Optional import requests +from airbyte_cdk.sources.declarative.types import Record @dataclass @@ -23,7 +24,7 @@ def initial_token(self) -> Optional[Any]: """ @abstractmethod - def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Any]: + def next_page_token(self, response: requests.Response, last_records: List[Record]) -> Optional[Any]: """ :param response: response to process :param last_records: records extracted from the response @@ -32,7 +33,7 @@ def next_page_token(self, response: requests.Response, last_records: List[Mappin pass @abstractmethod - def reset(self): + def reset(self) -> None: """ Reset the pagination's inner state """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py index 827171bcf705..8732e39ffef5 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/paginators/strategies/stop_condition.py @@ -42,8 +42,12 @@ def next_page_token(self, response: requests.Response, last_records: List[Record return None return self._delegate.next_page_token(response, last_records) - def reset(self): + def reset(self) -> None: self._delegate.reset() def get_page_size(self) -> Optional[int]: return self._delegate.get_page_size() + + @property + def initial_token(self) -> Optional[Any]: + return self._delegate.initial_token diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/add_fields.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/add_fields.py index 7802e4edbc85..109f4fb8ca70 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/add_fields.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/add_fields.py @@ -3,7 +3,7 @@ # from dataclasses import InitVar, dataclass, field -from typing import Any, List, Mapping, Optional, Union +from typing import Any, List, Mapping, Optional, Type, Union import dpath.util from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -17,6 +17,7 @@ class AddedFieldDefinition: path: FieldPointer value: Union[InterpolatedString, str] + value_type: Optional[Type[Any]] parameters: InitVar[Mapping[str, Any]] @@ -26,6 +27,7 @@ class ParsedAddFieldDefinition: path: FieldPointer value: InterpolatedString + value_type: Optional[Type[Any]] parameters: InitVar[Mapping[str, Any]] @@ -85,10 +87,10 @@ class AddFields(RecordTransformation): parameters: InitVar[Mapping[str, Any]] _parsed_fields: List[ParsedAddFieldDefinition] = field(init=False, repr=False, default_factory=list) - def __post_init__(self, parameters: Mapping[str, Any]): + def __post_init__(self, parameters: Mapping[str, Any]) -> None: for add_field in self.fields: if len(add_field.path) < 1: - raise f"Expected a non-zero-length path for the AddFields transformation {add_field}" + raise ValueError(f"Expected a non-zero-length path for the AddFields transformation {add_field}") if not isinstance(add_field.value, InterpolatedString): if not isinstance(add_field.value, str): @@ -96,11 +98,16 @@ def __post_init__(self, parameters: Mapping[str, Any]): else: self._parsed_fields.append( ParsedAddFieldDefinition( - add_field.path, InterpolatedString.create(add_field.value, parameters=parameters), parameters=parameters + add_field.path, + InterpolatedString.create(add_field.value, parameters=parameters), + value_type=add_field.value_type, + parameters=parameters, ) ) else: - self._parsed_fields.append(ParsedAddFieldDefinition(add_field.path, add_field.value, parameters={})) + self._parsed_fields.append( + ParsedAddFieldDefinition(add_field.path, add_field.value, value_type=add_field.value_type, parameters={}) + ) def transform( self, @@ -109,12 +116,15 @@ def transform( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, ) -> Record: + if config is None: + config = {} kwargs = {"record": record, "stream_state": stream_state, "stream_slice": stream_slice} for parsed_field in self._parsed_fields: - value = parsed_field.value.eval(config, **kwargs) + valid_types = (parsed_field.value_type,) if parsed_field.value_type else None + value = parsed_field.value.eval(config, valid_types=valid_types, **kwargs) dpath.util.new(record, parsed_field.path, value) return record - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __eq__(self, other: Any) -> bool: + return bool(self.__dict__ == other.__dict__) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/transformation.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/transformation.py index 560bf39e1b08..dd91864a2537 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/transformation.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/transformations/transformation.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Any, Mapping, Optional -from airbyte_cdk.sources.declarative.types import Config, StreamSlice, StreamState +from airbyte_cdk.sources.declarative.types import Config, Record, StreamSlice, StreamState @dataclass @@ -18,7 +18,7 @@ class RecordTransformation: @abstractmethod def transform( self, - record: Mapping[str, Any], + record: Record, config: Optional[Config] = None, stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/interpolation/test_jinja.py b/airbyte-cdk/python/unit_tests/sources/declarative/interpolation/test_jinja.py index cb312d18977e..097afbb3487f 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/interpolation/test_jinja.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/interpolation/test_jinja.py @@ -19,6 +19,20 @@ def test_get_value_from_config(): assert val == "2022-01-01" +@pytest.mark.parametrize( + "valid_types, expected_value", + [ + pytest.param((str,), "1234J", id="test_value_is_a_string_if_valid_types_is_str"), + pytest.param(None, 1234j, id="test_value_is_interpreted_as_complex_number_by_default"), + ], +) +def test_get_value_with_complex_number(valid_types, expected_value): + s = "{{ config['value'] }}" + config = {"value": "1234J"} + val = interpolation.eval(s, config, valid_types=valid_types) + assert val == expected_value + + def test_get_value_from_stream_slice(): s = "{{ stream_slice['date'] }}" config = {"date": "2022-01-01"} diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index a1c9eec8cd0e..602ab3d50424 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -1362,7 +1362,7 @@ def test_remove_fields(self): expected = [RemoveFields(field_pointers=[["path", "to", "field1"], ["path2"]], parameters={})] assert stream.retriever.record_selector.transformations == expected - def test_add_fields(self): + def test_add_fields_no_value_type(self): content = f""" the_stream: type: DeclarativeStream @@ -1374,26 +1374,142 @@ def test_add_fields(self): - path: ["field1"] value: "static_value" """ - parsed_manifest = YamlDeclarativeSource._parse(content) - resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - resolved_manifest["type"] = "DeclarativeSource" - stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["the_stream"], {}) - - stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + expected = [ + AddFields( + fields=[ + AddedFieldDefinition( + path=["field1"], + value=InterpolatedString(string="static_value", default="static_value", parameters={}), + value_type=None, + parameters={}, + ) + ], + parameters={}, + ) + ] + self._test_add_fields(content, expected) - assert isinstance(stream, DeclarativeStream) + def test_add_fields_value_type_is_string(self): + content = f""" + the_stream: + type: DeclarativeStream + $parameters: + {self.base_parameters} + transformations: + - type: AddFields + fields: + - path: ["field1"] + value: "static_value" + value_type: string + """ expected = [ AddFields( fields=[ AddedFieldDefinition( path=["field1"], value=InterpolatedString(string="static_value", default="static_value", parameters={}), + value_type=str, + parameters={}, + ) + ], + parameters={}, + ) + ] + self._test_add_fields(content, expected) + + def test_add_fields_value_type_is_number(self): + content = f""" + the_stream: + type: DeclarativeStream + $parameters: + {self.base_parameters} + transformations: + - type: AddFields + fields: + - path: ["field1"] + value: "1" + value_type: number + """ + expected = [ + AddFields( + fields=[ + AddedFieldDefinition( + path=["field1"], + value=InterpolatedString(string="1", default="1", parameters={}), + value_type=float, parameters={}, ) ], parameters={}, ) ] + self._test_add_fields(content, expected) + + def test_add_fields_value_type_is_integer(self): + content = f""" + the_stream: + type: DeclarativeStream + $parameters: + {self.base_parameters} + transformations: + - type: AddFields + fields: + - path: ["field1"] + value: "1" + value_type: integer + """ + expected = [ + AddFields( + fields=[ + AddedFieldDefinition( + path=["field1"], + value=InterpolatedString(string="1", default="1", parameters={}), + value_type=int, + parameters={}, + ) + ], + parameters={}, + ) + ] + self._test_add_fields(content, expected) + + def test_add_fields_value_type_is_boolean(self): + content = f""" + the_stream: + type: DeclarativeStream + $parameters: + {self.base_parameters} + transformations: + - type: AddFields + fields: + - path: ["field1"] + value: False + value_type: boolean + """ + expected = [ + AddFields( + fields=[ + AddedFieldDefinition( + path=["field1"], + value=InterpolatedString(string="False", default="False", parameters={}), + value_type=bool, + parameters={}, + ) + ], + parameters={}, + ) + ] + self._test_add_fields(content, expected) + + def _test_add_fields(self, content, expected): + parsed_manifest = YamlDeclarativeSource._parse(content) + resolved_manifest = resolver.preprocess_manifest(parsed_manifest) + resolved_manifest["type"] = "DeclarativeSource" + stream_manifest = transformer.propagate_types_and_parameters("", resolved_manifest["the_stream"], {}) + + stream = factory.create_component(model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config) + + assert isinstance(stream, DeclarativeStream) assert stream.retriever.record_selector.transformations == expected def test_default_schema_loader(self): diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/transformations/test_add_fields.py b/airbyte-cdk/python/unit_tests/sources/declarative/transformations/test_add_fields.py index 1386f2847652..a10422fd71bf 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/transformations/test_add_fields.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/transformations/test_add_fields.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Any, List, Mapping, Tuple +from typing import Any, List, Mapping, Optional, Tuple import pytest from airbyte_cdk.sources.declarative.transformations import AddFields @@ -11,12 +11,22 @@ @pytest.mark.parametrize( - ["input_record", "field", "kwargs", "expected"], + ["input_record", "field", "field_type", "kwargs", "expected"], [ - pytest.param({"k": "v"}, [(["path"], "static_value")], {}, {"k": "v", "path": "static_value"}, id="add new static value"), + pytest.param({"k": "v"}, [(["path"], "static_value")], None, {}, {"k": "v", "path": "static_value"}, id="add new static value"), + pytest.param({"k": "v"}, [(["path"], "{{ 1 }}")], None, {}, {"k": "v", "path": 1}, id="add an expression evaluated as a number"), + pytest.param( + {"k": "v"}, + [(["path"], "{{ 1 }}")], + str, + {}, + {"k": "v", "path": "1"}, + id="add an expression evaluated as a string using the value_type field", + ), pytest.param( {"k": "v"}, [(["path"], "static_value"), (["path2"], "static_value2")], + None, {}, {"k": "v", "path": "static_value", "path2": "static_value2"}, id="add new multiple static values", @@ -24,15 +34,17 @@ pytest.param( {"k": "v"}, [(["nested", "path"], "static_value")], + None, {}, {"k": "v", "nested": {"path": "static_value"}}, id="set static value at nested path", ), - pytest.param({"k": "v"}, [(["k"], "new_value")], {}, {"k": "new_value"}, id="update value which already exists"), - pytest.param({"k": [0, 1]}, [(["k", 3], "v")], {}, {"k": [0, 1, None, "v"]}, id="Set element inside array"), + pytest.param({"k": "v"}, [(["k"], "new_value")], None, {}, {"k": "new_value"}, id="update value which already exists"), + pytest.param({"k": [0, 1]}, [(["k", 3], "v")], None, {}, {"k": [0, 1, None, "v"]}, id="Set element inside array"), pytest.param( {"k": "v"}, [(["k2"], '{{ config["shop"] }}')], + None, {"config": {"shop": "in-n-out"}}, {"k": "v", "k2": "in-n-out"}, id="set a value from the config using bracket notation", @@ -40,6 +52,7 @@ pytest.param( {"k": "v"}, [(["k2"], "{{ config.shop }}")], + None, {"config": {"shop": "in-n-out"}}, {"k": "v", "k2": "in-n-out"}, id="set a value from the config using dot notation", @@ -47,6 +60,7 @@ pytest.param( {"k": "v"}, [(["k2"], '{{ stream_state["cursor"] }}')], + None, {"stream_state": {"cursor": "t0"}}, {"k": "v", "k2": "t0"}, id="set a value from the state using bracket notation", @@ -54,6 +68,7 @@ pytest.param( {"k": "v"}, [(["k2"], "{{ stream_state.cursor }}")], + None, {"stream_state": {"cursor": "t0"}}, {"k": "v", "k2": "t0"}, id="set a value from the state using dot notation", @@ -61,6 +76,7 @@ pytest.param( {"k": "v"}, [(["k2"], '{{ stream_slice["start_date"] }}')], + None, {"stream_slice": {"start_date": "oct1"}}, {"k": "v", "k2": "oct1"}, id="set a value from the stream slice using bracket notation", @@ -68,6 +84,7 @@ pytest.param( {"k": "v"}, [(["k2"], "{{ stream_slice.start_date }}")], + None, {"stream_slice": {"start_date": "oct1"}}, {"k": "v", "k2": "oct1"}, id="set a value from the stream slice using dot notation", @@ -75,6 +92,7 @@ pytest.param( {"k": "v"}, [(["k2"], "{{ record.k }}")], + None, {}, {"k": "v", "k2": "v"}, id="set a value from a field in the record using dot notation", @@ -82,6 +100,7 @@ pytest.param( {"k": "v"}, [(["k2"], '{{ record["k"] }}')], + None, {}, {"k": "v", "k2": "v"}, id="set a value from a field in the record using bracket notation", @@ -89,6 +108,7 @@ pytest.param( {"k": {"nested": "v"}}, [(["k2"], "{{ record.k.nested }}")], + None, {}, {"k": {"nested": "v"}, "k2": "v"}, id="set a value from a nested field in the record using bracket notation", @@ -96,15 +116,20 @@ pytest.param( {"k": {"nested": "v"}}, [(["k2"], '{{ record["k"]["nested"] }}')], + None, {}, {"k": {"nested": "v"}, "k2": "v"}, id="set a value from a nested field in the record using bracket notation", ), - pytest.param({"k": "v"}, [(["k2"], "{{ 2 + 2 }}")], {}, {"k": "v", "k2": 4}, id="set a value from a jinja expression"), + pytest.param({"k": "v"}, [(["k2"], "{{ 2 + 2 }}")], None, {}, {"k": "v", "k2": 4}, id="set a value from a jinja expression"), ], ) def test_add_fields( - input_record: Mapping[str, Any], field: List[Tuple[FieldPointer, str]], kwargs: Mapping[str, Any], expected: Mapping[str, Any] + input_record: Mapping[str, Any], + field: List[Tuple[FieldPointer, str]], + field_type: Optional[str], + kwargs: Mapping[str, Any], + expected: Mapping[str, Any], ): - inputs = [AddedFieldDefinition(path=v[0], value=v[1], parameters={}) for v in field] + inputs = [AddedFieldDefinition(path=v[0], value=v[1], value_type=field_type, parameters={}) for v in field] assert AddFields(fields=inputs, parameters={"alas": "i live"}).transform(input_record, **kwargs) == expected