Skip to content

Commit

Permalink
Add deduplicate operator
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Jan 23, 2025
1 parent 38c8aea commit 08379bb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,34 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
yield from stream


class Deduplicate(StreamOperator):
"""Deduplicate the stream based on the given fields.
Args:
by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
Examples:
>>> from some_module import Deduplicate
>>> deduplicator = Deduplicate(by=["field1", "field2"])
>>> unique_stream = deduplicator.process(input_stream)
>>> for item in unique_stream:
>>> print(item)
"""

by: List[str]

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
seen = set()

for instance in stream:
# Compute a lightweight hash for the signature
signature = hash(tuple(dict_get(instance, field) for field in self.by))

if signature not in seen:
seen.add(signature)
yield instance


class Balance(StreamRefiner):
"""A class used to balance streams deterministically.
Expand Down
22 changes: 22 additions & 0 deletions tests/library/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CastFields,
CollateInstances,
Copy,
Deduplicate,
DeterministicBalancer,
DivideAllFieldsBy,
DuplicateInstances,
Expand Down Expand Up @@ -312,6 +313,27 @@ def test_flatten_instances(self):
tester=self,
)

def test_deduplicate_by_fields(self):
inputs = [
{"a": 1, "b": {"c": 2}},
{"a": 2, "b": {"c": 3}},
{"a": 1, "b": {"c": 2}}, # Duplicate based on "a" and "b/c"
{"a": 1, "b": {"c": 3}}, # Duplicate based on "a" and "b/c"
]

targets = [
{"a": 1, "b": {"c": 2}},
{"a": 2, "b": {"c": 3}},
{"a": 1, "b": {"c": 3}},
]

check_operator(
operator=Deduplicate(by=["a", "b/c"]),
inputs=inputs,
targets=targets,
tester=self,
)

def test_filter_by_values_with_required_values(self):
inputs = [
{"a": 1, "b": {"c": 2}},
Expand Down

0 comments on commit 08379bb

Please sign in to comment.