Skip to content

Commit

Permalink
Add option to ignore duplicates in DownloadConfig
Browse files Browse the repository at this point in the history
When duplicates are ignored, then only one of the multiple examples with the same keys is kept and no exception is raised.

PiperOrigin-RevId: 624137980
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Apr 12, 2024
1 parent 364fb82 commit 3a34edd
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 39 deletions.
1 change: 1 addition & 0 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,7 @@ def _generate_splits(
beam_runner=download_config.beam_runner,
shard_config=download_config.get_shard_config(),
example_writer=self._example_writer(),
ignore_duplicates=download_config.ignore_duplicates,
)
# Wrap the generation inside a context manager.
# If `beam` is used during generation (when a pipeline gets created),
Expand Down
8 changes: 5 additions & 3 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""Download manager interface."""

from __future__ import annotations

import concurrent.futures
Expand Down Expand Up @@ -100,6 +101,8 @@ class DownloadConfig:
used.
max_shard_size: optional maximum shard size in bytes. If `None`, 1 GiB is
used.
ignore_duplicates: whether to ignore duplicated examples with the same key.
If there are multiple examples with the same key, the first one is kept.
"""

extract_dir: Optional[epath.PathLike] = None
Expand All @@ -117,6 +120,7 @@ class DownloadConfig:
num_shards: Optional[int] = None
min_shard_size: int = shard_utils.DEFAULT_MIN_SHARD_SIZE
max_shard_size: int = shard_utils.DEFAULT_MAX_SHARD_SIZE
ignore_duplicates: bool = False

