Skip to content

Commit

Permalink
docstring updates, attempt to make shape/len more precise
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Oct 2, 2024
1 parent 579727a commit b3b5432
Showing 1 changed file with 40 additions and 44 deletions.
84 changes: 40 additions & 44 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def open_experiment(self) -> Generator[soma.Experiment, None, None]:


class ExperimentAxisQueryIterable(Iterable[XObsDatum]):
"""An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
"""An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as
selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator
produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and
:class:`pandas.DataFrame`, respectively.
Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and
:class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset`
and `ExperimentAxisQueryIterDataPipe` for more details on usage.
and :class:`ExperimentAxisQueryIterDataPipe` for more details on usage.
Lifecycle:
experimental
Expand Down Expand Up @@ -136,14 +136,10 @@ def __init__(
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the
default), will return ``X`` data as a :class:`numpy.ndarray`.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
available for processing via the iterator. This allows network (or filesystem) requests to be made in
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
cost of doubling memory utilization. Defaults to ``True``.
Returns:
An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
a :class:`torch.utils.data.DataLoader` instance.
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is
made available for processing via the iterator. This allows network (or filesystem) requests to be made
in parallel with client-side processing of the SOMA data, potentially improving overall performance at
the cost of doubling memory utilization. Defaults to ``True``.
Raises:
``ValueError`` on various unsupported or malformed parameter values.
Expand Down Expand Up @@ -300,16 +296,16 @@ def __iter__(self) -> Iterator[XObsDatum]:
yield from _mini_batch_iter

def __len__(self) -> int:
"""Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or
as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell)
count will reflect the size of the data partition assigned to the active process.
"""Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed`
or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the
batch count will reflect the size of the data partition assigned to the active process.
See important caveats in the PyTorch
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
documentation regarding ``len(dataloader)``, which also apply to this class.
Returns:
An ``int``.
``int`` (Number of batches).
Lifecycle:
experimental
Expand All @@ -318,25 +314,31 @@ def __len__(self) -> int:

@property
def shape(self) -> Tuple[int, int]:
"""Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
(i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
the size of the data partition assigned to the active process.
"""Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
the number of batches will reflect the size of the data partition assigned to the active process.
Returns:
A tuple of two ``int`` values: number of obs, number of vars.
A tuple of two ``int`` values: number of batches, number of vars.
Lifecycle:
experimental
"""
self._init_once()
assert self._obs_joinids is not None
assert self._var_joinids is not None
world_size, _ = _get_distributed_world_rank()
n_workers, _ = _get_worker_world_rank()
partition_len = len(self._obs_joinids) // world_size // n_workers
div, rem = divmod(partition_len, self.batch_size)
return div + bool(rem), len(self._var_joinids)
world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
obs_per_proc, obs_rem = divmod(len(self._obs_joinids), world_size)
# obs rows assigned to this "distributed" process
n_proc_obs = obs_per_proc + bool(rank < obs_rem)
obs_per_worker, obs_rem = divmod(n_proc_obs, n_workers)
# obs rows assigned to this worker process
n_worker_obs = obs_per_worker + bool(worker_id < obs_rem)
n_batches, rem = divmod(n_worker_obs, self.batch_size)
# (num batches this worker will produce, num features)
return n_batches + bool(rem), len(self._var_joinids)

def __getitem__(self, index: int) -> XObsDatum:
raise NotImplementedError(
Expand All @@ -349,11 +351,9 @@ def _io_batch_iter(
X: soma.SparseNDArray,
obs_joinid_iter: Iterator[npt.NDArray[np.int64]],
) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]:
"""Iterate over IO batches, i.e., SOMA query/read, producing a tuple of
(X: csr_array, obs: DataFrame).
"""Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``.
obs joinids read are controlled by the obs_joinid_iter. Iterator results will
be reindexed.
``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed.
Private method.
"""
Expand Down Expand Up @@ -475,7 +475,7 @@ class ExperimentAxisQueryIterDataPipe(
torch.utils.data.dataset.Dataset[XObsDatum]
],
):
"""A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
"""A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`.
This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for
legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the
Expand Down Expand Up @@ -534,7 +534,7 @@ def __len__(self) -> int:
Lifecycle:
deprecated
"""
return self._exp_iter.__len__()
return len(self._exp_iter)

@property
def shape(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -640,10 +640,6 @@ def __init__(
parallel with client-side processing of the SOMA data, potentially improving overall performance at the
cost of doubling memory utilization. Defaults to ``True``.
Returns:
An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to
a :class:`torch.data.utils.DataLoader` instance.
Raises:
``ValueError`` on various unsupported or malformed parameter values.
Expand All @@ -663,7 +659,8 @@ def __init__(
)

def __iter__(self) -> Iterator[XObsDatum]:
"""Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`.
"""Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and
:class:`pandas.DataFrame`.
Returns:
``iterator``
Expand All @@ -678,30 +675,29 @@ def __iter__(self) -> Iterator[XObsDatum]:
yield X, obs

def __len__(self) -> int:
"""Return approximate number of batches this iterable will produce.
"""Return number of batches this iterable will produce.
See important caveats in the PyTorch
[:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
documentation regarding ``len(dataloader)``, which also apply to this class.
Returns:
An ``int``.
``int`` (number of batches).
Lifecycle:
experimental
"""
return self._exp_iter.__len__()
return len(self._exp_iter)

@property
def shape(self) -> Tuple[int, int]:
"""Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`.
"""Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`.
This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
(i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect
the size of the partition of the data assigned to the active process.
If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0),
the number of batches will reflect the size of the data partition assigned to the active process.
Returns:
A tuple of ``int``s, for obs and var counts, respectively.
A tuple of two ``int`` values: number of batches, number of vars.
Lifecycle:
experimental
Expand Down

0 comments on commit b3b5432

Please sign in to comment.