From eb7ae911a8d1831af5e9dfcb5554d90d1e5e84b8 Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Mon, 29 Jul 2024 08:38:52 -0700 Subject: [PATCH] Stream from Hugging Face instead of downloading and preparing everything. PiperOrigin-RevId: 657212303 --- .../huggingface_dataset_builder.py | 54 ++++++++++++------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 9d1ee6b57ee..2ef6475fa21 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -108,9 +108,22 @@ class _ShardInfo: num_exceptions: int +def _load_dataset( + hf_builder: hf_datasets.DatasetBuilder, + split: str, +) -> hf_datasets.Dataset: + """Efficiently loads a HuggingFace iterable dataset from its builder.""" + return hf_datasets.load_dataset( + hf_builder.repo_id, + hf_builder.config_id, + split=split, + streaming=True, + ) + + def _write_shard( shard_spec: _ShardSpec, - hf_builder, + hf_builder: hf_datasets.DatasetBuilder, example_writer, features: feature_lib.FeaturesDict, ignore_hf_errors: bool, @@ -136,12 +149,19 @@ def _write_shard( def get_serialized_examples_iter(): nonlocal num_bytes nonlocal num_exceptions - dataset = hf_builder.as_dataset( - split=shard_spec.shard_split, run_post_process=False + dataset = _load_dataset( + hf_builder, + shard_spec.hf_split, ) + dataset = iter(dataset) for i in range(shard_spec.num_examples): + if i < shard_spec.start_index: + next(dataset) + continue + if i >= shard_spec.end_index: + break try: - hf_value = dataset[i] + hf_value = next(dataset) except Exception: # pylint: disable=broad-exception-caught num_exceptions += 1 if ignore_hf_errors: @@ -257,14 +277,6 @@ def _create_builder_config( ) -> Optional[dataset_builder.BuilderConfig]: return self._converted_builder_config - @functools.lru_cache(maxsize=1) - def _hf_download_and_prepare(self): - login_to_hf(self._hf_hub_token) - self._hf_builder.download_and_prepare( - num_proc=self._hf_num_proc, - verification_mode=self._verification_mode, - ) - @property def _hf_info(self) -> hf_datasets.DatasetInfo: """Retrieves the dataset info from the HuggingFace Datasets.""" @@ -278,11 +290,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo: ) def _hf_features(self) -> hf_datasets.Features: - if not self._hf_info.features: - # We need to download and prepare the data to know its features. - self._hf_download_and_prepare() - - return self._hf_info.features + # Return the features from the builder info. + if self._hf_info.features: + return self._hf_info.features + # Return the features from the first split. + for split in self._hf_info.splits: + ds = _load_dataset( + self._hf_builder, + split, + ) + if hasattr(ds, 'info') and ds.info.features: + return ds.info.features + raise ValueError('No features found in the dataset.') def _info(self) -> dataset_info_lib.DatasetInfo: return dataset_info_lib.DatasetInfo( @@ -309,7 +328,6 @@ def _generate_splits( ) -> Sequence[splits_lib.SplitInfo]: """Prepares the dataset by writing to shards directly.""" del dl_manager, download_config # Unused. - self._hf_download_and_prepare() shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {} for hf_split, hf_split_info in self._hf_info.splits.items():