Skip to content

Commit

Permalink
Merge pull request #15 from single-cell-data/rw/pr/6
Browse files Browse the repository at this point in the history
CR #6
  • Loading branch information
ryan-williams authored Oct 1, 2024
2 parents 9d9e20d + 579727a commit cd51799
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 49 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ show_error_codes = true
ignore_missing_imports = true
warn_unreachable = true
strict = true
python_version = 3.9
python_version = '3.11'
plugins = "numpy.typing.mypy_plugin"

[tool.ruff]
lint.select = ["E", "F", "B", "I"]
lint.ignore = ["E501"] # line too long
lint.extend-select = ["I001"] # unsorted-imports
fix = true
target-version = "py39"
target-version = "py311"
line-length = 120
80 changes: 33 additions & 47 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@
from itertools import islice
from math import ceil
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Dict,
Generator,
Iterable,
Iterator,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)

import attrs
Expand All @@ -38,26 +37,17 @@
import torch
import torchdata
from somacore.query._eager_iter import EagerIterator as _EagerIterator
from typing_extensions import TypeAlias

logger = logging.getLogger("tiledbsoma_ml.pytorch")

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)

if TYPE_CHECKING:
# Python 3.8 does not support subscripting types, so work-around by
# restricting this to when we are running a type checker. TODO: remove
# the conditional when Python 3.8 support is dropped.
NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]]
XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix]
else:
NDArrayNumber: TypeAlias = np.ndarray
XDatum: TypeAlias = Union[np.ndarray, sparse.csr_matrix]

XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame]
NDArrayNumber = npt.NDArray[np.number[Any]]
XDatum = Union[NDArrayNumber, sparse.csr_matrix]
XObsDatum = Tuple[XDatum, pd.DataFrame]
"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``,
which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case,
which pairs a slice of ``X`` rows with a corresponding slice of ``obs``. In the default case,
the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs``
respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is
returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray`
Expand All @@ -68,7 +58,7 @@
class _ExperimentLocator:
"""State required to open the Experiment.
Necessary as we will likely be invoked across multiple processes.
Serializable across multiple processes.
Private implementation class.
"""
Expand All @@ -86,12 +76,11 @@ def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator":
)

@contextmanager
def open_experiment(self) -> Iterator[soma.Experiment]:
def open_experiment(self) -> Generator[soma.Experiment, None, None]:
context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config)
with soma.Experiment.open(
yield soma.Experiment.open(
self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context
) as exp:
yield exp
)


class ExperimentAxisQueryIterable(Iterable[XObsDatum]):
Expand Down Expand Up @@ -122,8 +111,8 @@ def __init__(
Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`.
The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as
a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a
Pandas :class:`pandas.DataFrame`, respectively.
a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas
:class:`pandas.DataFrame`, respectively.
Args:
query:
Expand All @@ -136,16 +125,16 @@ def __init__(
batch_size:
The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of
``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will
result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows
this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader`
batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s
result in :class:`torch.Tensor`s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows
this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher
performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s
``batch_size`` parameter to ``None``.
io_batch_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts
maximum memory utilization, larger values provide better read performance, but require more memory.
return_sparse_X:
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will
return ``X`` data as a :class:`numpy.ndarray`.
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the
default), will return ``X`` data as a :class:`numpy.ndarray`.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
available for processing via the iterator. This allows network (or filesystem) requests to be made in
Expand All @@ -166,7 +155,7 @@ def __init__(

super().__init__()

# Anything set in the instance needs to be picklable for multi-process DataLoaders
# Anything set in the instance needs to be pickle-able for multi-process DataLoaders
self.experiment_locator = _ExperimentLocator.create(query.experiment)
self.layer_name = X_name
self.measurement_name = query.measurement_name
Expand Down Expand Up @@ -228,18 +217,17 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]:
].copy()

if logger.isEnabledFor(logging.DEBUG):
partition_size = sum([len(chunk) for chunk in obs_partition_joinids])
logger.debug(
f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, "
f"n_workers={n_workers}, "
f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}"
f"Process {os.getpid()} {rank=}, {world_size=}, {worker_id=}, n_workers={n_workers}, {partition_size=}"
)

return iter(obs_partition_joinids)

def _init_once(self, exp: soma.Experiment | None = None) -> None:
"""One-time per worker initialization.
All operations be idempotent in order to support pipe reset().
All operations should be idempotent in order to support pipe reset().
Private method.
"""
Expand Down Expand Up @@ -291,7 +279,7 @@ def __iter__(self) -> Iterator[XObsDatum]:
world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
logger.debug(
f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}"
f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}"
)

with self.experiment_locator.open_experiment() as exp:
Expand Down Expand Up @@ -398,17 +386,16 @@ def _io_batch_iter(
)

# Now that X read is potentially in progress (in eager mode), go fetch obs data
#
obs_io_batch = cast(
pd.DataFrame,
# fmt: off
obs_io_batch = (
obs.read(coords=(obs_coords,), column_names=obs_column_names)
.concat()
.to_pandas()
.set_index("soma_joinid")
.reindex(obs_coords, copy=False)
.reset_index(),
)
obs_io_batch = obs_io_batch[self.obs_column_names]
.reset_index() # demote "soma_joinid" to a column
[self.obs_column_names]
) # fmt: on

del obs_indexer, obs_coords, X_tbl
gc.collect()
Expand Down Expand Up @@ -576,9 +563,9 @@ class ExperimentAxisQueryIterableDataset(
>>> import tiledbsoma
>>> import tiledbsoma_ml
>>> with tiledbsoma.Experiment.open("my_experiment_path") as exp:
with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query:
ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query)
dataloader = torch.utils.data.DataLoader(ds)
... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query:
... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query)
... dataloader = torch.utils.data.DataLoader(ds)
>>> data = next(iter(dataloader))
>>> data
(array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
Expand Down Expand Up @@ -723,7 +710,7 @@ def shape(self) -> Tuple[int, int]:


def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]:
"""For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes.
"""For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes.
A total_length of L, split into N sections, will return L%N sections of size L//N+1,
and the remainder as size L//N. This results in the same split as numpy.array_split,
Expand All @@ -750,11 +737,10 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]:

if sys.version_info >= (3, 12):
_batched = itertools.batched

else:

def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]:
"""Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions."""
"""Same as the Python 3.12+ ``itertools.batched`` -- polyfill for old Python versions."""
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
Expand All @@ -763,13 +749,13 @@ def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]:


def _get_distributed_world_rank() -> Tuple[int, int]:
"""Return tuple containing equivalent of torch.distributed world size and rank."""
"""Return tuple containing equivalent of ``torch.distributed`` world size and rank."""
world_size, rank = 1, 0
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ:
# Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There
# Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There
# is a NODE_RANK for the node's rank, but no way to tell the local node's
# world. So computing a global rank is impossible(?). Using LOCAL_RANK as a
# proxy, which works fine on a single-CPU box. TODO: could throw/error
Expand Down

0 comments on commit cd51799

Please sign in to comment.