From 6b93631a3d108f4af82bfa66e85863fa13d38eb4 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Tue, 10 Dec 2024 08:19:37 -0800 Subject: [PATCH] Do several small refactorings PiperOrigin-RevId: 704717347 --- tensorflow_datasets/core/data_sources/base.py | 1 + tensorflow_datasets/core/dataset_builder.py | 1 - tensorflow_datasets/core/reader.py | 29 ++++++++++--------- tensorflow_datasets/core/reader_test.py | 2 +- tensorflow_datasets/core/splits.py | 26 +++++++++-------- tensorflow_datasets/core/splits_test.py | 9 +++--- 6 files changed, 37 insertions(+), 31 deletions(-) diff --git a/tensorflow_datasets/core/data_sources/base.py b/tensorflow_datasets/core/data_sources/base.py index 0e4d9aa0e7d..09ece9410be 100644 --- a/tensorflow_datasets/core/data_sources/base.py +++ b/tensorflow_datasets/core/data_sources/base.py @@ -20,6 +20,7 @@ import typing from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar +from absl import logging from tensorflow_datasets.core import dataset_info as dataset_info_lib from tensorflow_datasets.core import decode from tensorflow_datasets.core import splits as splits_lib diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 24404b6fe24..bc87d3eea5c 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -1543,7 +1543,6 @@ def _as_dataset( file_format = file_adapters.FileFormat.from_value(file_format) reader = reader_lib.Reader( - self.data_dir, example_specs=example_specs, file_format=file_format, ) diff --git a/tensorflow_datasets/core/reader.py b/tensorflow_datasets/core/reader.py index f419ed8d967..2ff6fcb61fb 100644 --- a/tensorflow_datasets/core/reader.py +++ b/tensorflow_datasets/core/reader.py @@ -17,10 +17,11 @@ from __future__ import annotations +from collections.abc import Sequence import functools import os import re -from typing import Any, Callable, List, NamedTuple, Optional, Sequence +from typing import Any, Callable, NamedTuple from absl import logging import numpy as np @@ -63,7 +64,7 @@ def _get_dataset_from_filename( do_take: bool, file_format: file_adapters.FileFormat, add_tfds_id: bool, - override_buffer_size: Optional[int] = None, + override_buffer_size: int | None = None, ) -> tf.data.Dataset: """Returns a tf.data.Dataset instance from given instructions.""" ds = file_adapters.ADAPTER_FOR_FORMAT[file_format].make_tf_data( @@ -361,7 +362,7 @@ def _verify_read_config_for_ordered_dataset( logging.warning(error_message) -class Reader(object): +class Reader: """Build a tf.data.Dataset object out of Instruction instance(s). This class should not typically be exposed to the TFDS user. @@ -369,31 +370,30 @@ class Reader(object): def __init__( self, - path, # TODO(b/216427814) remove this as it isn't used anymore example_specs, - file_format=file_adapters.DEFAULT_FILE_FORMAT, + file_format: ( + str | file_adapters.FileFormat + ) = file_adapters.DEFAULT_FILE_FORMAT, ): """Initializes Reader. Args: - path (str): path where tfrecords are stored. example_specs: spec to build ExampleParser. file_format: file_adapters.FileFormat, format of the record files in which the dataset will be read/written from. """ - self._path = path self._parser = example_parser.ExampleParser(example_specs) - self._file_format = file_format + self._file_format = file_adapters.FileFormat.from_value(file_format) def read( self, *, instructions: Tree[splits_lib.SplitArg], - split_infos: List[splits_lib.SplitInfo], + split_infos: Sequence[splits_lib.SplitInfo], read_config: read_config_lib.ReadConfig, shuffle_files: bool, disable_shuffling: bool = False, - decode_fn: Optional[DecodeFn] = None, + decode_fn: DecodeFn | None = None, ) -> Tree[tf.data.Dataset]: """Returns tf.data.Dataset instance(s). @@ -417,8 +417,11 @@ def read( splits_dict = splits_lib.SplitDict(split_infos=split_infos) - def _read_instruction_to_ds(instruction): - file_instructions = splits_dict[instruction].file_instructions + def _read_instruction_to_ds( + instruction: splits_lib.SplitArg, + ) -> tf.data.Dataset: + split_info = splits_dict[instruction] + file_instructions = split_info.file_instructions return self.read_files( file_instructions, read_config=read_config, @@ -436,7 +439,7 @@ def read_files( read_config: read_config_lib.ReadConfig, shuffle_files: bool, disable_shuffling: bool = False, - decode_fn: Optional[DecodeFn] = None, + decode_fn: DecodeFn | None = None, ) -> tf.data.Dataset: """Returns single tf.data.Dataset instance for the set of file instructions. diff --git a/tensorflow_datasets/core/reader_test.py b/tensorflow_datasets/core/reader_test.py index ae9d60b8179..e478e26bfd7 100644 --- a/tensorflow_datasets/core/reader_test.py +++ b/tensorflow_datasets/core/reader_test.py @@ -76,7 +76,7 @@ def setUp(self): with mock.patch.object( example_parser, 'ExampleParser', testing.DummyParser ): - self.reader = reader_lib.Reader(self.tmp_dir, 'some_spec') + self.reader = reader_lib.Reader(self.tmp_dir, 'tfrecord') self.reader.read = functools.partial( self.reader.read, read_config=read_config_lib.ReadConfig(), diff --git a/tensorflow_datasets/core/splits.py b/tensorflow_datasets/core/splits.py index e5ae13c41df..1c4c1e47dc5 100644 --- a/tensorflow_datasets/core/splits.py +++ b/tensorflow_datasets/core/splits.py @@ -18,7 +18,7 @@ from __future__ import annotations import abc -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import dataclasses import functools import itertools @@ -123,7 +123,7 @@ def __post_init__(self): def get_available_shards( self, data_dir: epath.Path | None = None, - file_format: file_adapters.FileFormat | None = None, + file_format: str | file_adapters.FileFormat | None = None, strict_matching: bool = True, ) -> list[epath.Path]: """Returns the list of shards that are present in the data dir. @@ -140,6 +140,7 @@ def get_available_shards( """ if filename_template := self.filename_template: if file_format: + file_format = file_adapters.FileFormat.from_value(file_format) filename_template = filename_template.replace( filetype_suffix=file_format.file_suffix ) @@ -250,7 +251,9 @@ def replace(self, **kwargs: Any) -> SplitInfo: """Returns a copy of the `SplitInfo` with updated attributes.""" return dataclasses.replace(self, **kwargs) - def file_spec(self, file_format: file_adapters.FileFormat) -> str: + def file_spec( + self, file_format: str | file_adapters.FileFormat + ) -> str | None: """Returns the file spec of the split for the given file format. A file spec is the full path with sharded notation, e.g., @@ -259,6 +262,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str: Args: file_format: the file format for which to create the file spec for. """ + file_format = file_adapters.FileFormat.from_value(file_format) if filename_template := self.filename_template: if filename_template.filetype_suffix != file_format.file_suffix: raise ValueError( @@ -268,9 +272,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str: return filename_template.sharded_filepaths_pattern( num_shards=self.num_shards ) - raise ValueError( - f'Could not get filename template for split from split info: {self}.' - ) + return None @dataclasses.dataclass(eq=False, frozen=True) @@ -425,7 +427,7 @@ def __repr__(self) -> str: if typing.TYPE_CHECKING: # For type checking, `tfds.Split` is an alias for `str` with additional # `.TRAIN`, `.TEST`,... attributes. All strings are valid split type. - Split = Union[Split, str] + Split = Split | str class SplitDict(utils.NonMutableDict[str, SplitInfo]): @@ -438,7 +440,7 @@ def __init__( # TODO(b/216470058): remove this parameter dataset_name: str | None = None, # deprecated, please don't use ): - super(SplitDict, self).__init__( + super().__init__( {split_info.name: split_info for split_info in split_infos}, error_msg='Split {key} already present', ) @@ -457,7 +459,7 @@ def __getitem__(self, key) -> SplitInfo | SubSplitInfo: ) # 1st case: The key exists: `info.splits['train']` elif str(key) in self.keys(): - return super(SplitDict, self).__getitem__(str(key)) + return super().__getitem__(str(key)) # 2nd case: Uses instructions: `info.splits['train[50%]']` else: instructions = _make_file_instructions( @@ -543,7 +545,7 @@ def _file_instructions_for_split( def _make_file_instructions( - split_infos: list[SplitInfo], + split_infos: Sequence[SplitInfo], instruction: SplitArg, ) -> list[shard_utils.FileInstruction]: """Returns file instructions by applying the given instruction on the given splits. @@ -587,7 +589,7 @@ class AbstractSplit(abc.ABC): """ @classmethod - def from_spec(cls, spec: SplitArg) -> 'AbstractSplit': + def from_spec(cls, spec: SplitArg) -> AbstractSplit: """Creates a ReadInstruction instance out of a string spec. Args: @@ -632,7 +634,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]: """ raise NotImplementedError - def __add__(self, other: Union[str, 'AbstractSplit']) -> 'AbstractSplit': + def __add__(self, other: str | AbstractSplit) -> AbstractSplit: """Sum of 2 splits.""" if not isinstance(other, (str, AbstractSplit)): raise TypeError(f'Adding split {self!r} with non-split value: {other!r}') diff --git a/tensorflow_datasets/core/splits_test.py b/tensorflow_datasets/core/splits_test.py index 13410b9bfdd..8ea4f521b9d 100644 --- a/tensorflow_datasets/core/splits_test.py +++ b/tensorflow_datasets/core/splits_test.py @@ -666,10 +666,11 @@ def test_file_spec_missing_template(self): num_bytes=42, filename_template=None, ) - with self.assertRaises(ValueError): - split_info.file_spec( - file_format=tfds.core.file_adapters.FileFormat.TFRECORD - ) + self.assertIsNone( + split_info.file_spec( + file_format=tfds.core.file_adapters.FileFormat.TFRECORD + ) + ) def test_get_available_shards(self): tmp_dir = epath.Path(self.tmp_dir)