From f3fb8329a42b6846ecbeeb6079e6bd65b1815a12 Mon Sep 17 00:00:00 2001 From: The TensorFlow Datasets Authors Date: Tue, 19 Nov 2024 13:06:11 -0800 Subject: [PATCH] Add NoShuffleWriter (non-beam) to write examples to a single shard. PiperOrigin-RevId: 698118549 --- tensorflow_datasets/core/split_builder.py | 103 +++++++++++++++------- tensorflow_datasets/core/writer.py | 34 +++++++ tensorflow_datasets/core/writer_test.py | 35 ++++++++ 3 files changed, 138 insertions(+), 34 deletions(-) diff --git a/tensorflow_datasets/core/split_builder.py b/tensorflow_datasets/core/split_builder.py index 58059e56a05..598aa27377f 100644 --- a/tensorflow_datasets/core/split_builder.py +++ b/tensorflow_datasets/core/split_builder.py @@ -434,15 +434,11 @@ def submit_split_generation( generator=generator, filename_template=filename_template, disable_shuffling=disable_shuffling, + nondeterministic_order=nondeterministic_order, ) # Depending on the type of generator, we use the corresponding # `_build_from_xyz` method. if isinstance(generator, Iterable): - if nondeterministic_order: - logging.warning( - 'Enabling `nondeterministic_order` for a dataset that does not use' - ' beam has no effect.' - ) return self._build_from_generator(**build_kwargs) else: # Otherwise, beam required unknown_generator_type = TypeError( @@ -450,7 +446,6 @@ def submit_split_generation( 'Expected generator or apache_beam object. Got: ' f'{type(generator)}' ) - build_kwargs['nondeterministic_order'] = nondeterministic_order if isinstance(generator, beam.PTransform): # Generate the beam.PCollection pcollection = self.beam_pipeline | split_name >> generator @@ -467,6 +462,7 @@ def _build_from_generator( generator: Iterable[KeyExample], filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, + nondeterministic_order: bool, ) -> _SplitInfoFuture: """Split generator for example generators. @@ -474,7 +470,10 @@ def _build_from_generator( split_name: str, generator: Iterable[KeyExample], filename_template: Template to format the filename for a shard. - disable_shuffling: Specifies whether to shuffle the examples, + disable_shuffling: Specifies whether to shuffle the examples. + nondeterministic_order: If True, it will not assure deterministic ordering + when writing' examples to disk. This might result in quicker dataset + preparation Returns: future: The future containing the `tfds.core.SplitInfo`. @@ -495,35 +494,71 @@ def _build_from_generator( total_num_examples = None serialized_info = self._features.get_serialized_info() - writer = writer_lib.Writer( - serializer=example_serializer.ExampleSerializer(serialized_info), - filename_template=filename_template, - hash_salt=split_name, - disable_shuffling=disable_shuffling, - shard_config=self._shard_config, - example_writer=self._example_writer, - ignore_duplicates=self._ignore_duplicates, - ) - for key, example in utils.tqdm( - generator, - desc=f'Generating {split_name} examples...', - unit=' examples', - total=total_num_examples, - leave=False, - mininterval=1.0, - ): + if nondeterministic_order: + logging.info( + 'Order of examples does not matter, writing to a single shard using' + ' NoShuffleWriter.' + ) + writer = writer_lib.NoShuffleWriter( + serializer=example_serializer.ExampleSerializer(serialized_info), + filename_template=filename_template, + example_writer=self._example_writer, + ) + # Encode and serialize the examples. + serialized_examples = [] + for key, example in utils.tqdm( + generator, + desc=f'Generating {split_name} examples...', + unit=' examples', + total=total_num_examples, + leave=False, + mininterval=1.0, + ): + try: + example = self._features.encode_example(example) + except Exception as e: # pylint: disable=broad-except + utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') + serialized_examples.append(writer.write(key, example)) + # Write the examples to a single shard. + shard_path = writer.finalize(examples=serialized_examples) + shard_lengths = [len(serialized_examples)] + total_size = shard_path.stat().length + logging.info( + 'Done writing %s. Number of examples: %s (shards: %s)', + shard_path, + shard_lengths, + total_size, + ) + else: + logging.info('Deterministic ordering is enabled, using Writer') + writer = writer_lib.Writer( + serializer=example_serializer.ExampleSerializer(serialized_info), + filename_template=filename_template, + hash_salt=split_name, + disable_shuffling=disable_shuffling, + shard_config=self._shard_config, + example_writer=self._example_writer, + ignore_duplicates=self._ignore_duplicates, + ) + for key, example in utils.tqdm( + generator, + desc=f'Generating {split_name} examples...', + unit=' examples', + total=total_num_examples, + leave=False, + mininterval=1.0, + ): + try: + example = self._features.encode_example(example) + except Exception as e: # pylint: disable=broad-except + utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') + writer.write(key, example) try: - example = self._features.encode_example(example) + shard_lengths, total_size = writer.finalize() except Exception as e: # pylint: disable=broad-except - utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') - writer.write(key, example) - try: - shard_lengths, total_size = writer.finalize() - except Exception as e: # pylint: disable=broad-except - utils.reraise( - e, prefix=f'Failed to finalize writing of split "{split_name}": ' - ) - + utils.reraise( + e, prefix=f'Failed to finalize writing of split "{split_name}": ' + ) split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index 2a427c8d6ab..4492d91aee5 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -436,6 +436,40 @@ def finalize(self) -> tuple[list[int], int]: return shard_lengths, self._shuffler.size +class NoShuffleWriter: + """Shuffles / writes Examples to one single file.""" + + def __init__( + self, + serializer: example_serializer.Serializer, + filename_template: naming.ShardedFileTemplate, + example_writer: ExampleWriter, + ): + """Initializes Writer. + + Args: + serializer: class that can serialize examples. + filename_template: template to format sharded filenames. + example_writer: class that writes examples to disk or elsewhere. + """ + self._serializer = serializer + self._filename_template = filename_template + self._example_writer = example_writer + + def write(self, key: int | bytes, example: Example): + """Writes given example.""" + serialized_example = self._serializer.serialize_example(example=example) + return key, serialized_example + + def finalize(self, examples: Iterable[KeyExample]) -> epath.Path: + """Writes the examples to a single shard and returns its path.""" + shard_path = self._filename_template.sharded_filepath( + shard_index=0, num_shards=1 + ) + self._example_writer.write(path=shard_path, examples=examples) + return shard_path + + @dataclasses.dataclass class _ShardInfo: id: int diff --git a/tensorflow_datasets/core/writer_test.py b/tensorflow_datasets/core/writer_test.py index c0b5fd30068..fda159a91dd 100644 --- a/tensorflow_datasets/core/writer_test.py +++ b/tensorflow_datasets/core/writer_test.py @@ -627,6 +627,41 @@ def _build_pcollection(pipeline): self.assertIn(file_format.file_suffix, f.name) +class NoShuffleWriterTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('tfrecord', file_adapters.FileFormat.TFRECORD), + ) + def test_write_beam(self, file_format: file_adapters.FileFormat): + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir = epath.Path(tmp_dir) + filename_template = naming.ShardedFileTemplate( + dataset_name='foo', + split='train', + filetype_suffix=file_format.file_suffix, + data_dir=tmp_dir, + ) + writer = writer_lib.NoShuffleWriter( + serializer=testing.DummySerializer('dummy specs'), + filename_template=filename_template, + example_writer=writer_lib.ExampleWriter(file_format=file_format), + ) + to_write = [(i, str(i).encode('utf-8')) for i in range(10)] + examples = [] + for key, record in to_write: + examples.append(writer.write(key, record)) + shard_path = writer.finalize(examples) + self.assertEqual(len(to_write), len(examples)) + self.assertEqual( + set([key for key, _ in examples]), set([key for key, _ in to_write]) + ) + files = list(tmp_dir.iterdir()) + self.assertLen(files, 1) + self.assertIn(file_format.file_suffix, files[0].name) + self.assertIn(file_format.file_suffix, shard_path.name) + + class CustomExampleWriter(writer_lib.ExampleWriter): def __init__(self):