From 1b032443066c7b01a2f2c165a4d5c6f372bd9557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alfonso=20Casta=C3=B1o?= Date: Fri, 11 Oct 2024 06:15:07 -0700 Subject: [PATCH] If there is only one config, load it by default PiperOrigin-RevId: 684813556 --- tensorflow_datasets/core/dataset_builder.py | 9 ------ tensorflow_datasets/core/read_only_builder.py | 29 +++++++++++++++---- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 26c222ae670..0c0efa22c69 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -1982,15 +1982,6 @@ def _save_default_config_name( tmp_config_path.write_text(json.dumps(data)) -def load_default_config_name(builder_dir: epath.Path) -> str | None: - """Load `builder_cls` metadata (common to all builder configs).""" - config_path = builder_dir / ".config" / constants.METADATA_FILENAME - if not config_path.exists(): - return None - data = json.loads(config_path.read_text()) - return data.get("default_config_name") - - def canonical_version_for_config( instance_or_cls: Union[DatasetBuilder, Type[DatasetBuilder]], config: Optional[BuilderConfig] = None, diff --git a/tensorflow_datasets/core/read_only_builder.py b/tensorflow_datasets/core/read_only_builder.py index 2fea8072fc2..a457ddc96d0 100644 --- a/tensorflow_datasets/core/read_only_builder.py +++ b/tensorflow_datasets/core/read_only_builder.py @@ -20,6 +20,7 @@ from collections.abc import Sequence import concurrent.futures import functools +import json import os import typing from typing import Any, Type @@ -38,6 +39,8 @@ from tensorflow_datasets.core import registered from tensorflow_datasets.core import splits as splits_lib from tensorflow_datasets.core import utils + from tensorflow_datasets.core import constants + from tensorflow_datasets.core.features import feature as feature_lib from tensorflow_datasets.core.proto import dataset_info_pb2 from tensorflow_datasets.core.utils import error_utils @@ -436,10 +439,10 @@ def _list_possible_configs( configs = [] for data_dir in all_data_dirs: builder_dir = epath.Path(data_dir) / builder_name - if builder_dir.exists(): - for path in builder_dir.iterdir(): - if path.is_dir(): - configs.append(path.name) + variants = file_utils.list_dataset_variants( + dataset_dir=builder_dir, include_versions=False + ) + configs.extend(v.config for v in variants if v.config) return configs @@ -537,7 +540,7 @@ def _get_default_config_name( return cls.default_builder_config.name # Otherwise, try to load default config from common metadata - return dataset_builder.load_default_config_name(builder_dir) + return load_default_config_name(builder_dir) def _get_version( @@ -577,3 +580,19 @@ def _get_version( ) error_utils.add_context(error_msg) return None + + +def load_default_config_name(dataset_dir: epath.Path) -> str | None: + """Load `builder_cls` metadata (common to all builder configs).""" + config_path = dataset_dir / '.config' / constants.METADATA_FILENAME + if config_path.exists(): + data = json.loads(config_path.read_text()) + return data.get('default_config_name') + variants = list( + file_utils.list_dataset_variants( + dataset_dir=dataset_dir, include_versions=False + ) + ) + if len(variants) == 1: + return variants[0].config + return None