Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tfds.data_source pickable. #5429

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions tensorflow_datasets/core/data_sources/array_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@
"""

import dataclasses
from typing import Any, Optional

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.data_sources import base
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source


Expand All @@ -42,18 +37,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
source.
"""

dataset_info: dataset_info_lib.DatasetInfo
split: splits_lib.Split = None
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
None
)
# In order to lazy load array_record, we don't load
# `array_record_data_source.ArrayRecordDataSource` here.
data_source: Any = dataclasses.field(init=False)
length: int = dataclasses.field(init=False)

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
self.data_source = array_record_data_source.ArrayRecordDataSource(
file_instructions
)
32 changes: 24 additions & 8 deletions tensorflow_datasets/core/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

from collections.abc import MappingView, Sequence
import dataclasses
import functools
import typing
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.features import top_level_feature
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
Expand Down Expand Up @@ -54,6 +56,14 @@ def file_instructions(
return split_dict[split].file_instructions


class _DatasetBuilder(Protocol):
"""Protocol for the DatasetBuilder to avoid cyclic imports."""

@property
def info(self) -> dataset_info_lib.DatasetInfo:
...


@dataclasses.dataclass
class BaseDataSource(MappingView, Sequence):
"""Base DataSource to override all dunder methods with the deserialization.
Expand All @@ -64,22 +74,28 @@ class BaseDataSource(MappingView, Sequence):
deserialization/decoding.
Attributes:
dataset_info: The DatasetInfo of the
dataset_builder: The dataset builder.
split: The split to load in the data source.
decoders: Optional decoders for decoding.
data_source: The underlying data source to initialize in the __post_init__.
"""

dataset_info: dataset_info_lib.DatasetInfo
dataset_builder: _DatasetBuilder
split: splits_lib.Split | None = None
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
data_source: DataSource[Any] = dataclasses.field(init=False)

@functools.cached_property
def _features(self) -> top_level_feature.TopLevelFeature:
"""Caches features because we log the use of dataset_builder.info."""
features = self.dataset_builder.info.features
if not features:
raise ValueError('No feature defined in the dataset builder.')
return features

def __getitem__(self, key: SupportsIndex) -> Any:
record = self.data_source[key.__index__()]
return self.dataset_info.features.deserialize_example_np(
record, decoders=self.decoders
)
return self._features.deserialize_example_np(record, decoders=self.decoders)

def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
"""Retrieves items by batch.
Expand All @@ -98,24 +114,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
if not keys:
return []
records = self.data_source.__getitems__(keys)
features = self.dataset_info.features
if len(keys) != len(records):
raise IndexError(
f'Requested {len(keys)} records but got'
f' {len(records)} records.'
f'{keys=}, {records=}'
)
return [
features.deserialize_example_np(record, decoders=self.decoders)
self._features.deserialize_example_np(record, decoders=self.decoders)
for record in records
]

def __repr__(self) -> str:
decoders_repr = (
tree.map_structure(type, self.decoders) if self.decoders else None
)
name = self.dataset_builder.info.name
return (
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
f'{self.__class__.__name__}(name={name}, '
f'split={self.split!r}, '
f'decoders={decoders_repr})'
)
Expand Down
49 changes: 39 additions & 10 deletions tensorflow_datasets/core/data_sources/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""Tests for all data sources."""

import pickle
from unittest import mock

import cloudpickle
from etils import epath
import pytest
import tensorflow_datasets as tfds
from tensorflow_datasets import testing
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import file_adapters
Expand Down Expand Up @@ -77,7 +79,7 @@ def mocked_parquet_dataset():
)
def test_read_write(
tmp_path: epath.Path,
builder_cls: dataset_builder.DatasetBuilder,
builder_cls: dataset_builder_lib.DatasetBuilder,
file_format: file_adapters.FileFormat,
):
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
Expand Down Expand Up @@ -106,28 +108,36 @@ def test_read_write(
]


def create_dataset_info(file_format: file_adapters.FileFormat):
def create_dataset_builder(
file_format: file_adapters.FileFormat,
) -> dataset_builder_lib.DatasetBuilder:
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
split_mock.return_value.name = 'train'
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
dataset_info.file_format = file_format
dataset_info.splits = {'train': split_mock()}
dataset_info.name = 'dataset_name'
return dataset_info

dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
dataset_builder.info = dataset_info

return dataset_builder


