From 08379bb9478a3371de5d73b830a5750fcc4b2251 Mon Sep 17 00:00:00 2001 From: elronbandel Date: Thu, 23 Jan 2025 10:05:36 +0200 Subject: [PATCH] Add deduplicate operator Signed-off-by: elronbandel --- src/unitxt/operators.py | 28 ++++++++++++++++++++++++++++ tests/library/test_operators.py | 22 ++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 3d2cf09d9..37fbf587b 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -1900,6 +1900,34 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato yield from stream +class Deduplicate(StreamOperator): + """Deduplicate the stream based on the given fields. + + Args: + by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness. + + Examples: + >>> from some_module import Deduplicate + >>> deduplicator = Deduplicate(by=["field1", "field2"]) + >>> unique_stream = deduplicator.process(input_stream) + >>> for item in unique_stream: + >>> print(item) + """ + + by: List[str] + + def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: + seen = set() + + for instance in stream: + # Compute a lightweight hash for the signature + signature = hash(tuple(dict_get(instance, field) for field in self.by)) + + if signature not in seen: + seen.add(signature) + yield instance + + class Balance(StreamRefiner): """A class used to balance streams deterministically. diff --git a/tests/library/test_operators.py b/tests/library/test_operators.py index abb0842a1..af038dd65 100644 --- a/tests/library/test_operators.py +++ b/tests/library/test_operators.py @@ -13,6 +13,7 @@ CastFields, CollateInstances, Copy, + Deduplicate, DeterministicBalancer, DivideAllFieldsBy, DuplicateInstances, @@ -312,6 +313,27 @@ def test_flatten_instances(self): tester=self, ) + def test_deduplicate_by_fields(self): + inputs = [ + {"a": 1, "b": {"c": 2}}, + {"a": 2, "b": {"c": 3}}, + {"a": 1, "b": {"c": 2}}, # Duplicate based on "a" and "b/c" + {"a": 1, "b": {"c": 3}}, # Duplicate based on "a" and "b/c" + ] + + targets = [ + {"a": 1, "b": {"c": 2}}, + {"a": 2, "b": {"c": 3}}, + {"a": 1, "b": {"c": 3}}, + ] + + check_operator( + operator=Deduplicate(by=["a", "b/c"]), + inputs=inputs, + targets=targets, + tester=self, + ) + def test_filter_by_values_with_required_values(self): inputs = [ {"a": 1, "b": {"c": 2}},