Skip to content

Commit

Permalink
data{pipe,set}.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Dec 17, 2024
1 parent 869a99f commit 0708496
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 323 deletions.
6 changes: 2 additions & 4 deletions src/tiledbsoma_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
"""An API to support machine learning applications built on SOMA."""

from .dataloader import experiment_dataloader
from .pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
)
from .datapipe import ExperimentAxisQueryIterDataPipe
from .dataset import ExperimentAxisQueryIterableDataset

__version__ = "0.1.0-dev"

Expand Down
6 changes: 2 additions & 4 deletions src/tiledbsoma_ml/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from torch.utils.data import DataLoader

from tiledbsoma_ml._distributed import init_multiprocessing
from tiledbsoma_ml.pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
)
from tiledbsoma_ml.datapipe import ExperimentAxisQueryIterDataPipe
from tiledbsoma_ml.dataset import ExperimentAxisQueryIterableDataset

_T = TypeVar("_T")

Expand Down
114 changes: 114 additions & 0 deletions src/tiledbsoma_ml/datapipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation
# Copyright (c) 2021-2024 TileDB, Inc.
#
# Licensed under the MIT License.

from __future__ import annotations

from typing import Iterator, Sequence, Tuple

from somacore import ExperimentAxisQuery
from torch.utils.data.dataset import Dataset
from torchdata.datapipes.iter import IterDataPipe

from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable


class ExperimentAxisQueryIterDataPipe(
IterDataPipe[Dataset[Batch]] # type:ignore[misc]
):
"""A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for
legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the
TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information.
See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class.
Lifecycle:
deprecated
"""

def __init__(
self,
query: ExperimentAxisQuery,
X_name: str = "raw",
obs_column_names: Sequence[str] = ("soma_joinid",),
batch_size: int = 1,
shuffle: bool = True,
seed: int | None = None,
io_batch_size: int = 2**16,
shuffle_chunk_size: int = 64,
return_sparse_X: bool = False,
use_eager_fetch: bool = True,
):
"""
See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class.
Lifecycle:
deprecated
"""
super().__init__()
self._exp_iter = ExperimentAxisQueryIterable(
query=query,
X_name=X_name,
obs_column_names=obs_column_names,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
io_batch_size=io_batch_size,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
shuffle_chunk_size=shuffle_chunk_size,
)

def __iter__(self) -> Iterator[Batch]:
"""
See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class.
Lifecycle:
deprecated
"""
batch_size = self._exp_iter.batch_size
for X, obs in self._exp_iter:
if batch_size == 1:
X = X[0] # This is a no-op for `csr_matrix`s
yield X, obs

def __len__(self) -> int:
"""
See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class.
Lifecycle:
deprecated
"""
return len(self._exp_iter)

@property
def shape(self) -> Tuple[int, int]:
"""
See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class.
Lifecycle:
deprecated
"""
return self._exp_iter.shape

def set_epoch(self, epoch: int) -> None:
"""
Set the epoch for this Data iterator.
When :attr:`shuffle=True`, this will ensure that all replicas use a different
random ordering for each epoch. Failure to call this method before each epoch
will result in the same data ordering.
This call must be made before the per-epoch iterator is created.
Lifecycle:
experimental
"""
self._exp_iter.set_epoch(epoch)

@property
def epoch(self) -> int:
return self._exp_iter.epoch
222 changes: 222 additions & 0 deletions src/tiledbsoma_ml/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation
# Copyright (c) 2021-2024 TileDB, Inc.
#
# Licensed under the MIT License.

from __future__ import annotations

from typing import Iterator, Sequence, Tuple

from somacore import ExperimentAxisQuery
from torch.utils.data import IterableDataset

from tiledbsoma_ml.pytorch import Batch, ExperimentAxisQueryIterable


