Skip to content

Commit

Permalink
Stream from Hugging Face instead of downloading and preparing everyth…
Browse files Browse the repository at this point in the history
…ing.

PiperOrigin-RevId: 657212303
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Jul 29, 2024
1 parent 2123db7 commit eb7ae91
Showing 1 changed file with 36 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,22 @@ class _ShardInfo:
num_exceptions: int


def _load_dataset(
hf_builder: hf_datasets.DatasetBuilder,
split: str,
) -> hf_datasets.Dataset:
"""Efficiently loads a HuggingFace iterable dataset from its builder."""
return hf_datasets.load_dataset(
hf_builder.repo_id,
hf_builder.config_id,
split=split,
streaming=True,
)


def _write_shard(
shard_spec: _ShardSpec,
hf_builder,
hf_builder: hf_datasets.DatasetBuilder,
example_writer,
features: feature_lib.FeaturesDict,
ignore_hf_errors: bool,
Expand All @@ -136,12 +149,19 @@ def _write_shard(
def get_serialized_examples_iter():
nonlocal num_bytes
nonlocal num_exceptions
dataset = hf_builder.as_dataset(
split=shard_spec.shard_split, run_post_process=False
dataset = _load_dataset(
hf_builder,
shard_spec.hf_split,
)
dataset = iter(dataset)
for i in range(shard_spec.num_examples):
if i < shard_spec.start_index:
next(dataset)
continue
if i >= shard_spec.end_index:
break
try:
hf_value = dataset[i]
hf_value = next(dataset)
except Exception: # pylint: disable=broad-exception-caught
num_exceptions += 1
if ignore_hf_errors:
Expand Down Expand Up @@ -257,14 +277,6 @@ def _create_builder_config(
) -> Optional[dataset_builder.BuilderConfig]:
return self._converted_builder_config

@functools.lru_cache(maxsize=1)
def _hf_download_and_prepare(self):
login_to_hf(self._hf_hub_token)
self._hf_builder.download_and_prepare(
num_proc=self._hf_num_proc,
verification_mode=self._verification_mode,
)

@property
def _hf_info(self) -> hf_datasets.DatasetInfo:
"""Retrieves the dataset info from the HuggingFace Datasets."""
Expand All @@ -278,11 +290,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
)

def _hf_features(self) -> hf_datasets.Features:
if not self._hf_info.features:
# We need to download and prepare the data to know its features.
self._hf_download_and_prepare()

return self._hf_info.features
# Return the features from the builder info.
if self._hf_info.features:
return self._hf_info.features
# Return the features from the first split.
for split in self._hf_info.splits:
ds = _load_dataset(
self._hf_builder,
split,
)
if hasattr(ds, 'info') and ds.info.features:
return ds.info.features
raise ValueError('No features found in the dataset.')

def _info(self) -> dataset_info_lib.DatasetInfo:
return dataset_info_lib.DatasetInfo(
Expand All @@ -309,7 +328,6 @@ def _generate_splits(
) -> Sequence[splits_lib.SplitInfo]:
"""Prepares the dataset by writing to shards directly."""
del dl_manager, download_config # Unused.
self._hf_download_and_prepare()

shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
for hf_split, hf_split_info in self._hf_info.splits.items():
Expand Down

0 comments on commit eb7ae91

Please sign in to comment.