Skip to content

Commit

Permalink
Do several small refactorings
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704717347
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 10, 2024
1 parent cfeb104 commit 6b93631
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 31 deletions.
1 change: 1 addition & 0 deletions tensorflow_datasets/core/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import typing
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar

from absl import logging
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
Expand Down
1 change: 0 additions & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,7 +1543,6 @@ def _as_dataset(
file_format = file_adapters.FileFormat.from_value(file_format)

reader = reader_lib.Reader(
self.data_dir,
example_specs=example_specs,
file_format=file_format,
)
Expand Down
29 changes: 16 additions & 13 deletions tensorflow_datasets/core/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

from __future__ import annotations

from collections.abc import Sequence
import functools
import os
import re
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from typing import Any, Callable, NamedTuple

from absl import logging
import numpy as np
Expand Down Expand Up @@ -63,7 +64,7 @@ def _get_dataset_from_filename(
do_take: bool,
file_format: file_adapters.FileFormat,
add_tfds_id: bool,
override_buffer_size: Optional[int] = None,
override_buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Returns a tf.data.Dataset instance from given instructions."""
ds = file_adapters.ADAPTER_FOR_FORMAT[file_format].make_tf_data(
Expand Down Expand Up @@ -361,39 +362,38 @@ def _verify_read_config_for_ordered_dataset(
logging.warning(error_message)


class Reader(object):
class Reader:
"""Build a tf.data.Dataset object out of Instruction instance(s).
This class should not typically be exposed to the TFDS user.
"""

def __init__(
self,
path, # TODO(b/216427814) remove this as it isn't used anymore
example_specs,
file_format=file_adapters.DEFAULT_FILE_FORMAT,
file_format: (
str | file_adapters.FileFormat
) = file_adapters.DEFAULT_FILE_FORMAT,
):
"""Initializes Reader.
Args:
path (str): path where tfrecords are stored.
example_specs: spec to build ExampleParser.
file_format: file_adapters.FileFormat, format of the record files in which
the dataset will be read/written from.
"""
self._path = path
self._parser = example_parser.ExampleParser(example_specs)
self._file_format = file_format
self._file_format = file_adapters.FileFormat.from_value(file_format)

def read(
self,
*,
instructions: Tree[splits_lib.SplitArg],
split_infos: List[splits_lib.SplitInfo],
split_infos: Sequence[splits_lib.SplitInfo],
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
disable_shuffling: bool = False,
decode_fn: Optional[DecodeFn] = None,
decode_fn: DecodeFn | None = None,
) -> Tree[tf.data.Dataset]:
"""Returns tf.data.Dataset instance(s).
Expand All @@ -417,8 +417,11 @@ def read(

splits_dict = splits_lib.SplitDict(split_infos=split_infos)

def _read_instruction_to_ds(instruction):
file_instructions = splits_dict[instruction].file_instructions
def _read_instruction_to_ds(
instruction: splits_lib.SplitArg,
) -> tf.data.Dataset:
split_info = splits_dict[instruction]
file_instructions = split_info.file_instructions
return self.read_files(
file_instructions,
read_config=read_config,
Expand All @@ -436,7 +439,7 @@ def read_files(
read_config: read_config_lib.ReadConfig,
shuffle_files: bool,
disable_shuffling: bool = False,
decode_fn: Optional[DecodeFn] = None,
decode_fn: DecodeFn | None = None,
) -> tf.data.Dataset:
"""Returns single tf.data.Dataset instance for the set of file instructions.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setUp(self):
with mock.patch.object(
example_parser, 'ExampleParser', testing.DummyParser
):
self.reader = reader_lib.Reader(self.tmp_dir, 'some_spec')
self.reader = reader_lib.Reader(self.tmp_dir, 'tfrecord')
self.reader.read = functools.partial(
self.reader.read,
read_config=read_config_lib.ReadConfig(),
Expand Down
26 changes: 14 additions & 12 deletions tensorflow_datasets/core/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import abc
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
import dataclasses
import functools
import itertools
Expand Down Expand Up @@ -123,7 +123,7 @@ def __post_init__(self):
def get_available_shards(
self,
data_dir: epath.Path | None = None,
file_format: file_adapters.FileFormat | None = None,
file_format: str | file_adapters.FileFormat | None = None,
strict_matching: bool = True,
) -> list[epath.Path]:
"""Returns the list of shards that are present in the data dir.
Expand All @@ -140,6 +140,7 @@ def get_available_shards(
"""
if filename_template := self.filename_template:
if file_format:
file_format = file_adapters.FileFormat.from_value(file_format)
filename_template = filename_template.replace(
filetype_suffix=file_format.file_suffix
)
Expand Down Expand Up @@ -250,7 +251,9 @@ def replace(self, **kwargs: Any) -> SplitInfo:
"""Returns a copy of the `SplitInfo` with updated attributes."""
return dataclasses.replace(self, **kwargs)

def file_spec(self, file_format: file_adapters.FileFormat) -> str:
def file_spec(
self, file_format: str | file_adapters.FileFormat
) -> str | None:
"""Returns the file spec of the split for the given file format.
A file spec is the full path with sharded notation, e.g.,
Expand All @@ -259,6 +262,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
Args:
file_format: the file format for which to create the file spec for.
"""
file_format = file_adapters.FileFormat.from_value(file_format)
if filename_template := self.filename_template:
if filename_template.filetype_suffix != file_format.file_suffix:
raise ValueError(
Expand All @@ -268,9 +272,7 @@ def file_spec(self, file_format: file_adapters.FileFormat) -> str:
return filename_template.sharded_filepaths_pattern(
num_shards=self.num_shards
)
raise ValueError(
f'Could not get filename template for split from split info: {self}.'
)
return None


@dataclasses.dataclass(eq=False, frozen=True)
Expand Down Expand Up @@ -425,7 +427,7 @@ def __repr__(self) -> str:
if typing.TYPE_CHECKING:
# For type checking, `tfds.Split` is an alias for `str` with additional
# `.TRAIN`, `.TEST`,... attributes. All strings are valid split type.
Split = Union[Split, str]
Split = Split | str


class SplitDict(utils.NonMutableDict[str, SplitInfo]):
Expand All @@ -438,7 +440,7 @@ def __init__(
# TODO(b/216470058): remove this parameter
dataset_name: str | None = None, # deprecated, please don't use
):
super(SplitDict, self).__init__(
super().__init__(
{split_info.name: split_info for split_info in split_infos},
error_msg='Split {key} already present',
)
Expand All @@ -457,7 +459,7 @@ def __getitem__(self, key) -> SplitInfo | SubSplitInfo:
)
# 1st case: The key exists: `info.splits['train']`
elif str(key) in self.keys():
return super(SplitDict, self).__getitem__(str(key))
return super().__getitem__(str(key))
# 2nd case: Uses instructions: `info.splits['train[50%]']`
else:
instructions = _make_file_instructions(
Expand Down Expand Up @@ -543,7 +545,7 @@ def _file_instructions_for_split(


def _make_file_instructions(
split_infos: list[SplitInfo],
split_infos: Sequence[SplitInfo],
instruction: SplitArg,
) -> list[shard_utils.FileInstruction]:
"""Returns file instructions by applying the given instruction on the given splits.
Expand Down Expand Up @@ -587,7 +589,7 @@ class AbstractSplit(abc.ABC):
"""

@classmethod
def from_spec(cls, spec: SplitArg) -> 'AbstractSplit':
def from_spec(cls, spec: SplitArg) -> AbstractSplit:
"""Creates a ReadInstruction instance out of a string spec.
Args:
Expand Down Expand Up @@ -632,7 +634,7 @@ def to_absolute(self, split_infos) -> list[_AbsoluteInstruction]:
"""
raise NotImplementedError

def __add__(self, other: Union[str, 'AbstractSplit']) -> 'AbstractSplit':
def __add__(self, other: str | AbstractSplit) -> AbstractSplit:
"""Sum of 2 splits."""
if not isinstance(other, (str, AbstractSplit)):
raise TypeError(f'Adding split {self!r} with non-split value: {other!r}')
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_datasets/core/splits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,10 +666,11 @@ def test_file_spec_missing_template(self):
num_bytes=42,
filename_template=None,
)
with self.assertRaises(ValueError):
split_info.file_spec(
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
)
self.assertIsNone(
split_info.file_spec(
file_format=tfds.core.file_adapters.FileFormat.TFRECORD
)
)

def test_get_available_shards(self):
tmp_dir = epath.Path(self.tmp_dir)
Expand Down

0 comments on commit 6b93631

Please sign in to comment.