def get_shard_config(self) -> shard_utils.ShardConfig:
return shard_utils.ShardConfig(
Expand Down Expand Up @@ -248,9 +252,7 @@ def __init__(

self._download_dir: epath.Path = download_dir
self._extract_dir: epath.Path = extract_dir
self._manual_dir: Optional[epath.Path] = (
manual_dir # pytype: disable=annotation-type-mismatch # attribute-variable-annotations
)
self._manual_dir: Optional[epath.Path] = manual_dir # pytype: disable=annotation-type-mismatch # attribute-variable-annotations
self._manual_dir_instructions = utils.dedent(manual_dir_instructions)
self._download_dir.mkdir(parents=True, exist_ok=True)
self._extract_dir.mkdir(parents=True, exist_ok=True)
Expand Down
42 changes: 32 additions & 10 deletions tensorflow_datasets/core/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

"""To shuffle records (stable)."""

from collections.abc import Iterator, Sequence
import math
import os
import struct
from typing import Iterator, List, Optional
from typing import Optional
import uuid

from absl import logging
Expand Down Expand Up @@ -213,18 +214,28 @@ def del_file(self):
class Shuffler(object):
"""Stores data in temp buckets, restitute it shuffled."""

def __init__(self, dirpath, hash_salt, disable_shuffling: bool = False):
def __init__(
self,
dirpath,
hash_salt,
disable_shuffling: bool = False,
ignore_duplicates: bool = False,
):
"""Initialize Shuffler.
Args:
dirpath (string): directory in which to store temporary files.
hash_salt (string or bytes): salt to hash keys.
disable_shuffling (bool): specify whether to shuffle by hashing the key.
ignore_duplicates: whether to ignore duplicated examples with the same
key. If there are multiple examples with the same key, the first one is
kept. If this is False, then a `DuplicatedKeysError` is raised.
"""
grp_name = uuid.uuid4()
self._hasher = hashing.Hasher(hash_salt)
self._disable_shuffling = disable_shuffling
self._buckets: List[_Bucket] = []
self._ignore_duplicates = ignore_duplicates
self._buckets: list[_Bucket] = []
for i in range(BUCKETS_NUMBER):
bucket_name = 'bucket_%s_%03d.tmp' % (grp_name, i)
path = os.path.join(dirpath, bucket_name)
Expand All @@ -234,47 +245,58 @@ def __init__(self, dirpath, hash_salt, disable_shuffling: bool = False):
# To keep data in memory until enough data has been gathered.
self._in_memory = True
self._mem_buffer = []
self._seen_keys: set[int] = set()
self._num_examples = 0

@property
def size(self):
def size(self) -> int:
"""Return total size in bytes of records (not keys)."""
return self._total_bytes

@property
def bucket_lengths(self):
def bucket_lengths(self) -> Sequence[int]:
if self._in_memory:
return [len(self._mem_buffer)]
return [len(b) for b in self._buckets]

def _add_to_bucket(self, hkey, data):
@property
def num_examples(self) -> int:
return self._num_examples

def _add_to_bucket(self, hkey, data) -> None:
bucket_number = get_bucket_number(hkey=hkey, num_buckets=BUCKETS_NUMBER)
self._buckets[bucket_number].add(hkey, data)

def _add_to_mem_buffer(self, hkey, data):
def _add_to_mem_buffer(self, hkey, data) -> None:
self._mem_buffer.append((hkey, data))
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
for hkey, data in self._mem_buffer:
self._add_to_bucket(hkey, data)
self._mem_buffer = None
self._in_memory = False

def add(self, key, data):
def add(self, key, data) -> bool:
"""Add (key, data) to shuffler."""
if self._read_only:
raise AssertionError('add() cannot be called after __iter__.')
if not isinstance(data, six.binary_type):
raise AssertionError(
'Only bytes (not %s) can be stored in Shuffler!' % (type(data))
)
hkey = self._hasher.hash_key(key)
if self._ignore_duplicates:
if hkey in self._seen_keys:
return
self._seen_keys.add(hkey)
if self._disable_shuffling:
# Use the original key and not the hashed key to maintain the order.
hkey = key
else:
hkey = self._hasher.hash_key(key)
self._total_bytes += len(data)
if self._in_memory:
self._add_to_mem_buffer(hkey, data)
else:
self._add_to_bucket(hkey, data)
self._num_examples += 1

def __iter__(self) -> Iterator[type_utils.KeySerializedExample]:
self._read_only = True
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
max_examples_per_split: int | None,
example_writer: writer_lib.ExampleWriter,
shard_config: shard_utils.ShardConfig | None = None,
ignore_duplicates: bool = False,
):
self._split_dict = split_dict
self._features = features
Expand All @@ -143,6 +144,7 @@ def __init__(
self._beam_runner = beam_runner
self._beam_pipeline: Optional['beam.Pipeline'] = None
self._shard_config = shard_config
self._ignore_duplicates = ignore_duplicates
self._example_writer = example_writer

@contextlib.contextmanager
Expand Down Expand Up @@ -386,6 +388,7 @@ def _build_from_generator(
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)
for key, example in utils.tqdm(
generator,
Expand Down Expand Up @@ -428,6 +431,7 @@ def _build_from_pcollection(
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)

def _encode_example(key_ex, encode_fn=self._features.encode_example):
Expand Down
67 changes: 43 additions & 24 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,19 @@ def _raise_error_for_duplicated_keys(example1, example2, example_specs):
"""Log information about the examples and raise an AssertionError."""
msg = "Two examples share the same hashed key!"
logging.error(msg)
parser = example_parser.ExampleParser(example_specs)
ex1 = parser.parse_example(example1)
ex2 = parser.parse_example(example2)
logging.error("1st example: %s", ex1)
logging.error("2nd example: %s", ex2)
try:
parser = example_parser.ExampleParser(example_specs)
ex1 = parser.parse_example(example1)
ex2 = parser.parse_example(example2)
logging.error("1st example: %s", ex1)
logging.error("2nd example: %s", ex2)
except ValueError:
logging.error(
"Failed to parse examples! Cannot log them to see the examples behind"
" the duplicated keys. Raw example 1: %s, raw example 2: %s",
example1,
example2,
)
raise AssertionError(msg + " See logs above to view the examples.")


Expand Down Expand Up @@ -192,6 +200,7 @@ def __init__(
disable_shuffling: bool,
example_writer: ExampleWriter,
shard_config: shard_utils.ShardConfig | None = None,
ignore_duplicates: bool = False,
):
"""Initializes Writer.
Expand All @@ -202,14 +211,16 @@ def __init__(
disable_shuffling (bool): Specifies whether to shuffle the records.
example_writer: class that writes examples to disk or elsewhere.
shard_config: the configuration for creating shards.
ignore_duplicates: whether to ignore duplicated examples with the same
key. If False, a `DuplicatedKeysError` will be raised on duplicates.
"""
self._serializer = serializer
self._shuffler = shuffle.Shuffler(
dirpath=filename_template.data_dir,
hash_salt=hash_salt,
disable_shuffling=disable_shuffling,
ignore_duplicates=ignore_duplicates,
)
self._num_examples = 0
self._filename_template = filename_template
self._shard_config = shard_config or shard_utils.ShardConfig()
self._example_writer = example_writer
Expand All @@ -226,13 +237,12 @@ def write(self, key: int | bytes, example: Example):
"""
serialized_example = self._serializer.serialize_example(example=example)
self._shuffler.add(key, serialized_example)
self._num_examples += 1

def finalize(self) -> tuple[list[int], int]:
"""Effectively writes examples to the shards."""
filename = self._filename_template.sharded_filepaths_pattern()
shard_specs = _get_shard_specs(
num_examples=self._num_examples,
num_examples=self._shuffler.num_examples,
total_size=self._shuffler.size,
bucket_lengths=self._shuffler.bucket_lengths,
filename_template=self._filename_template,
Expand All @@ -245,7 +255,7 @@ def finalize(self) -> tuple[list[int], int]:
utils.tqdm(
self._shuffler,
desc=f"Shuffling {filename}...",
total=self._num_examples,
total=self._shuffler.num_examples,
unit=" examples",
leave=False,
mininterval=1.0,
Expand Down Expand Up @@ -322,6 +332,7 @@ def __init__(
disable_shuffling: bool,
example_writer: ExampleWriter,
shard_config: shard_utils.ShardConfig | None = None,
ignore_duplicates: bool = False,
):
"""Init BeamWriter.
Expand All @@ -336,6 +347,8 @@ def __init__(
disable_shuffling: bool, specifies whether to shuffle the records.
example_writer: class that writes examples to storage.
shard_config: the configuration for creating shards.
ignore_duplicates: whether to ignore duplicated examples with the same
key. If False, a `DuplicatedKeysError` will be raised on duplicates.
"""
self._original_state = dict(
serializer=serializer,
Expand All @@ -344,6 +357,7 @@ def __init__(
disable_shuffling=disable_shuffling,
shard_config=shard_config,
example_writer=example_writer,
ignore_duplicates=ignore_duplicates,
)
self._filename_template = filename_template
self._split_info_path = (
Expand All @@ -355,6 +369,7 @@ def __init__(
self._disable_shuffling = disable_shuffling
self._shard_config = shard_config or shard_utils.ShardConfig()
self._example_writer = example_writer
self._ignore_duplicates = ignore_duplicates

@functools.lru_cache()
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
Expand Down Expand Up @@ -416,29 +431,34 @@ def _write_final_shard(
raise AssertionError("Not a single example present in the PCollection!")
# There may be empty shards, this ensures there are no gaps.
shard_id = non_empty_shard_ids.index(original_shard_id)
examples = sorted(examples)
self._get_distribution(name="ShardLenDistribution").update(len(examples))
# Compare continuous examples
for ex0, ex1 in zip(examples[:-1], examples[1:]):
if ex0[0] == ex1[0]: # Different keys
_raise_error_for_duplicated_keys(
ex0[1], ex1[1], self._serializer.example_specs
)
example_by_key = {}
for key, example in examples:
if key in example_by_key:
if not self._ignore_duplicates:
_raise_error_for_duplicated_keys(
example_by_key[key], example, self._serializer.example_specs
)
else:
example_by_key[key] = example
shard_path = self._filename_template.sharded_filepath(
shard_index=shard_id, num_shards=len(non_empty_shard_ids)
)
with utils.incomplete_file(epath.Path(shard_path)) as tmp_path:
logging.info(
"Writing %d examples to %s.", len(examples), os.fspath(tmp_path)
"Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path)
)
record_keys = self._example_writer.write(
tmp_path, sorted(example_by_key.items())
)
record_keys = self._example_writer.write(tmp_path, examples)
self.inc_counter(name="written_shards")
# If there are record_keys, create index files.
if record_keys:
index_path = _get_index_path(os.fspath(shard_path))
_write_index_file(index_path, list(record_keys))
shard_size = sum(map(len, examples))
return _ShardInfo(id=shard_id, num_examples=len(examples), size=shard_size)
shard_size = sum(map(len, example_by_key.values()))
return _ShardInfo(
id=shard_id, num_examples=len(example_by_key), size=shard_size
)

def _number_of_shards(self, num_examples: int, total_size: int) -> int:
"""Returns the number of shards."""
Expand Down Expand Up @@ -468,11 +488,11 @@ def _assign_shard(
def _store_split_info(
self,
shard_infos: Sequence[_ShardInfo],
total_size: int,
) -> None:
"""Stores the split info to disk."""
shard_infos = sorted(shard_infos, key=lambda x: x.id)
shard_lengths = [info.num_examples for info in shard_infos]
total_size = sum([info.size for info in shard_infos])
with utils.incomplete_file(epath.Path(self._split_info_path)) as tmp_path:
tmp_path.write_text(
json.dumps({"total_size": total_size, "shard_lengths": shard_lengths})
Expand Down Expand Up @@ -553,8 +573,7 @@ def write_from_pcollection(self, examples_pcollection):
# (_ShardInfo)
| "CollectShardInfo" >> beam.transforms.combiners.ToList()
# [_ShardInfo]
| "CalculateSplitInfo"
>> beam.ParDo(self._store_split_info, total_size=total_size)
| "CalculateSplitInfo" >> beam.ParDo(self._store_split_info)
)

def finalize(self) -> tuple[list[int], int]:
Expand Down
Loading

0 comments on commit 3a34edd

Please sign in to comment.