Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NoShuffleWriter (non-beam) to write examples to a single shard. #6354

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 69 additions & 34 deletions tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,23 +434,18 @@ def submit_split_generation(
generator=generator,
filename_template=filename_template,
disable_shuffling=disable_shuffling,
nondeterministic_order=nondeterministic_order,
)
# Depending on the type of generator, we use the corresponding
# `_build_from_xyz` method.
if isinstance(generator, Iterable):
if nondeterministic_order:
logging.warning(
'Enabling `nondeterministic_order` for a dataset that does not use'
' beam has no effect.'
)
return self._build_from_generator(**build_kwargs)
else: # Otherwise, beam required
unknown_generator_type = TypeError(
f'Invalid split generator value for split `{split_name}`. '
'Expected generator or apache_beam object. Got: '
f'{type(generator)}'
)
build_kwargs['nondeterministic_order'] = nondeterministic_order
if isinstance(generator, beam.PTransform):
# Generate the beam.PCollection
pcollection = self.beam_pipeline | split_name >> generator
Expand All @@ -467,14 +462,18 @@ def _build_from_generator(
generator: Iterable[KeyExample],
filename_template: naming.ShardedFileTemplate,
disable_shuffling: bool,
nondeterministic_order: bool,
) -> _SplitInfoFuture:
"""Split generator for example generators.

Args:
split_name: str,
generator: Iterable[KeyExample],
filename_template: Template to format the filename for a shard.
disable_shuffling: Specifies whether to shuffle the examples,
disable_shuffling: Specifies whether to shuffle the examples.
nondeterministic_order: If True, it will not assure deterministic ordering
when writing' examples to disk. This might result in quicker dataset
preparation

Returns:
future: The future containing the `tfds.core.SplitInfo`.
Expand All @@ -495,35 +494,71 @@ def _build_from_generator(
total_num_examples = None

serialized_info = self._features.get_serialized_info()
writer = writer_lib.Writer(
serializer=example_serializer.ExampleSerializer(serialized_info),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)
for key, example in utils.tqdm(
generator,
desc=f'Generating {split_name} examples...',
unit=' examples',
total=total_num_examples,
leave=False,
mininterval=1.0,
):
if nondeterministic_order:
logging.info(
'Order of examples does not matter, writing to a single shard using'
' NoShuffleWriter.'
)
writer = writer_lib.NoShuffleWriter(
serializer=example_serializer.ExampleSerializer(serialized_info),
filename_template=filename_template,
example_writer=self._example_writer,
)
# Encode and serialize the examples.
serialized_examples = []
for key, example in utils.tqdm(
generator,
desc=f'Generating {split_name} examples...',
unit=' examples',
total=total_num_examples,
leave=False,
mininterval=1.0,
):
try:
example = self._features.encode_example(example)
except Exception as e: # pylint: disable=broad-except
utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
serialized_examples.append(writer.write(key, example))
# Write the examples to a single shard.
shard_path = writer.finalize(examples=serialized_examples)
shard_lengths = [len(serialized_examples)]
total_size = shard_path.stat().length
logging.info(
'Done writing %s. Number of examples: %s (shards: %s)',
shard_path,
shard_lengths,
total_size,
)
else:
logging.info('Deterministic ordering is enabled, using Writer')
writer = writer_lib.Writer(
serializer=example_serializer.ExampleSerializer(serialized_info),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)
for key, example in utils.tqdm(
generator,
desc=f'Generating {split_name} examples...',
unit=' examples',
total=total_num_examples,
leave=False,
mininterval=1.0,
):
try:
example = self._features.encode_example(example)
except Exception as e: # pylint: disable=broad-except
utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
writer.write(key, example)
try:
example = self._features.encode_example(example)
shard_lengths, total_size = writer.finalize()
except Exception as e: # pylint: disable=broad-except
utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
writer.write(key, example)
try:
shard_lengths, total_size = writer.finalize()
except Exception as e: # pylint: disable=broad-except
utils.reraise(
e, prefix=f'Failed to finalize writing of split "{split_name}": '
)

utils.reraise(
e, prefix=f'Failed to finalize writing of split "{split_name}": '
)
split_info = splits_lib.SplitInfo(
name=split_name,
shard_lengths=shard_lengths,
Expand Down
34 changes: 34 additions & 0 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,40 @@ def finalize(self) -> tuple[list[int], int]:
return shard_lengths, self._shuffler.size


class NoShuffleWriter:
"""Shuffles / writes Examples to one single file."""

def __init__(
self,
serializer: example_serializer.Serializer,
filename_template: naming.ShardedFileTemplate,
example_writer: ExampleWriter,
):
"""Initializes Writer.

Args:
serializer: class that can serialize examples.
filename_template: template to format sharded filenames.
example_writer: class that writes examples to disk or elsewhere.
"""
self._serializer = serializer
self._filename_template = filename_template
self._example_writer = example_writer

def write(self, key: int | bytes, example: Example):
"""Writes given example."""
serialized_example = self._serializer.serialize_example(example=example)
return key, serialized_example

def finalize(self, examples: Iterable[KeyExample]) -> epath.Path:
"""Writes the examples to a single shard and returns its path."""
shard_path = self._filename_template.sharded_filepath(
shard_index=0, num_shards=1
)
self._example_writer.write(path=shard_path, examples=examples)
return shard_path


@dataclasses.dataclass
class _ShardInfo:
id: int
Expand Down
35 changes: 35 additions & 0 deletions tensorflow_datasets/core/writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,41 @@ def _build_pcollection(pipeline):
self.assertIn(file_format.file_suffix, f.name)


class NoShuffleWriterTest(parameterized.TestCase):

@parameterized.named_parameters(
('tfrecord', file_adapters.FileFormat.TFRECORD),
)
def test_write_beam(self, file_format: file_adapters.FileFormat):

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = epath.Path(tmp_dir)
filename_template = naming.ShardedFileTemplate(
dataset_name='foo',
split='train',
filetype_suffix=file_format.file_suffix,
data_dir=tmp_dir,
)
writer = writer_lib.NoShuffleWriter(
serializer=testing.DummySerializer('dummy specs'),
filename_template=filename_template,
example_writer=writer_lib.ExampleWriter(file_format=file_format),
)
to_write = [(i, str(i).encode('utf-8')) for i in range(10)]
examples = []
for key, record in to_write:
examples.append(writer.write(key, record))
shard_path = writer.finalize(examples)
self.assertEqual(len(to_write), len(examples))
self.assertEqual(
set([key for key, _ in examples]), set([key for key, _ in to_write])
)
files = list(tmp_dir.iterdir())
self.assertLen(files, 1)
self.assertIn(file_format.file_suffix, files[0].name)
self.assertIn(file_format.file_suffix, shard_path.name)


class CustomExampleWriter(writer_lib.ExampleWriter):

def __init__(self):
Expand Down
Loading