@pytest.mark.parametrize(
'data_source_cls',
_DATA_SOURCE_CLS,
)
def test_missing_split_raises_error(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
with pytest.raises(
ValueError,
match="Unknown split 'doesnotexist'.",
):
data_source_cls(dataset_info, split='doesnotexist')
data_source_cls(dataset_builder, split='doesnotexist')


@pytest.mark.usefixtures(*_FIXTURES)
Expand All @@ -136,8 +146,10 @@ def test_missing_split_raises_error(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
source = data_source_cls(dataset_info, split='train')
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(dataset_builder, split='train')
name = data_source_cls.__name__
assert (
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
Expand All @@ -150,9 +162,11 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(
dataset_info,
dataset_builder,
split='train',
decoders={'my_feature': decode.SkipDecoding()},
)
Expand Down Expand Up @@ -181,3 +195,18 @@ def test_data_source_is_sliceable():
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
assert file_instructions[0].skip == 0
assert file_instructions[0].take == 30000


# PyGrain requires that data sources are picklable.
@pytest.mark.parametrize(
'file_format',
file_adapters.FileFormat.with_random_access(),
)
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
def test_data_source_is_picklable_after_use(file_format, pickle_module):
with tfds.testing.tmp_dir() as data_dir:
builder = tfds.testing.DummyDataset(data_dir=data_dir)
builder.download_and_prepare(file_format=file_format)
data_source = builder.as_data_source(split='train')
assert data_source[0] == {'id': 0}
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/data_sources/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class ParquetDataSource(base.BaseDataSource):
"""ParquetDataSource to read from a ParquetDataset."""

def __post_init__(self):
file_instructions = base.file_instructions(self.dataset_info, self.split)
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
filenames = [
file_instruction.filename for file_instruction in file_instructions
]
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,13 @@ def build_single_data_source(
file_format = self.info.file_format
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
return array_record.ArrayRecordDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
elif file_format == file_adapters.FileFormat.PARQUET:
return parquet.ParquetDataSource(
self.info,
self,
split=split,
decoders=decoders,
)
Expand Down
49 changes: 39 additions & 10 deletions tensorflow_datasets/testing/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,25 @@ def _getitems(
_getitem(self, record_key, generator, serialized=serialized)
for record_key in record_keys
]
if serialized:
return np.array(items)
return items
return np.asarray(items)


def _deserialize_example_np(serialized_example, *, decoders=None):
def _deserialize_example_np(self, serialized_example, *, decoders=None):
"""Function to overwrite dataset_info.features.deserialize_example_np.
Warning: this has to be defined in the outer scope in order for the function
to be pickable.
Args:
self: the dataset builder.
serialized_example: the example to deserialize.
decoders: optional decoders.
Returns:
The serialized example, because deserialization is taken care by
RandomFakeGenerator.
"""
del decoders
del self, decoders
return serialized_example


Expand Down Expand Up @@ -173,6 +172,7 @@ def mock_data(
as_data_source_fn: Optional[Callable[..., Sequence[Any]]] = None,
data_dir: Optional[str] = None,
mock_array_record_data_source: Optional[PickableDataSourceMock] = None,
use_in_multiprocessing: bool = False,
) -> Iterator[None]:
"""Mock tfds to generate random data.
Expand Down Expand Up @@ -262,6 +262,10 @@ def as_dataset(self, *args, **kwargs):
mock_array_record_data_source: Overwrite a mock for the underlying
ArrayRecord data source if it is used. Note: If used the same mock will be
used for all data sources loaded within this context.
use_in_multiprocessing: If True, the mock will use a multiprocessing-safe
approach to generate the data. It's notably useful for PyGrain. The goal
is to migrate the codebase to this mode by default. Find a more detailed
explanation of this parameter in a comment in the code below.
Yields:
None
Expand Down Expand Up @@ -361,9 +365,31 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
if split is None:
split = {s: s for s in self.info.splits}

generator_cls, features, _, _ = _get_fake_data_components(
decoders, self.info.features
)
features = self.info.features
if use_in_multiprocessing:
# In multiprocessing, we generate serialized data. The data is then
# re-deserialized by the feature as it would normally happen in TFDS. In
# this approach, we don't need to monkey-patch workers to propagate the
# information that deserialize_example_np should be a no-op. Indeed, doing
# so is difficult as PyGrain uses the `spawn` multiprocessing mode. Users
# of tfds.testing.mock_data in the codebase started relying on the
# function not serializing (for example, they don't have TensorFlow in
# their dependency), so we cannot have use_in_multiprocessing by default.
# ┌─────────────┐
# │ Main process│
# └─┬──────┬────┘
# ┌───────▼─┐ ┌─▼───────┐
# │ worker1 │ │ worker2 │ ...
# └───────┬─┘ └─┬───────┘
# serialized data by the generator
# ┌───────▼─┐ ┌─▼───────┐
# │ tfds 1 │ │ tfds 2 │ ...
# └───────┬─┘ └─┬───────┘
# deserialized data
generator_cls = SerializedRandomFakeGenerator
else:
# We generate already deserialized data with the generator.
generator_cls, _, _, _ = _get_fake_data_components(decoders, features)
generator = generator_cls(features, num_examples)

if actual_policy == MockPolicy.USE_CODE:
Expand All @@ -385,7 +411,6 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
# Force ARRAY_RECORD as the default file_format.
return_value=file_adapters.FileFormat.ARRAY_RECORD,
):
self.info.features.deserialize_example_np = _deserialize_example_np
mock_data_source.return_value.__len__.return_value = num_examples
mock_data_source.return_value._generator = ( # pylint:disable=protected-access
generator
Expand All @@ -399,7 +424,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):

def build_single_data_source(split):
single_data_source = array_record.ArrayRecordDataSource(
dataset_info=self.info, split=split, decoders=decoders
dataset_builder=self, split=split, decoders=decoders
)
return single_data_source

Expand Down Expand Up @@ -463,6 +488,10 @@ def new_builder_from_files(*args, **kwargs):
f'{core}.dataset_builder.FileReaderBuilder._as_dataset',
as_dataset_fn,
),
(
f'{core}.features.top_level_feature.TopLevelFeature.deserialize_example_np',
_deserialize_example_np,
),
]:
stack.enter_context(mock.patch(path, mocked_fn))
yield
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_datasets/testing/mocking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,12 @@ def test_as_data_source_fn():
assert imagenet[0] == 'foo'
assert imagenet[1] == 'bar'
assert imagenet[2] == 'baz'


# PyGrain requires that data sources are picklable.
def test_mocked_data_source_is_pickable():
with tfds.testing.mock_data(num_examples=2):
data_source = tfds.data_source('imagenet2012', split='train')
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
assert len(pickled_and_unpickled_data_source) == 2
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)
Loading