class ExperimentAxisQueryIterableDataset(IterableDataset[Batch]): # type:ignore[misc]
"""A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as
specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of
``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray`
and a :class:`pandas.DataFrame`.
For example:
>>> import torch
>>> import tiledbsoma
>>> import tiledbsoma_ml
>>> with tiledbsoma.Experiment.open("my_experiment_path") as exp:
... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query:
... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query)
... dataloader = torch.utils.data.DataLoader(ds)
>>> data = next(iter(dataloader))
>>> data
(array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
soma_joinid
0 57905025)
>>> data[0]
array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)
>>> data[1]
soma_joinid
0 57905025
The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each
iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size``
of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more
performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`.
The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the
default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension).
The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A
larger value will increase total memory usage and may reduce average read time per row.
Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using
:class:`DataLoader` shuffling. The shuffling algorithm works as follows:
1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk".
2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``).
3. The entire I/O buffer is shuffled.
Put another way, we read randomly selected groups of observations from across all query results, concatenate
those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle
is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size``
(number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O
performance.
This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader`
and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition
data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any
data tail will be dropped.
Lifecycle:
experimental
"""

def __init__(
self,
query: ExperimentAxisQuery,
X_name: str = "raw",
obs_column_names: Sequence[str] = ("soma_joinid",),
batch_size: int = 1,
shuffle: bool = True,
seed: int | None = None,
io_batch_size: int = 2**16,
shuffle_chunk_size: int = 64,
return_sparse_X: bool = False,
use_eager_fetch: bool = True,
):
"""
Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`.
The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as
a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively.
Args:
query:
A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over.
X_name:
The name of the ``X`` layer to read.
obs_column_names:
The names of the ``obs`` columns to return. At least one column name must be specified.
Default is ``('soma_joinid',)``.
batch_size:
The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of
``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will
result in :class:`torch.Tensor`\ s of rank 2 (multiple rows).
Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader`
batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader``
batch_size parameter to ``None``.
shuffle:
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
io_batch_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
this class's behavior: 1) The maximum memory utilization, with larger values providing
better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling
(see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but
may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers
are employed).
shuffle_chunk_size:
The number of contiguous rows sampled, prior to concatenation and shuffling.
Larger numbers correspond to less randomness, but greater read performance.
If ``shuffle == False``, this parameter is ignored.
return_sparse_X:
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will
return ``X`` data as a :class:`numpy.ndarray`.
seed:
The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
processes.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
available for processing via the iterator. This allows network (or filesystem) requests to be made in
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
cost of doubling memory utilization. Defaults to ``True``.
Raises:
``ValueError`` on various unsupported or malformed parameter values.
Lifecycle:
experimental
"""
super().__init__()
self._exp_iter = ExperimentAxisQueryIterable(
query=query,
X_name=X_name,
obs_column_names=obs_column_names,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
io_batch_size=io_batch_size,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
shuffle_chunk_size=shuffle_chunk_size,
)

def __iter__(self) -> Iterator[Batch]:
"""Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and
:class:`pandas.DataFrame`.
Returns:
``iterator``
Lifecycle:
experimental
"""
batch_size = self._exp_iter.batch_size
for X, obs in self._exp_iter:
if batch_size == 1:
X = X[0] # This is a no-op for `csr_matrix`s
yield X, obs

def __len__(self) -> int:
"""Return number of batches this iterable will produce.
See important caveats in the PyTorch
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
documentation regarding ``len(dataloader)``, which also apply to this class.
Returns:
``int`` (number of batches).
Lifecycle:
experimental
"""
return len(self._exp_iter)

@property
def shape(self) -> Tuple[int, int]:
"""Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
the number of batches will reflect the size of the data partition assigned to the active process.
Returns:
A tuple of two ``int`` values: number of batches, number of vars.
Lifecycle:
experimental
"""
return self._exp_iter.shape

def set_epoch(self, epoch: int) -> None:
"""
Set the epoch for this Data iterator.
When :attr:`shuffle=True`, this will ensure that all replicas use a different
random ordering for each epoch. Failure to call this method before each epoch
will result in the same data ordering.
This call must be made before the per-epoch iterator is created.
Lifecycle:
experimental
"""
self._exp_iter.set_epoch(epoch)

@property
def epoch(self) -> int:
return self._exp_iter.epoch
Loading

0 comments on commit 0708496

Please sign in to comment.