From acf584f464e77b65a969e276284b30012aa4f5b7 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 31 Jul 2024 05:52:34 -0700 Subject: [PATCH 01/70] initial commit of pytorch datapipe/loader --- apis/python/src/tiledbsoma/ml/__init__.py | 7 +- apis/python/src/tiledbsoma/ml/encoders.py | 19 ++- apis/python/src/tiledbsoma/ml/pytorch.py | 141 +++++++++++++++------- 3 files changed, 117 insertions(+), 50 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/__init__.py b/apis/python/src/tiledbsoma/ml/__init__.py index 8450cdc..82def32 100644 --- a/apis/python/src/tiledbsoma/ml/__init__.py +++ b/apis/python/src/tiledbsoma/ml/__init__.py @@ -1,4 +1,9 @@ -"""An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census.""" +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +"""An API to facilitate use of PyTorch ML training with data from SOMA data.""" from .encoders import BatchEncoder, Encoder, LabelEncoder from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader diff --git a/apis/python/src/tiledbsoma/ml/encoders.py b/apis/python/src/tiledbsoma/ml/encoders.py index 3d4fc4d..32a4125 100644 --- a/apis/python/src/tiledbsoma/ml/encoders.py +++ b/apis/python/src/tiledbsoma/ml/encoders.py @@ -1,3 +1,10 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + import abc import functools from typing import List @@ -30,7 +37,7 @@ def fit(self, obs: pd.DataFrame) -> None: pass @abc.abstractmethod - def transform(self, df: pd.DataFrame) -> pd.DataFrame: + def transform(self, df: pd.DataFrame) -> npt.ArrayLike: """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" pass @@ -63,8 +70,8 @@ def fit(self, obs: pd.DataFrame) -> None: """Fit the encoder with ``obs``.""" self._encoder.fit(obs[self.col].unique()) - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" + def transform(self, df: pd.DataFrame) -> npt.ArrayLike: + """Transform the obs :class:`pandas.DataFrame` into a :class:`numpy.typing.ArrayLike` of encoded values.""" return self._encoder.transform(df[self.col]) # type: ignore def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: @@ -98,9 +105,11 @@ def __init__(self, cols: List[str], name: str = "batch"): self._encoder = LabelEncoder() def _join_cols(self, df: pd.DataFrame): # type: ignore - return functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols]) + return functools.reduce( + lambda a, b: a + b, [df[c].astype(str) for c in self.cols] + ) - def transform(self, df: pd.DataFrame) -> pd.DataFrame: + def transform(self, df: pd.DataFrame) -> npt.ArrayLike: """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" arr = self._join_cols(df) return self._encoder.transform(arr) # type: ignore diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 5bef673..62bd3b9 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -1,3 +1,10 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + import gc import itertools import logging @@ -13,23 +20,23 @@ import numpy.typing as npt import pandas as pd import psutil -import tiledbsoma as soma import torch import torchdata.datapipes.iter as pipes -from attr import define +from attrs import define from numpy.random import Generator from pyarrow import Table from scipy import sparse +from somacore.query._eager_iter import EagerIterator as _EagerIterator from torch import Tensor from torch import distributed as dist from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset -from ... import get_default_soma_context -from ..util._eager_iter import _EagerIterator +import tiledbsoma as soma + from .encoders import Encoder, LabelEncoder -pytorch_logger = logging.getLogger("cellxgene_census.experimental.pytorch") +pytorch_logger = logging.getLogger("tiledbsoma.ml.pytorch") # TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) ObsAndXDatum = Tuple[Tensor, Tensor] @@ -84,7 +91,10 @@ class Stats: """The number of chunks retrieved""" def __str__(self) -> str: - return f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " f"elapsed={timedelta(seconds=self.elapsed)}" + return ( + f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " + f"elapsed={timedelta(seconds=self.elapsed)}" + ) def __add__(self, other: "Stats") -> "Stats": self.n_obs += other.n_obs @@ -94,16 +104,32 @@ def __add__(self, other: "Stats") -> "Stats": return self -@contextmanager -def _open_experiment( - uri: str, - aws_region: Optional[str] = None, -) -> soma.Experiment: - """Internal method for opening a SOMA ``Experiment`` as a context manager.""" - context = get_default_soma_context().replace(tiledb_config={"vfs.s3.region": aws_region} if aws_region else {}) +@define(frozen=True, kw_only=True) +class ExperimentLocator: + """State required to open the Experiment. + + Necessary as we will likely be invoked across multiple processes. + """ + + uri: str + tiledb_timestamp_ms: int + tiledb_config: Dict[str, Union[str, float]] + + @classmethod + def create(cls, experiment: soma.Experiment) -> "ExperimentLocator": + return ExperimentLocator( + uri=experiment.uri, + tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, + tiledb_config=experiment.context.tiledb_config, + ) - with soma.Experiment.open(uri, context=context) as exp: - yield exp + @contextmanager + def open_experiment(self) -> Iterator[soma.Experiment]: + context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) + with soma.Experiment.open( + self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context + ) as exp: + yield exp def _tables_to_np( @@ -156,7 +182,9 @@ def __init__( # The result is again a list of numpy arrays. self.obs_joinids_chunks_iter = ( shuffle_rng.permutation(np.concatenate(grouped_chunks)) - for grouped_chunks in list_split(obs_joinids_chunked, shuffle_chunk_count) + for grouped_chunks in list_split( + obs_joinids_chunked, shuffle_chunk_count + ) ) else: self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) @@ -196,13 +224,18 @@ def __next__(self) -> _SOMAChunk: # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) - blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise( - axis=0, size=len(obs_joinids_chunk), eager=False - ) + blockwise_iter = self.X.read( + coords=(obs_joinids_chunk, self.var_joinids) + ).blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) X_batch: ChunkX if not self.return_sparse_X: - res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids)))) + res = next( + _tables_to_np( + blockwise_iter.tables(), + shape=(obs_batch.shape[0], len(self.var_joinids)), + ) + ) X_batch, nnz = res[0], res[2] else: X_batch = next(blockwise_iter.scipy(compress=True))[0] @@ -317,7 +350,9 @@ def __next__(self) -> ObsAndXDatum: while n_obs < self.batch_size: try: - obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs) + obs_partial, X_partial = self._read_partial_torch_batch( + self.batch_size - n_obs + ) n_obs += len(obs_partial) obss.append(obs_partial) Xs.append(X_partial) @@ -371,7 +406,9 @@ def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, Chun # GC memory from previous soma_chunk self.soma_chunk = None pre_gc, _, gc_elapsed = run_gc() - self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss) + self.max_process_mem_usage_bytes = max( + self.max_process_mem_usage_bytes, pre_gc[0].uss + ) self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) self.stats += self.soma_chunk.stats @@ -540,8 +577,7 @@ def __init__( Lifecycle: experimental """ - self.exp_uri = experiment.uri - self.aws_region = experiment.context.tiledb_config.get("vfs.s3.region") + self.experiment_locator = ExperimentLocator.create(experiment) self.measurement_name = measurement_name self.layer_name = X_name self.obs_query = obs_query @@ -570,7 +606,9 @@ def __init__( if len(encoders) != len({enc.name for enc in encoders}): raise ValueError("Encoders must have unique names") - self.obs_column_names = list(dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders]))) + self.obs_column_names = list( + dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders])) + ) def _init(self) -> None: if self._initialized: @@ -578,7 +616,7 @@ def _init(self) -> None: pytorch_logger.debug("Initializing ExperimentDataPipe") - with _open_experiment(self.exp_uri, self.aws_region) as exp: + with self.experiment_locator.open_experiment() as exp: query = exp.axis_query( measurement_name=self.measurement_name, obs_query=self.obs_query, @@ -665,14 +703,20 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: dist_partition=dist.get_rank() if dist.is_initialized() else 0, num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, ) - obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition( - obs_joinids_chunked, partition, partitions - ) + obs_joinids_chunked_partition: List[ + npt.NDArray[np.int64] + ] = self._subset_ids_to_partition(obs_joinids_chunked, partition, partitions) + + with self.experiment_locator.open_experiment() as exp: + X = exp.ms[self.measurement_name].X[self.layer_name] + if not isinstance(X, soma.SparseNDArray): + raise NotImplementedError( + "ExperimentDataPipe only supported on X layers which are of type SparseNDArray" + ) - with _open_experiment(self.exp_uri, self.aws_region) as exp: obs_and_x_iter = _ObsAndXIterator( obs=exp.obs, - X=exp.ms[self.measurement_name].X[self.layer_name], + X=X, obs_column_names=self.obs_column_names, obs_joinids_chunked=obs_joinids_chunked_partition, var_joinids=self._var_joinids, @@ -687,15 +731,22 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: yield from obs_and_x_iter - self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes + self.max_process_mem_usage_bytes = ( + obs_and_x_iter.max_process_mem_usage_bytes + ) pytorch_logger.debug( - "max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" + "max process memory usage=" + f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" ) @staticmethod - def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> List[npt.NDArray[np.int64]]: + def _chunk_ids( + ids: npt.NDArray[np.int64], chunk_size: int + ) -> List[npt.NDArray[np.int64]]: num_chunks = max(1, ceil(len(ids) / chunk_size)) - pytorch_logger.debug(f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}") + pytorch_logger.debug( + f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" + ) return np.array_split(ids, num_chunks) def __len__(self) -> int: @@ -736,11 +787,11 @@ def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> List[Encoder]: # TODO: This does not work in multiprocessing mode, as child process's stats are not collected def stats(self) -> Stats: - """Get data loading stats for this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + """Get data loading stats for this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. Returns: - The :class:`cellxgene_census.experimental.ml.pytorch.Stats` object for this - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + The :class:`tiledbsoma.ml.pytorch.Stats` object for this + :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. Lifecycle: experimental @@ -749,7 +800,7 @@ def stats(self) -> Stats: @property def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + """Get the shape of the data that will be returned by this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect the size of the partition of the data assigned to the active process. @@ -796,7 +847,7 @@ def experiment_dataloader( **dataloader_kwargs: Any, ) -> DataLoader: """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a - :class:`torch.utils.data.DataLoader` that works with :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`, + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`, since some of the :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, ``collate_fn``). @@ -804,15 +855,15 @@ def experiment_dataloader( Args: datapipe: An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe` or any other + :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe` or any other :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the - :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. num_workers: Number of worker processes to use for data loading. If ``0``, data will be loaded in the main process. **dataloader_kwargs: Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, and ``collate_fn``, which are not - supported when using :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. + supported when using :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. Returns: A :class:`torch.utils.data.DataLoader`. @@ -832,7 +883,9 @@ def experiment_dataloader( "collate_fn", ] if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): - raise ValueError(f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported") + raise ValueError( + f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported" + ) if num_workers > 0: _init_multiprocessing() From 12237b0fdb5ca33f44aca89ae7d26e8abc656755 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 31 Jul 2024 10:06:28 -0700 Subject: [PATCH 02/70] update comments --- apis/python/src/tiledbsoma/ml/encoders.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/encoders.py b/apis/python/src/tiledbsoma/ml/encoders.py index 32a4125..13f0496 100644 --- a/apis/python/src/tiledbsoma/ml/encoders.py +++ b/apis/python/src/tiledbsoma/ml/encoders.py @@ -29,6 +29,9 @@ class Encoder(abc.ABC): - ``columns``: List of columns in ``obs`` that the encoder will be applied to. See the implementation of :class:`LabelEncoder` for an example. + + Lifecycle: + experimental """ @abc.abstractmethod @@ -60,7 +63,11 @@ def columns(self) -> List[str]: class LabelEncoder(Encoder): - """Default encoder based on :class:`sklearn.preprocessing.LabelEncoder`.""" + """Default encoder based on :class:`sklearn.preprocessing.LabelEncoder`. + + Lifecycle: + experimental + """ def __init__(self, col: str) -> None: self._encoder = SklearnLabelEncoder() @@ -95,7 +102,11 @@ def classes_(self): # type: ignore class BatchEncoder(Encoder): - """An encoder that concatenates and encodes several ``obs`` columns.""" + """An encoder that concatenates and encodes several ``obs`` columns. + + Lifecycle: + experimental + """ def __init__(self, cols: List[str], name: str = "batch"): self.cols = cols From 2fc9bebaf1373aec7c0af06631539f35df0f99a3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 31 Jul 2024 11:25:43 -0700 Subject: [PATCH 03/70] more lint --- apis/python/src/tiledbsoma/ml/pytorch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 62bd3b9..ebb3f67 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -703,9 +703,9 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: dist_partition=dist.get_rank() if dist.is_initialized() else 0, num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, ) - obs_joinids_chunked_partition: List[ - npt.NDArray[np.int64] - ] = self._subset_ids_to_partition(obs_joinids_chunked, partition, partitions) + obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = ( + self._subset_ids_to_partition(obs_joinids_chunked, partition, partitions) + ) with self.experiment_locator.open_experiment() as exp: X = exp.ms[self.measurement_name].X[self.layer_name] @@ -759,7 +759,9 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> ObsAndXDatum: raise NotImplementedError("IterDataPipe can only be iterated") - def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> List[Encoder]: + def _build_obs_encoders( + self, query: soma.ExperimentAxisQuery[soma.Experiment] # type: ignore[type-var] + ) -> List[Encoder]: pytorch_logger.debug("Initializing encoders") encoders = [] @@ -840,7 +842,6 @@ def _collate_noop(x: Any) -> Any: return x -# TODO: Move into somacore.ExperimentAxisQuery def experiment_dataloader( datapipe: pipes.IterDataPipe, num_workers: int = 0, From 220a11b6874d5748c428bf44fdb4d5ac0c994008 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Thu, 1 Aug 2024 13:07:15 -0700 Subject: [PATCH 04/70] fix typos --- apis/python/src/tiledbsoma/ml/pytorch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index ebb3f67..49aeead 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -672,6 +672,10 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: assert self._var_joinids is not None if self.soma_chunk_size is None: + # TODO: given that soma_chunk_size defaults to 64, this code will only run if the user explicitly + # provides None as a value. Which is a bit unorthodox. Consider using some other sentinel, e.g., + # define an 'memory_budget' param. + # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data sparsity, 8 bytes for the # X value and 8 bytes for the sparse matrix indices, and a 100% working memory overhead (2x). X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 @@ -695,7 +699,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: if self._shuffle_rng: self._shuffle_rng.shuffle(obs_joinids_chunked) - # subset to a single partition, as needed for distributed training and multi-processing datat loading + # subset to a single partition, as needed for distributed training and multi-processing data loading worker_info = torch.utils.data.get_worker_info() partition, partitions = self._compute_partitions( loader_partition=worker_info.id if worker_info else 0, @@ -745,7 +749,7 @@ def _chunk_ids( ) -> List[npt.NDArray[np.int64]]: num_chunks = max(1, ceil(len(ids) / chunk_size)) pytorch_logger.debug( - f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" + f"Splitting {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" ) return np.array_split(ids, num_chunks) From 2c870ea692eccd69a95178b10a71e25f431b2dd0 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 20 Aug 2024 10:53:16 -0700 Subject: [PATCH 05/70] rework for performance --- apis/python/src/tiledbsoma/ml/__init__.py | 10 +- apis/python/src/tiledbsoma/ml/pytorch.py | 1256 ++++++++++----------- 2 files changed, 577 insertions(+), 689 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/__init__.py b/apis/python/src/tiledbsoma/ml/__init__.py index 82def32..857f30f 100644 --- a/apis/python/src/tiledbsoma/ml/__init__.py +++ b/apis/python/src/tiledbsoma/ml/__init__.py @@ -6,11 +6,15 @@ """An API to facilitate use of PyTorch ML training with data from SOMA data.""" from .encoders import BatchEncoder, Encoder, LabelEncoder -from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader +from .pytorch import ( + ExperimentAxisQueryDataPipe, + ExperimentAxisQueryIterableDataset, + experiment_dataloader, +) __all__ = [ - "Stats", - "ExperimentDataPipe", + "ExperimentAxisQueryDataPipe", + "ExperimentAxisQueryIterableDataset", "experiment_dataloader", "Encoder", "LabelEncoder", diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 49aeead..0a77f4a 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -1,36 +1,38 @@ -# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation -# Copyright (c) 2021-2024 TileDB, Inc. -# -# Licensed under the MIT License. - from __future__ import annotations import gc import itertools import logging import os -import typing +import sys +import time from contextlib import contextmanager -from datetime import timedelta +from itertools import islice from math import ceil -from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union - +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import attrs import numpy as np import numpy.typing as npt import pandas as pd import psutil +import pyarrow as pa +import scipy.sparse as sparse import torch -import torchdata.datapipes.iter as pipes -from attrs import define -from numpy.random import Generator -from pyarrow import Table -from scipy import sparse +import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator -from torch import Tensor -from torch import distributed as dist -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset import tiledbsoma as soma @@ -38,77 +40,32 @@ pytorch_logger = logging.getLogger("tiledbsoma.ml.pytorch") -# TODO: Rename to reflect the correct order of the Tensors within the tuple: (X, obs) -ObsAndXDatum = Tuple[Tensor, Tensor] -"""Return type of ``ExperimentDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). -The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" - - -# "Chunk" of X data, returned by each `Method` above -ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] - - -@define -class _SOMAChunk: - """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X`` - matrix. +_T_co = TypeVar("_T_co", covariant=True) - Lifecycle: - experimental - """ - - obs: pd.DataFrame - X: ChunkX - stats: "Stats" - - def __len__(self) -> int: - return len(self.obs) +XObsDatum = Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] +XObsNpDatum = Tuple[npt.NDArray[np.number[Any]], npt.NDArray[np.number[Any]]] +XObsTensorDatum = Tuple[torch.Tensor, torch.Tensor] +"""Return type of ``ExperimentAxisQueryDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). +The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" Encoders = Dict[str, Encoder] """A dictionary of ``Encoder``s keyed by the ``obs`` column name.""" -@define -class Stats: - """Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API. This is useful for assessing the read - throughput of SOMA data. - - Lifecycle: - experimental - """ - - n_obs: int = 0 - """The total number of obs rows retrieved""" - - nnz: int = 0 - """The total number of values retrieved""" - - elapsed: float = 0 - """The total elapsed time in seconds for retrieving all batches""" - - n_soma_chunks: int = 0 - """The number of chunks retrieved""" - - def __str__(self) -> str: - return ( - f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " - f"elapsed={timedelta(seconds=self.elapsed)}" - ) - - def __add__(self, other: "Stats") -> "Stats": - self.n_obs += other.n_obs - self.nnz += other.nnz - self.elapsed += other.elapsed - self.n_soma_chunks += other.n_soma_chunks - return self +if TYPE_CHECKING: + PsUtilMemInfo = Tuple[psutil.pfullmem, psutil.svmem, psutil.sswap] +else: + PsUtilMemInfo = Tuple[Any] -@define(frozen=True, kw_only=True) -class ExperimentLocator: +@attrs.define(frozen=True, kw_only=True) +class _ExperimentLocator: """State required to open the Experiment. Necessary as we will likely be invoked across multiple processes. + + Private. """ uri: str @@ -116,8 +73,8 @@ class ExperimentLocator: tiledb_config: Dict[str, Union[str, float]] @classmethod - def create(cls, experiment: soma.Experiment) -> "ExperimentLocator": - return ExperimentLocator( + def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": + return _ExperimentLocator( uri=experiment.uri, tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, tiledb_config=experiment.context.tiledb_config, @@ -132,314 +89,382 @@ def open_experiment(self) -> Iterator[soma.Experiment]: yield exp -def _tables_to_np( - tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] -) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: - for tbl, indices in tables: - row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) - nnz = len(data) - dense_matrix = np.zeros(shape, dtype=data.dtype) - dense_matrix[row_indices, col_indices] = data - yield dense_matrix, indices, nnz - - -class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): - """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, - not intended for public use. - """ - - X: soma.SparseNDArray - """A handle to the full X data of the SOMA ``Experiment``""" - - obs_joinids_chunks_iter: Iterator[npt.NDArray[np.int64]] +class ExperimentAxisQueryIterable(Iterable[XObsDatum]): + r"""Private base class for Dataset/DataPipe subclasses.""" - var_joinids: npt.NDArray[np.int64] - """The ``var`` joinids to be retrieved from the SOMA ``Experiment``""" + # XXX TODO - docstrings, slots, etc. def __init__( self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, - return_sparse_X: bool = False, + experiment: soma.Experiment, + measurement_name: str = "RNA", + X_name: str = "raw", + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, + obs_column_names: Sequence[str] = (), + batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, + soma_chunk_size: int = 64, + use_eager_fetch: bool = True, + encoders: List[Encoder] | None = None, + shuffle_chunk_count: int = 2000, + partition: bool = True, ): - self.obs = obs - self.X = X + super().__init__() + + # Anything set in the instance needs to be picklable for multi-process users + self.experiment_locator = _ExperimentLocator.create(experiment) + self.measurement_name = measurement_name + self.layer_name = X_name + self.obs_query = obs_query + self.var_query = var_query self.obs_column_names = obs_column_names - if shuffle_chunk_count is not None: - assert shuffle_rng is not None - - # At the start of this step, `obs_joinids_chunked` is a list of one dimensional - # numpy arrays. Each numpy array corresponds to a chunk of contiguous rows in `obs`. - # Critically, `obs_joinids_chunked` is randomly ordered where each chunk is - # from a random section of `obs`. - # We then take `shuffle_chunk_count` of these in order, concatenate them into - # a larger numpy array and shuffle this larger numpy array. - # The result is again a list of numpy arrays. - self.obs_joinids_chunks_iter = ( - shuffle_rng.permutation(np.concatenate(grouped_chunks)) - for grouped_chunks in list_split( - obs_joinids_chunked, shuffle_chunk_count - ) + self.batch_size = batch_size + self.soma_chunk_size = soma_chunk_size + self.use_eager_fetch = use_eager_fetch + # XXX - TODO: when/if we add X encoders, how will they be differentiated from obs encoders? Naming? + self._encoders = encoders or [] + self._obs_joinids: npt.NDArray[np.int64] | None = None + self._var_joinids: npt.NDArray[np.int64] | None = None + self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None + self._shuffle_rng = np.random.default_rng(seed) if shuffle else None + self.partition = partition + self._initialized = False + + if obs_column_names and encoders: + raise ValueError( + "Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically." ) - else: - self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) - self.var_joinids = var_joinids - self.shuffle_chunk_count = shuffle_chunk_count - self.return_sparse_X = return_sparse_X - def __next__(self) -> _SOMAChunk: - pytorch_logger.debug("Retrieving next SOMA chunk...") - start_time = time() + if encoders: + # Check if names are unique + if len(encoders) != len({enc.name for enc in encoders}): + raise ValueError("Encoders must have unique names") - # If no more chunks to iterate through, raise StopIteration, as all iterators do when at end - obs_joinids_chunk = next(self.obs_joinids_chunks_iter) + self.obs_column_names = list( + dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders])) + ) - if "soma_joinid" not in self.obs_column_names: - cols = ["soma_joinid", *self.obs_column_names] - else: - cols = list(self.obs_column_names) + def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: + """Private - create iterator over obs id chunks. - obs_batch = ( - self.obs.read( - coords=(obs_joinids_chunk,), - column_names=cols, - ) - .concat() - .to_pandas() - .set_index("soma_joinid") - ) - assert obs_batch.shape[0] == obs_joinids_chunk.shape[0] - - # handle case of empty result (first batch has 0 rows) - if len(obs_batch) == 0: - raise StopIteration - - # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled - obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) - - # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, - # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) - blockwise_iter = self.X.read( - coords=(obs_joinids_chunk, self.var_joinids) - ).blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) - - X_batch: ChunkX - if not self.return_sparse_X: - res = next( - _tables_to_np( - blockwise_iter.tables(), - shape=(obs_batch.shape[0], len(self.var_joinids)), - ) - ) - X_batch, nnz = res[0], res[2] - else: - X_batch = next(blockwise_iter.scipy(compress=True))[0] - nnz = X_batch.nnz + As appropriate, will chunk, shuffle and apply partitioning per worker. + """ + assert self._obs_joinids is not None + obs_joinids: npt.NDArray[np.int64] = self._obs_joinids - assert obs_batch.shape[0] == X_batch.shape[0] + # Chunk joinids by soma_chunk_size + assert self.soma_chunk_size is not None + num_chunks = max(1, ceil(len(obs_joinids) / self.soma_chunk_size)) + obs_joinids_chunked = np.array_split(obs_joinids, num_chunks) - end_time = time() - stats = Stats() - stats.n_obs += X_batch.shape[0] - stats.nnz += nnz - stats.elapsed += end_time - start_time - stats.n_soma_chunks += 1 + # Shuffle chunks. NOTE: this assumes a single global seed for the RNG, + # ensuring all workers get the same shuffle. + if self._shuffle_rng: + self._shuffle_rng.shuffle(obs_joinids_chunked) - pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") - return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) + # Now extract the partition for this worker + partition, num_partitions = ( + _get_torch_partition_info() if self.partition else (0, 1) + ) + obs_splits = _splits(len(obs_joinids_chunked), num_partitions) + obs_partition_joinids = obs_joinids_chunked[ + obs_splits[partition] : obs_splits[partition + 1] + ] + obs_joinid_iter = iter(obs_partition_joinids) + if pytorch_logger.isEnabledFor(logging.DEBUG) and self.partition: + pytorch_logger.debug( + f"Process {os.getpid()} handling partition {partition + 1} of {num_partitions}, " + f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" + ) -def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: - """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. - TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. - """ - i = 0 - result = [] - while i < len(arr_list): - if (i + sublist_len) >= len(arr_list): - result.append(arr_list[i:]) - else: - result.append(arr_list[i : i + sublist_len]) + return obs_joinid_iter - i += sublist_len + def _init_once(self) -> None: + """One-time per worker initialization. All operations be idempotent in order to support pipe reset().""" + if self._initialized: + return - return result + pytorch_logger.debug("Initializing ExperimentAxisQueryIterable") + with self.experiment_locator.open_experiment() as exp: + query = exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + self._encoders = self._build_encoders(query) -def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 - proc = psutil.Process(os.getpid()) + self._initialized = True - pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - start = time() - gc.collect() - gc_elapsed = time() - start - post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + def __iter__(self) -> Iterator[XObsDatum]: + with self.experiment_locator.open_experiment() as exp: + self._init_once() - pytorch_logger.debug(f"gc: pre={pre_gc}") - pytorch_logger.debug(f"gc: post={post_gc}") + X = exp.ms[self.measurement_name].X[self.layer_name] + if not isinstance(X, soma.SparseNDArray): + raise NotImplementedError( + "ExperimentAxisQueryIterDataPipe only supported on X layers which are of type SparseNDArray" + ) - return pre_gc, post_gc, gc_elapsed + obs_joinid_iter = self._create_obs_joinid_iter() + yield from self._encoded_mini_batch_iter(exp.obs, X, obs_joinid_iter) + def __len__(self) -> int: + self._init_once() + assert self._obs_joinids is not None -class _ObsAndXIterator(Iterator[ObsAndXDatum]): - """Iterates through a set of ``obs`` and corresponding ``X`` rows, where the rows to be returned are specified by - the ``obs_tables_iter`` argument. For the specified ``obs` rows, the corresponding ``X`` data is loaded and - joined together. It is returned from this iterator as 2-tuples of ``X`` and obs Tensors. + div, rem = divmod(len(self._obs_joinids), self.batch_size) + return div + bool(rem) - Internally manages the retrieval of data in SOMA-sized chunks, fetching the next chunk of SOMA data as needed. - Supports fetching the data in an eager manner, where the next SOMA chunk is fetched while the current chunk is - being read. This is an internal class, not intended for public use. - """ + @property + def shape(self) -> Tuple[int, int]: + """Get the shape of the data that will be returned by this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. - soma_chunk_iter: Iterator[_SOMAChunk] - """The iterator for SOMA chunks of paired obs and X data""" + Returns: + A 2-tuple of ``int``s, for obs and var counts, respectively. - soma_chunk: Optional[_SOMAChunk] - """The current SOMA chunk of obs and X data""" + Lifecycle: + experimental + """ + self._init_once() + assert self._obs_joinids is not None + assert self._var_joinids is not None + return len(self._obs_joinids), len(self._var_joinids) - i: int = -1 - """Index into current obs ``SOMA`` chunk""" + def __getitem__(self, index: int) -> XObsDatum: + raise NotImplementedError("Can only be iterated") - def __init__( + def _io_batch_iter( self, obs: soma.DataFrame, X: soma.SparseNDArray, - obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], - var_joinids: npt.NDArray[np.int64], - batch_size: int, - encoders: List[Encoder], - stats: Stats, - return_sparse_X: bool, - use_eager_fetch: bool, - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, - ) -> None: - self.soma_chunk_iter = _ObsAndXSOMAIterator( - obs, - X, - obs_column_names, - obs_joinids_chunked, - var_joinids, - shuffle_chunk_count, - shuffle_rng, - return_sparse_X=return_sparse_X, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[Tuple[sparse.csr_array, pd.DataFrame]]: + """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of + (X: csr_array, obs: DataFrame). + + obs joinids read are controlled by the obs_joinid_iter. Iterator results will + be reindexed and shuffled (if shuffling enabled). + + Private. + """ + assert self._var_joinids is not None + + obs_column_names = ( + list(self.obs_column_names) + if "soma_joinid" in self.obs_column_names + else ["soma_joinid", *self.obs_column_names] ) - if use_eager_fetch: - self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) - self.soma_chunk = None - self.var_joinids = var_joinids - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.encoders = encoders - self.stats = stats - self.gc_elapsed = 0.0 - self.max_process_mem_usage_bytes = 0 - self.X_dtype = X.schema[2].type.to_pandas_dtype() - - def __next__(self) -> ObsAndXDatum: - """Read the next torch batch, possibly across multiple soma chunks.""" - obss: list[pd.DataFrame] = [] - Xs: list[ChunkX] = [] - n_obs = 0 - - while n_obs < self.batch_size: - try: - obs_partial, X_partial = self._read_partial_torch_batch( - self.batch_size - n_obs - ) - n_obs += len(obs_partial) - obss.append(obs_partial) - Xs.append(X_partial) - except StopIteration: - break - - if len(Xs) == 0: # If we ran out of data - raise StopIteration - else: - if self.return_sparse_X: - X = sparse.vstack(Xs) - else: - X = np.concatenate(Xs, axis=0) - obs = pd.concat(obss, axis=0) + var_indexer = soma.IntIndexer(self._var_joinids, context=X.context) - obs_encoded = pd.DataFrame() + batched_obs_joinid_iter = _batched( + obs_joinid_iter, + self._shuffle_chunk_count if self._shuffle_chunk_count is not None else 1, + ) + for obs_coord_chunks in batched_obs_joinid_iter: + st_time = time.perf_counter() + obs_coords = np.concatenate(obs_coord_chunks) + obs_shuffled_coords = ( + obs_coords + if self._shuffle_rng is None + else self._shuffle_rng.permuted(obs_coords) + ) + pytorch_logger.debug( + f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." + ) + + def _get_obs_io_batch( + obs: soma.DataFrame, + obs_coords: npt.NDArray[np.int64], + obs_shuffled_coords: npt.NDArray[np.int64], + ) -> pd.DataFrame: + return cast( + pd.DataFrame, + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_shuffled_coords, copy=False) + .reset_index(), + ) - # Add the soma_joinid to the original obs, in case that is requested by the encoders. - obs["soma_joinid"] = obs.index + obs_future = obs.context.threadpool.submit( + _get_obs_io_batch, obs, obs_coords, obs_shuffled_coords + ) - for enc in self.encoders: - obs_encoded[enc.name] = enc.transform(obs) + X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() + if self.use_eager_fetch: + X_tbl_iter = _EagerIterator( + X_tbl_iter, pool=X.context.threadpool + ) # type:ignore[assignment] + + obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) + X_tbl = pa.concat_tables( + pa.Table.from_pydict( + { + "soma_dim_0": obs_indexer.get_indexer( + tbl["soma_dim_0"].to_numpy() + ), + "soma_dim_1": var_indexer.get_indexer( + tbl["soma_dim_1"].to_numpy() + ), + "soma_data": tbl["soma_data"].to_numpy(), + } + ) + for tbl in X_tbl_iter + ).sort_by([("soma_dim_0", "ascending"), ("soma_dim_1", "ascending")]) - # `to_numpy()` avoids copying the numpy array data - obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) + i, j, d = _IJD(X_tbl) + X_io_batch = sparse.csr_array( + (d, (i, j)), shape=(len(obs_coords), len(self._var_joinids)), copy=False + ) - if not self.return_sparse_X: - X_tensor = torch.from_numpy(X) - else: - coo = X.tocoo() + obs_io_batch = obs_future.result() - X_tensor = torch.sparse_coo_tensor( - # Note: The `np.array` seems unnecessary, but PyTorch warns bare array is "extremely slow" - indices=torch.from_numpy(np.array([coo.row, coo.col])), - values=coo.data, - size=coo.shape, + del i, j, d, X_tbl, X_tbl_iter + del obs_future, obs_coords, obs_shuffled_coords, obs_indexer + _run_gc() + tm = time.perf_counter() - st_time + pytorch_logger.debug( + f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f}samples/sec" ) + yield X_io_batch, obs_io_batch - if self.batch_size == 1: - X_tensor = X_tensor[0] - obs_tensor = obs_tensor[0] + def _mini_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[Tuple[sparse.csr_array, pd.DataFrame]]: + """Break IO batches into shuffled mini-batch-sized chunks, still in internal format (CSR, Pandas). + + Private. + """ + assert self._obs_joinids is not None + assert self._var_joinids is not None - return X_tensor, obs_tensor + io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter) + if self.use_eager_fetch: + io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) + + mini_batch_size = self.batch_size + result: Tuple[sparse.csr_array, pd.DataFrame] | None = None # partial result + for X_io_batch, obs_io_batch in io_batch_iter: + assert X_io_batch.shape[0] == obs_io_batch.shape[0] + assert X_io_batch.shape[1] == len(self._var_joinids) + iob_idx = 0 # current offset into io batch + iob_len = X_io_batch.shape[0] + + while iob_idx < iob_len: + if result is None: + # perform zero copy slice where possible + result = ( + X_io_batch[iob_idx : iob_idx + mini_batch_size], + obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], + ) + iob_idx += len(result[1]) + else: + # use remanent from previous IO batch + to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) + result = ( + sparse.vstack([result[0], X_io_batch[0:to_take]]), + pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), + ) + iob_idx += to_take + + assert result[0].shape[0] == result[1].shape[0] + if result[0].shape[0] == mini_batch_size: + yield result + result = None + + else: + # yield a remnant, if any + if result is not None: + yield result + + def _encoded_mini_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[XObsDatum]: + """Apply encoding on top of the mini batches. - def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: - """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may - contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current - SOMA chunk are fewer than the requested ``batch_size``. + Returns numpy encodings (obs, X). """ - if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): - # GC memory from previous soma_chunk - self.soma_chunk = None - pre_gc, _, gc_elapsed = run_gc() - self.max_process_mem_usage_bytes = max( - self.max_process_mem_usage_bytes, pre_gc[0].uss - ) - self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) - self.stats += self.soma_chunk.stats - self.gc_elapsed += gc_elapsed - self.i = 0 + for X_mini_batch, obs_mini_batch in self._mini_batch_iter( + obs, X, obs_joinid_iter + ): + # TODO - X encoding + X_encoded = X_mini_batch.todense() - pytorch_logger.debug( - f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}" + # Obs encoding + obs_encoded = pd.DataFrame( + {enc.name: enc.transform(obs_mini_batch) for enc in self._encoders} ) - obs_batch = self.soma_chunk.obs - X_chunk = self.soma_chunk.X + del obs_mini_batch, X_mini_batch + yield X_encoded, obs_encoded + + def _build_encoders( + self, query: soma.ExperimentAxisQuery[soma.Experiment] # type: ignore[type-var] + ) -> List[Encoder]: + pytorch_logger.debug("Initializing encoders") + + encoders = [] + + if "soma_joinid" not in self.obs_column_names: + cols = ["soma_joinid", *self.obs_column_names] + else: + cols = list(self.obs_column_names) - safe_batch_size = min(batch_size, len(obs_batch) - self.i) - slice_ = slice(self.i, self.i + safe_batch_size) - assert slice_.stop <= obs_batch.shape[0] + obs = query.obs(column_names=cols).concat().to_pandas() - obs_rows = obs_batch.iloc[slice_] - assert obs_rows.index.is_unique - assert safe_batch_size == obs_rows.shape[0] + if self._encoders: + # Fit all the custom encoders with obs + for enc in self._encoders: + enc.fit(obs) + encoders.append(enc) + else: + # Create one LabelEncoder for each column, and fit it with obs + for col in self.obs_column_names: + enc = LabelEncoder(col) + enc.fit(obs) + encoders.append(enc) - X_batch = X_chunk[slice_] + return encoders - assert obs_rows.shape[0] == X_batch.shape[0] + @property + def encoders(self) -> Encoders: + """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on ``obs`` column names, + which were used to encode the ``obs`` column values. - self.i += safe_batch_size + These encoders can be used to decode the encoded values as follows: - return obs_rows, X_batch + >>> exp_data_pipe.encoders[""].inverse_transform(encoded_values) + Returns: + A ``Dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing.LabelEncoder` objects. + """ + self._init_once() + assert self._encoders is not None + return {enc.name: enc for enc in self._encoders} -class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore + +class ExperimentAxisQueryDataPipe( + torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] + torch.utils.data.dataset.Dataset[XObsTensorDatum] + ], +): r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` axes. Provides an iterator over these data when the object is passed to Python's built-in ``iter`` function. @@ -472,397 +497,158 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig ``obs_column_names``, and string-typed columns are encoded as integer values. If needed, these values can be decoded by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` method: - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + >>> exp_data_pipe.encoders[""].inverse_transform(encoded_values) Lifecycle: experimental """ - _initialized: bool - - _obs_joinids: Optional[npt.NDArray[np.int64]] - - _var_joinids: Optional[npt.NDArray[np.int64]] - - _encoders: List[Encoder] - - _stats: Stats - - _shuffle_rng: Optional[Generator] - - # TODO: Consider adding another convenience method wrapper to construct this object whose signature is more closely - # aligned with get_anndata() params (i.e. "exploded" AxisQuery params). def __init__( self, experiment: soma.Experiment, measurement_name: str = "RNA", X_name: str = "raw", - obs_query: Optional[soma.AxisQuery] = None, - var_query: Optional[soma.AxisQuery] = None, + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, obs_column_names: Sequence[str] = (), batch_size: int = 1, shuffle: bool = True, - seed: Optional[int] = None, - return_sparse_X: bool = False, - soma_chunk_size: Optional[int] = 64, + seed: int | None = None, + soma_chunk_size: int = 64, use_eager_fetch: bool = True, - encoders: Optional[List[Encoder]] = None, - shuffle_chunk_count: Optional[int] = 2000, - ) -> None: - r"""Construct a new ``ExperimentDataPipe``. - - Args: - experiment: - The :class:`tiledbsoma.Experiment` from which to read data. - measurement_name: - The name of the :class:`tiledbsoma.Measurement` to read. Defaults to ``"RNA"``. - X_name: - The name of the X layer to read. Defaults to ``"raw"``. - obs_query: - The query used to filter along the ``obs`` axis. If not specified, all ``obs`` and ``X`` data will - be returned, which can be very large. - var_query: - The query used to filter along the ``var`` axis. If not specified, all ``var`` columns (genes/features) - will be returned. - obs_column_names: - The names of the ``obs`` columns to return. The ``soma_joinid`` index "column" does not need to be - specified and will always be returned. If not specified, only the ``soma_joinid`` will be returned. - If custom encoders are passed, this parameter must not be used, since the columns will be inferred - automatically from the encoders. - batch_size: - The number of rows of ``obs`` and ``X`` data to return in each iteration. Defaults to ``1``. A value of - ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). - shuffle: - Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. - For performance reasons, shuffling is not performed globally across all rows, but rather in chunks. - More specifically, we select ``shuffle_chunk_count`` non-contiguous chunks across all the observations - in the query, concatenate the chunks and shuffle the associated observations. - The randomness of the shuffling is therefore determined by the - (``soma_chunk_size``, ``shuffle_chunk_count``) selection. The default values have been determined - to yield a good trade-off between randomness and performance. Further tuning may be required for - different type of models. Note that memory usage is correlated to the product - ``soma_chunk_size * shuffle_chunk_count``. - seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This *must* be specified when using - :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker - processes. - return_sparse_X: - Controls whether the ``X`` data is returned as a dense or sparse :class:`torch.Tensor`. As ``X`` data is - very sparse, setting this to ``True`` will reduce memory usage, if the model supports use of sparse - :class:`torch.Tensor`\ s. Defaults to ``False``, since sparse :class:`torch.Tensor`\ s are still - experimental in PyTorch. - soma_chunk_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of - this class's behavior: 1) The maximum memory utilization, with larger values providing - better read performance, but also requiring more memory; 2) The granularity of the global shuffling - step (see ``shuffle`` parameter for details). The default value of 64 works well in conjunction - with the default ``shuffle_chunk_count`` value. - use_eager_fetch: - Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made - available for processing via the iterator. This allows network (or filesystem) requests to be made in - parallel with client-side processing of the SOMA data, potentially improving overall performance at the - cost of doubling memory utilization. Defaults to ``True``. - shuffle_chunk_count: - The number of contiguous blocks (chunks) of rows sampled to then concatenate and shuffle. - Larger numbers correspond to more randomness per training batch. - If ``shuffle == False``, this parameter is ignored. Defaults to ``2000``. - encoders: - Specify custom encoders to be used. If not specified, a LabelEncoder will be created and - used for each column in ``obs_column_names``. If specified, only columns for which an encoder - has been registered will be returned in the ``obs`` tensor. Each encoder needs to have a unique name. - If this parameter is specified, the ``obs_column_names`` parameter must not be used, - since the columns will be inferred automatically from the encoders. - - Lifecycle: - experimental - """ - self.experiment_locator = ExperimentLocator.create(experiment) - self.measurement_name = measurement_name - self.layer_name = X_name - self.obs_query = obs_query - self.var_query = var_query - self.obs_column_names = obs_column_names - self.batch_size = batch_size - self.return_sparse_X = return_sparse_X - self.soma_chunk_size = soma_chunk_size - self.use_eager_fetch = use_eager_fetch - self._stats = Stats() - self._encoders = encoders or [] - self._obs_joinids = None - self._var_joinids = None - self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None - self._shuffle_rng = np.random.default_rng(seed) if shuffle else None - self._initialized = False - self.max_process_mem_usage_bytes = 0 - - if obs_column_names and encoders: - raise ValueError( - "Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically." - ) - - if encoders: - # Check if names are unique - if len(encoders) != len({enc.name for enc in encoders}): - raise ValueError("Encoders must have unique names") - - self.obs_column_names = list( - dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders])) - ) - - def _init(self) -> None: - if self._initialized: - return - - pytorch_logger.debug("Initializing ExperimentDataPipe") - - with self.experiment_locator.open_experiment() as exp: - query = exp.axis_query( - measurement_name=self.measurement_name, - obs_query=self.obs_query, - var_query=self.var_query, - ) - - # The to_numpy() call is a workaround for a possible bug in TileDB-SOMA: - # https://github.com/single-cell-data/TileDB-SOMA/issues/1456 - self._obs_joinids = query.obs_joinids().to_numpy() - self._var_joinids = query.var_joinids().to_numpy() - - self._encoders = self._build_obs_encoders(query) - - self._initialized = True - - @staticmethod - def _subset_ids_to_partition( - ids_chunked: List[npt.NDArray[np.int64]], - partition_index: int, - num_partitions: int, - ) -> List[npt.NDArray[np.int64]]: - """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world - size. - """ - # subset to a single partition - # typing does not reflect that is actually a List of 2D NDArrays - partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) - partition = [ids_chunked[i] for i in partition_indices[partition_index]] - - if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: - pytorch_logger.debug( - f"Process {os.getpid()} handling partition {partition_index + 1} of {num_partitions}, " - f"partition_size={sum([len(chunk) for chunk in partition])}" - ) - - return partition - - @staticmethod - def _compute_partitions( - loader_partition: int, - loader_partitions: int, - dist_partition: int, - num_dist_partitions: int, - ) -> Tuple[int, int]: - # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload - total_partitions = num_dist_partitions * loader_partitions - partition = dist_partition * loader_partitions + loader_partition - return partition, total_partitions - - def __iter__(self) -> Iterator[ObsAndXDatum]: - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - if self.soma_chunk_size is None: - # TODO: given that soma_chunk_size defaults to 64, this code will only run if the user explicitly - # provides None as a value. Which is a bit unorthodox. Consider using some other sentinel, e.g., - # define an 'memory_budget' param. - - # set soma_chunk_size to utilize ~1 GiB of RAM per SOMA chunk; assumes 95% X data sparsity, 8 bytes for the - # X value and 8 bytes for the sparse matrix indices, and a 100% working memory overhead (2x). - X_row_memory_size = 0.05 * len(self._var_joinids) * 8 * 3 * 2 - self.soma_chunk_size = int((1 * 1024**3) / X_row_memory_size) - pytorch_logger.debug(f"Using {self.soma_chunk_size=}") - - if ( - self.return_sparse_X - and torch.utils.data.get_worker_info() - and torch.utils.data.get_worker_info().num_workers > 0 - ): - raise NotImplementedError( - "torch does not work with sparse tensors in multi-processing mode " - "(see https://github.com/pytorch/pytorch/issues/20248)" - ) - - # chunk the obs joinids into batches of size soma_chunk_size - obs_joinids_chunked = self._chunk_ids(self._obs_joinids, self.soma_chunk_size) - - # globally shuffle the chunks, if requested - if self._shuffle_rng: - self._shuffle_rng.shuffle(obs_joinids_chunked) - - # subset to a single partition, as needed for distributed training and multi-processing data loading - worker_info = torch.utils.data.get_worker_info() - partition, partitions = self._compute_partitions( - loader_partition=worker_info.id if worker_info else 0, - loader_partitions=worker_info.num_workers if worker_info else 1, - dist_partition=dist.get_rank() if dist.is_initialized() else 0, - num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, - ) - obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = ( - self._subset_ids_to_partition(obs_joinids_chunked, partition, partitions) + encoders: List[Encoder] | None = None, + shuffle_chunk_count: int = 2000, + ): + self._exp_iter = ExperimentAxisQueryIterable( + experiment=experiment, + measurement_name=measurement_name, + X_name=X_name, + obs_query=obs_query, + var_query=var_query, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + soma_chunk_size=soma_chunk_size, + use_eager_fetch=use_eager_fetch, + encoders=encoders, + shuffle_chunk_count=shuffle_chunk_count, + # --- + partition=True, ) - with self.experiment_locator.open_experiment() as exp: - X = exp.ms[self.measurement_name].X[self.layer_name] - if not isinstance(X, soma.SparseNDArray): - raise NotImplementedError( - "ExperimentDataPipe only supported on X layers which are of type SparseNDArray" - ) - - obs_and_x_iter = _ObsAndXIterator( - obs=exp.obs, - X=X, - obs_column_names=self.obs_column_names, - obs_joinids_chunked=obs_joinids_chunked_partition, - var_joinids=self._var_joinids, - batch_size=self.batch_size, - encoders=self._encoders, - stats=self._stats, - return_sparse_X=self.return_sparse_X, - use_eager_fetch=self.use_eager_fetch, - shuffle_rng=self._shuffle_rng, - shuffle_chunk_count=self._shuffle_chunk_count, - ) - - yield from obs_and_x_iter - - self.max_process_mem_usage_bytes = ( - obs_and_x_iter.max_process_mem_usage_bytes - ) - pytorch_logger.debug( - "max process memory usage=" - f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" - ) + def __iter__(self) -> Iterator[XObsTensorDatum]: + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + X_tensor, obs_tensor = torch.from_numpy(X), torch.from_numpy(obs.to_numpy()) + if batch_size == 1: + X_tensor = X_tensor[0] + obs_tensor = obs_tensor[0] - @staticmethod - def _chunk_ids( - ids: npt.NDArray[np.int64], chunk_size: int - ) -> List[npt.NDArray[np.int64]]: - num_chunks = max(1, ceil(len(ids) / chunk_size)) - pytorch_logger.debug( - f"Splitting {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}" - ) - return np.array_split(ids, num_chunks) + yield X_tensor, obs_tensor def __len__(self) -> int: - self._init() - assert self._obs_joinids is not None - - div, rem = divmod(len(self._obs_joinids), self.batch_size) - return div + bool(rem) - - def __getitem__(self, index: int) -> ObsAndXDatum: - raise NotImplementedError("IterDataPipe can only be iterated") - - def _build_obs_encoders( - self, query: soma.ExperimentAxisQuery[soma.Experiment] # type: ignore[type-var] - ) -> List[Encoder]: - pytorch_logger.debug("Initializing encoders") + return self._exp_iter.__len__() - encoders = [] - - if "soma_joinid" not in self.obs_column_names: - cols = ["soma_joinid", *self.obs_column_names] - else: - cols = list(self.obs_column_names) + @property + def shape(self) -> Tuple[int, int]: + return self._exp_iter.shape - obs = query.obs(column_names=cols).concat().to_pandas() + @property + def encoders(self) -> Encoders: + return self._exp_iter.encoders - if self._encoders: - # Fit all the custom encoders with obs - for enc in self._encoders: - enc.fit(obs) - encoders.append(enc) - else: - # Create one LabelEncoder for each column, and fit it with obs - for col in self.obs_column_names: - enc = LabelEncoder(col) - enc.fit(obs) - encoders.append(enc) - return encoders +class ExperimentAxisQueryIterableDataset( + torch.utils.data.IterableDataset[XObsNpDatum] # type:ignore[misc] +): + def __init__( + self, + experiment: soma.Experiment, + measurement_name: str = "RNA", + X_name: str = "raw", + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, + obs_column_names: Sequence[str] = (), + batch_size: int = 1, # XXX add docstring noting values >1 will not work with default collator + shuffle: bool = True, + seed: int | None = None, + soma_chunk_size: int = 64, + use_eager_fetch: bool = True, + encoders: List[Encoder] | None = None, + shuffle_chunk_count: int = 2000, + ): + self._exp_iter = ExperimentAxisQueryIterable( + experiment=experiment, + measurement_name=measurement_name, + X_name=X_name, + obs_query=obs_query, + var_query=var_query, + obs_column_names=obs_column_names, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + soma_chunk_size=soma_chunk_size, + use_eager_fetch=use_eager_fetch, + encoders=encoders, + shuffle_chunk_count=shuffle_chunk_count, + # --- + partition=True, + ) - # TODO: This does not work in multiprocessing mode, as child process's stats are not collected - def stats(self) -> Stats: - """Get data loading stats for this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. + def __iter__(self) -> Iterator[XObsNpDatum]: + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + obs_np: npt.NDArray[np.number[Any]] = obs.to_numpy() + if batch_size == 1: + X = X[0] + obs_np = obs_np[0] - Returns: - The :class:`tiledbsoma.ml.pytorch.Stats` object for this - :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. + yield X, obs_np - Lifecycle: - experimental - """ - return self._stats + def __len__(self) -> int: + return self._exp_iter.__len__() @property def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. - This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode - (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the partition of the data assigned to the active process. - - Returns: - A 2-tuple of ``int``s, for obs and var counts, respectively. - - Lifecycle: - experimental - """ - self._init() - assert self._obs_joinids is not None - assert self._var_joinids is not None - - return len(self._obs_joinids), len(self._var_joinids) + return self._exp_iter.shape @property - def obs_encoders(self) -> Encoders: - """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on ``obs`` column names, - which were used to encode the ``obs`` column values. - - These encoders can be used to decode the encoded values as follows: - - >>> exp_data_pipe.obs_encoders[""].inverse_transform(encoded_values) + def encoders(self) -> Encoders: + return self._exp_iter.encoders - Returns: - A ``Dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing.LabelEncoder` objects. - """ - self._init() - assert self._encoders is not None - - return {enc.name: enc for enc in self._encoders} +def _collate_ndarray_to_tensor( + datum: Sequence[npt.NDArray[np.number[Any]] | torch.Tensor], +) -> Tuple[torch.Tensor, ...]: + """Default collate_fn for ``experiment_dataloader`` converts ndarray to a Tensor. -# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling -def _collate_noop(x: Any) -> Any: - return x + Must be a top-level function to play nice with picking for multiprocessing use cases. + """ + return tuple( + torch.from_numpy(d) if not isinstance(d, torch.Tensor) else d for d in datum + ) def experiment_dataloader( - datapipe: pipes.IterDataPipe, + ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, num_workers: int = 0, **dataloader_kwargs: Any, -) -> DataLoader: +) -> torch.utils.data.DataLoader: """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a - :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`, - since some of the :class:`torch.utils.data.DataLoader` constructor parameters are not applicable when using a - :class:`torchdata.datapipes.iter.IterDataPipe` (``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, - ``collate_fn``). + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`. + + Several :class:`torch.utils.data.DataLoader` parameters are disallowed as they are either non-performant and better + specified on the :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset` (``shuffle``, ``batch_size``) or are not compatible + (``sampler``, ``batch_sampler``, ``collate_fn``). Args: datapipe: - An :class:`torchdata.datapipes.iter.IterDataPipe`, which can be an - :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe` or any other - :class:`torchdata.datapipes.iter.IterDataPipe` that has been chained to the - :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. + An :class:`torch.util.data.IterableDataset`, which can be an + :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`, + :class:`tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` or any other + :class:`torch.util.data.IterableDataset` that has been chained to the + :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`. num_workers: Number of worker processes to use for data loading. If ``0``, data will be loaded in the main process. **dataloader_kwargs: @@ -895,18 +681,117 @@ def experiment_dataloader( if num_workers > 0: _init_multiprocessing() - return DataLoader( - datapipe, - batch_size=None, # batching is handled by our ExperimentDataPipe + return torch.utils.data.DataLoader( + ds, + batch_size=None, # batching is handled by our ExperimentAxisQueryIterableDataset num_workers=num_workers, - # avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches - collate_fn=_collate_noop, - # shuffling is handled by our ExperimentDataPipe + collate_fn=_collate_ndarray_to_tensor, + # shuffling is handled by ExperimentAxisQueryIterableDataset shuffle=False, **dataloader_kwargs, ) +def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: + """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. + + A total_length of L, split into N sections, will return L%N sections of size L//N+1, + and the remainder as size L//N. + + This results in is the same split as numpy.array_split, for an array of length L. + + Examples + -------- + >>> _splits(10, 3) + array([0, 4, 7, 10]) + >>> _splits(4, 2) + array([0, 2, 4]) + """ + if not sections > 0: + raise ValueError("number of sections must greater than 0.") from None + each_section, extras = divmod(total_length, sections) + per_section_sizes = ( + [0] + extras * [each_section + 1] + (sections - extras) * [each_section] + ) + splits = np.array(per_section_sizes, dtype=np.intp).cumsum() + return splits + + +def _IJD( + tbl: pa.Table, +) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.number[Any]]]: + """Given SOMA-style Pyarrow Table of COO sparse array data, return I,J,D vectors.""" + return ( + tbl["soma_dim_0"].to_numpy(), + tbl["soma_dim_1"].to_numpy(), + tbl["soma_data"].to_numpy(), + ) + + +if sys.version_info >= (3, 12): + _batched = itertools.batched + +else: + + def _batched( + iterable: Iterable[_T_co], n: int, *, strict: bool = False + ) -> Iterator[Tuple[_T_co, ...]]: + """Same as the Python 3.12+ itertools.batched, but polyfilled for old implementations.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + if strict and len(batch) != n: + raise ValueError("batched(): incomplete batch") + yield batch + + +def _run_gc() -> Tuple[PsUtilMemInfo, PsUtilMemInfo, float]: + """Run Python GC and log stats.""" + proc = psutil.Process(os.getpid()) + + pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + start = time.time() + gc.collect() + gc_elapsed = time.time() - start + post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + + pytorch_logger.debug(f"gc: pre={pre_gc}") + pytorch_logger.debug(f"gc: post={post_gc}") + + return pre_gc, post_gc, gc_elapsed + + +def _get_torch_partition_info() -> Tuple[int, int]: + """Return this workers partition and total partition count as a tuple. + + Private. Used to partition the iterator in some cases. + + Examples + -------- + >>> _get_torch_partition_info() + (0, 1) + + """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + loader_partition, loader_partitions = 0, 1 + else: + loader_partition = worker_info.id + loader_partitions = worker_info.num_workers + + if not torch.distributed.is_initialized(): + dist_partition, num_dist_partitions = 0, 1 + else: + dist_partition = torch.distributed.get_rank() + num_dist_partitions = torch.distributed.get_world_size() + + total_partitions = num_dist_partitions * loader_partitions + partition = dist_partition * loader_partitions + loader_partition + + return partition, total_partitions + + def _init_multiprocessing() -> None: """Ensures use of "spawn" for starting child processes with multiprocessing. @@ -916,7 +801,6 @@ def _init_multiprocessing() -> None: https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing """ - torch.multiprocessing.set_start_method("fork", force=True) orig_start_method = torch.multiprocessing.get_start_method() if orig_start_method != "spawn": if orig_start_method: From 1aedf3d407e97533b266a1b9a2308ba247edec49 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 21 Aug 2024 04:51:42 +0000 Subject: [PATCH 06/70] tuning --- apis/python/src/tiledbsoma/ml/pytorch.py | 149 +++++++++++++---------- 1 file changed, 85 insertions(+), 64 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 0a77f4a..49d33c9 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -24,6 +24,7 @@ ) import attrs +import numba import numpy as np import numpy.typing as npt import pandas as pd @@ -187,7 +188,9 @@ def _init_once(self) -> None: if self._initialized: return - pytorch_logger.debug("Initializing ExperimentAxisQueryIterable") + pytorch_logger.debug( + f"Initializing ExperimentAxisQueryIterable (shuffle={self._shuffle_rng is not None})" + ) with self.experiment_locator.open_experiment() as exp: query = exp.axis_query( @@ -212,7 +215,15 @@ def __iter__(self) -> Iterator[XObsDatum]: ) obs_joinid_iter = self._create_obs_joinid_iter() - yield from self._encoded_mini_batch_iter(exp.obs, X, obs_joinid_iter) + _mini_batch_iter = self._encoded_mini_batch_iter( + exp.obs, X, obs_joinid_iter + ) + if self.use_eager_fetch: + _mini_batch_iter = _EagerIterator( + _mini_batch_iter, pool=exp.context.threadpool + ) + + yield from _mini_batch_iter def __len__(self) -> int: self._init_once() @@ -281,23 +292,14 @@ def _io_batch_iter( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) - def _get_obs_io_batch( - obs: soma.DataFrame, - obs_coords: npt.NDArray[np.int64], - obs_shuffled_coords: npt.NDArray[np.int64], - ) -> pd.DataFrame: - return cast( - pd.DataFrame, - obs.read(coords=(obs_coords,), column_names=obs_column_names) - .concat() - .to_pandas() - .set_index("soma_joinid") - .reindex(obs_shuffled_coords, copy=False) - .reset_index(), - ) - - obs_future = obs.context.threadpool.submit( - _get_obs_io_batch, obs, obs_coords, obs_shuffled_coords + obs_io_batch = cast( + pd.DataFrame, + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_shuffled_coords, copy=False) + .reset_index(), ) X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() @@ -307,34 +309,30 @@ def _get_obs_io_batch( ) # type:ignore[assignment] obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) - X_tbl = pa.concat_tables( - pa.Table.from_pydict( - { - "soma_dim_0": obs_indexer.get_indexer( - tbl["soma_dim_0"].to_numpy() - ), - "soma_dim_1": var_indexer.get_indexer( - tbl["soma_dim_1"].to_numpy() - ), - "soma_data": tbl["soma_data"].to_numpy(), - } - ) - for tbl in X_tbl_iter - ).sort_by([("soma_dim_0", "ascending"), ("soma_dim_1", "ascending")]) - - i, j, d = _IJD(X_tbl) + X_tbl = pa.concat_tables(X_tbl_iter) + X_tbl = pa.Table.from_pydict( + { + "soma_dim_0": obs_indexer.get_indexer( + X_tbl["soma_dim_0"].to_numpy() + ), + "soma_dim_1": var_indexer.get_indexer( + X_tbl["soma_dim_1"].to_numpy() + ), + "soma_data": X_tbl["soma_data"].to_numpy(), + } + ) X_io_batch = sparse.csr_array( - (d, (i, j)), shape=(len(obs_coords), len(self._var_joinids)), copy=False + _D_IJ(X_tbl), + shape=(len(obs_coords), len(self._var_joinids)), + copy=False, ) - obs_io_batch = obs_future.result() - - del i, j, d, X_tbl, X_tbl_iter - del obs_future, obs_coords, obs_shuffled_coords, obs_indexer + del X_tbl, X_tbl_iter + del obs_coords, obs_shuffled_coords, obs_indexer _run_gc() tm = time.perf_counter() - st_time pytorch_logger.debug( - f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f}samples/sec" + f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" ) yield X_io_batch, obs_io_batch @@ -400,12 +398,12 @@ def _encoded_mini_batch_iter( Returns numpy encodings (obs, X). """ - for X_mini_batch, obs_mini_batch in self._mini_batch_iter( obs, X, obs_joinid_iter ): # TODO - X encoding - X_encoded = X_mini_batch.todense() + # X_encoded = X_mini_batch.todense() SLOW + X_encoded = _csr_to_dense(X_mini_batch) # Obs encoding obs_encoded = pd.DataFrame( @@ -465,12 +463,17 @@ class ExperimentAxisQueryDataPipe( torch.utils.data.dataset.Dataset[XObsTensorDatum] ], ): - r"""An :class:`torchdata.datapipes.iter.IterDataPipe` that reads ``obs`` and ``X`` data from a - :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` axes. Provides an - iterator over these data when the object is passed to Python's built-in ``iter`` function. + """An :class:`torchdata.datapipes.iter.IterDataPipe` which reads ``obs`` and ``X`` data from a + :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` axes. Provides + a standard Python iterable interface. + + >>> for batch in ExperimentAxisQueryDataPipe(...): + X_batch, obs_batch = batch - >>> for batch in iter(ExperimentDataPipe(...)): - X_batch, y_batch = batch + **WARNING:** :class:`torchdata.datapipes` is deprecated as of version 0.9 (July 2024), and is slated for removal + in a future release (late 2024). It is recommended that new code utilize :class:`ExperimentAxisQueryIterableDataset`. + Older code should pin the torchdata version to 0.9 or older. For more information, see + https://github.com/pytorch/data/issues/1196. The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: @@ -488,14 +491,11 @@ class ExperimentAxisQueryDataPipe( [2416, 0, 4], [2417, 0, 3]], dtype=torch.int64)) - The ``return_sparse_X`` parameter controls whether the ``X`` data is returned as a dense or sparse - :class:`torch.Tensor`. If the model supports use of sparse :class:`torch.Tensor`\ s, this will reduce memory usage. + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. + User-specified encoders may be provided - when not provided, the ``X`` batch will not be encoded, and the + ``obs`` batch will be encoded with a simple label encoder. - The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. The first - element is always the ``soma_joinid`` of the ``obs`` :class:`pandas.DataFrame` (or, equivalently, the - ``soma_dim_0`` of the ``X`` matrix). The remaining elements are the ``obs`` columns specified by - ``obs_column_names``, and string-typed columns are encoded as integer values. If needed, these values can be decoded - by obtaining the encoder for a given ``obs`` column name and calling its ``inverse_transform`` method: + If needed, encoded valeus can be decoded by calling ``inverse_transform`` method on the encoder, e.g., >>> exp_data_pipe.encoders[""].inverse_transform(encoded_values) @@ -621,7 +621,7 @@ def encoders(self) -> Encoders: def _collate_ndarray_to_tensor( datum: Sequence[npt.NDArray[np.number[Any]] | torch.Tensor], ) -> Tuple[torch.Tensor, ...]: - """Default collate_fn for ``experiment_dataloader`` converts ndarray to a Tensor. + """Default torch.utils.data.DataLoader collate_fn for ``experiment_dataloader`` -- converts ndarray to a Tensor. Must be a top-level function to play nice with picking for multiprocessing use cases. """ @@ -717,15 +717,16 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: return splits -def _IJD( +def _D_IJ( tbl: pa.Table, -) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.number[Any]]]: - """Given SOMA-style Pyarrow Table of COO sparse array data, return I,J,D vectors.""" - return ( - tbl["soma_dim_0"].to_numpy(), - tbl["soma_dim_1"].to_numpy(), - tbl["soma_data"].to_numpy(), - ) +) -> Tuple[ + npt.NDArray[np.number[Any]], Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]] +]: + """Given SOMA-style Pyarrow Table of COO sparse array data, return tuple (D, (I, J)) vectors.""" + d = tbl["soma_data"].to_numpy() + i = tbl["soma_dim_0"].to_numpy() + j = tbl["soma_dim_1"].to_numpy() + return d, (i, j) if sys.version_info >= (3, 12): @@ -809,3 +810,23 @@ def _init_multiprocessing() -> None: f'"{torch.multiprocessing.get_start_method()}" to "spawn"' ) torch.multiprocessing.set_start_method("spawn", force=True) + + +@numba.njit(nogil=True, parallel=True) +def _csr_to_dense_inner(indptr, indices, data, out): # type:ignore[no-untyped-def] + n_rows = out.shape[0] + for i in numba.prange(n_rows): + for j in range(indptr[i], indptr[i + 1]): + out[i, indices[j]] = data[j] + + return out + + +def _csr_to_dense(sp: sparse.csr_array) -> npt.NDArray[np.number[Any]]: + assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) + return cast( + npt.NDArray[np.number[Any]], + _csr_to_dense_inner( + sp.indptr, sp.indices, sp.data, np.zeros(sp.shape, dtype=sp.dtype) + ), + ) From e577ecd08f6e381406a70f4dac938bd2e18c966c Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 21 Aug 2024 08:45:45 -0700 Subject: [PATCH 07/70] tweaks, checkpoint --- apis/python/src/tiledbsoma/ml/pytorch.py | 28 +++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 49d33c9..b0a4641 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -303,10 +303,10 @@ def _io_batch_iter( ) X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() - if self.use_eager_fetch: - X_tbl_iter = _EagerIterator( - X_tbl_iter, pool=X.context.threadpool - ) # type:ignore[assignment] + # if self.use_eager_fetch: + # X_tbl_iter = _EagerIterator( + # X_tbl_iter, pool=X.context.threadpool + # ) # type:ignore[assignment] obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) X_tbl = pa.concat_tables(X_tbl_iter) @@ -396,18 +396,26 @@ def _encoded_mini_batch_iter( ) -> Iterator[XObsDatum]: """Apply encoding on top of the mini batches. - Returns numpy encodings (obs, X). + Returns numpy encodings (X, obs). """ - for X_mini_batch, obs_mini_batch in self._mini_batch_iter( - obs, X, obs_joinid_iter - ): + _encoded_mini_batch_iter = self._mini_batch_iter(obs, X, obs_joinid_iter) + if self.use_eager_fetch: + _encoded_mini_batch_iter = _EagerIterator( + _encoded_mini_batch_iter, pool=X.context.threadpool + ) + + for X_mini_batch, obs_mini_batch in _encoded_mini_batch_iter: # TODO - X encoding # X_encoded = X_mini_batch.todense() SLOW X_encoded = _csr_to_dense(X_mini_batch) # Obs encoding - obs_encoded = pd.DataFrame( - {enc.name: enc.transform(obs_mini_batch) for enc in self._encoders} + obs_encoded = ( + pd.DataFrame( + {enc.name: enc.transform(obs_mini_batch) for enc in self._encoders} + ) + if self._encoders + else obs_mini_batch ) del obs_mini_batch, X_mini_batch From ee2929d401cd14d4fb3014e29f1818f2c839ca0c Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 21 Aug 2024 20:41:46 +0000 Subject: [PATCH 08/70] lint --- apis/python/src/tiledbsoma/ml/pytorch.py | 32 +++++++++++------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index b0a4641..54a391c 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -43,8 +43,12 @@ _T_co = TypeVar("_T_co", covariant=True) -XObsDatum = Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] -XObsNpDatum = Tuple[npt.NDArray[np.number[Any]], npt.NDArray[np.number[Any]]] +if TYPE_CHECKING: + NDArrayNumber = npt.NDArray[np.number[Any]] +else: + NDArrayNumber = "npt.NDArray[np.number[Any]]" +XObsDatum = Tuple[NDArrayNumber, pd.DataFrame] +XObsNpDatum = Tuple[NDArrayNumber, NDArrayNumber] XObsTensorDatum = Tuple[torch.Tensor, torch.Tensor] """Return type of ``ExperimentAxisQueryDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" @@ -303,13 +307,8 @@ def _io_batch_iter( ) X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() - # if self.use_eager_fetch: - # X_tbl_iter = _EagerIterator( - # X_tbl_iter, pool=X.context.threadpool - # ) # type:ignore[assignment] - - obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) X_tbl = pa.concat_tables(X_tbl_iter) + obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) X_tbl = pa.Table.from_pydict( { "soma_dim_0": obs_indexer.get_indexer( @@ -406,8 +405,9 @@ def _encoded_mini_batch_iter( for X_mini_batch, obs_mini_batch in _encoded_mini_batch_iter: # TODO - X encoding - # X_encoded = X_mini_batch.todense() SLOW - X_encoded = _csr_to_dense(X_mini_batch) + X_encoded = _csr_to_dense( + X_mini_batch + ) # same as X_mini_batch.todense(), which is SLOW # Obs encoding obs_encoded = ( @@ -607,7 +607,7 @@ def __init__( def __iter__(self) -> Iterator[XObsNpDatum]: batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: - obs_np: npt.NDArray[np.number[Any]] = obs.to_numpy() + obs_np: NDArrayNumber = obs.to_numpy() if batch_size == 1: X = X[0] obs_np = obs_np[0] @@ -627,7 +627,7 @@ def encoders(self) -> Encoders: def _collate_ndarray_to_tensor( - datum: Sequence[npt.NDArray[np.number[Any]] | torch.Tensor], + datum: Sequence[NDArrayNumber | torch.Tensor], ) -> Tuple[torch.Tensor, ...]: """Default torch.utils.data.DataLoader collate_fn for ``experiment_dataloader`` -- converts ndarray to a Tensor. @@ -727,9 +727,7 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: def _D_IJ( tbl: pa.Table, -) -> Tuple[ - npt.NDArray[np.number[Any]], Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]] -]: +) -> Tuple[NDArrayNumber, Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]]: """Given SOMA-style Pyarrow Table of COO sparse array data, return tuple (D, (I, J)) vectors.""" d = tbl["soma_data"].to_numpy() i = tbl["soma_dim_0"].to_numpy() @@ -830,10 +828,10 @@ def _csr_to_dense_inner(indptr, indices, data, out): # type:ignore[no-untyped-d return out -def _csr_to_dense(sp: sparse.csr_array) -> npt.NDArray[np.number[Any]]: +def _csr_to_dense(sp: sparse.csr_array) -> NDArrayNumber: assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) return cast( - npt.NDArray[np.number[Any]], + NDArrayNumber, _csr_to_dense_inner( sp.indptr, sp.indices, sp.data, np.zeros(sp.shape, dtype=sp.dtype) ), From 013cea607f9ace6168943197e94871cec8acef5f Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Thu, 22 Aug 2024 02:16:59 +0000 Subject: [PATCH 09/70] py 3.8 lint --- apis/python/src/tiledbsoma/ml/pytorch.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 54a391c..1c076f9 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: NDArrayNumber = npt.NDArray[np.number[Any]] else: - NDArrayNumber = "npt.NDArray[np.number[Any]]" + NDArrayNumber = np.ndarray XObsDatum = Tuple[NDArrayNumber, pd.DataFrame] XObsNpDatum = Tuple[NDArrayNumber, NDArrayNumber] XObsTensorDatum = Tuple[torch.Tensor, torch.Tensor] @@ -340,7 +340,7 @@ def _mini_batch_iter( obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], - ) -> Iterator[Tuple[sparse.csr_array, pd.DataFrame]]: + ) -> Iterator[Tuple[sparse.csr_array | sparse.csr_matrix, pd.DataFrame]]: """Break IO batches into shuffled mini-batch-sized chunks, still in internal format (CSR, Pandas). Private. @@ -353,7 +353,9 @@ def _mini_batch_iter( io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) mini_batch_size = self.batch_size - result: Tuple[sparse.csr_array, pd.DataFrame] | None = None # partial result + result: Tuple[sparse.csr_array | sparse.csr_matrix, pd.DataFrame] | None = ( + None # partial result + ) for X_io_batch, obs_io_batch in io_batch_iter: assert X_io_batch.shape[0] == obs_io_batch.shape[0] assert X_io_batch.shape[1] == len(self._var_joinids) @@ -372,6 +374,9 @@ def _mini_batch_iter( # use remanent from previous IO batch to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) result = ( + # In older versions of scipy.sparse, vstack will return _matrix when + # called with _array. Various code paths must accomadate this bug (mostly + # in their type declarations) sparse.vstack([result[0], X_io_batch[0:to_take]]), pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), ) @@ -828,7 +833,7 @@ def _csr_to_dense_inner(indptr, indices, data, out): # type:ignore[no-untyped-d return out -def _csr_to_dense(sp: sparse.csr_array) -> NDArrayNumber: +def _csr_to_dense(sp: sparse.csr_array | sparse.csr_matrix) -> NDArrayNumber: assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) return cast( NDArrayNumber, From 39dbab62aa471bfe0799a2ad76647d073291bee3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Thu, 22 Aug 2024 23:49:32 +0000 Subject: [PATCH 10/70] rework io and shuffle buffer size params --- apis/python/src/tiledbsoma/ml/pytorch.py | 73 ++++++++++++++---------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 1c076f9..51c35aa 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -109,11 +109,11 @@ def __init__( obs_column_names: Sequence[str] = (), batch_size: int = 1, shuffle: bool = True, + io_batch_size: int = 2**17, + shuffle_chunk_size: int = 64, seed: int | None = None, - soma_chunk_size: int = 64, use_eager_fetch: bool = True, encoders: List[Encoder] | None = None, - shuffle_chunk_count: int = 2000, partition: bool = True, ): super().__init__() @@ -126,17 +126,24 @@ def __init__( self.var_query = var_query self.obs_column_names = obs_column_names self.batch_size = batch_size - self.soma_chunk_size = soma_chunk_size + self.io_batch_size = io_batch_size + self.shuffle = shuffle self.use_eager_fetch = use_eager_fetch # XXX - TODO: when/if we add X encoders, how will they be differentiated from obs encoders? Naming? self._encoders = encoders or [] self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None - self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self.partition = partition self._initialized = False + self.shuffle_chunk_size = shuffle_chunk_size + if self.shuffle: + # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. + self.io_batch_size = ( + ceil(io_batch_size / shuffle_chunk_size) * shuffle_chunk_size + ) + if obs_column_names and encoders: raise ValueError( "Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically." @@ -152,22 +159,30 @@ def __init__( ) def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: - """Private - create iterator over obs id chunks. + """Private - create iterator over obs id chunks with split size of (roughly) io_batch_size. As appropriate, will chunk, shuffle and apply partitioning per worker. """ assert self._obs_joinids is not None obs_joinids: npt.NDArray[np.int64] = self._obs_joinids - # Chunk joinids by soma_chunk_size - assert self.soma_chunk_size is not None - num_chunks = max(1, ceil(len(obs_joinids) / self.soma_chunk_size)) - obs_joinids_chunked = np.array_split(obs_joinids, num_chunks) - - # Shuffle chunks. NOTE: this assumes a single global seed for the RNG, - # ensuring all workers get the same shuffle. - if self._shuffle_rng: - self._shuffle_rng.shuffle(obs_joinids_chunked) + if self.shuffle: + assert self._shuffle_rng is not None + assert self.io_batch_size % self.shuffle_chunk_size == 0 + shuffle_split = np.array_split( + obs_joinids, max(1, ceil(len(obs_joinids) / self.shuffle_chunk_size)) + ) + self._shuffle_rng.shuffle(shuffle_split) + obs_joinids_chunked = list( + np.concatenate(b) + for b in _batched( + shuffle_split, self.io_batch_size // self.shuffle_chunk_size + ) + ) + else: + obs_joinids_chunked = np.array_split( + obs_joinids, max(1, ceil(len(obs_joinids) / self.io_batch_size)) + ) # Now extract the partition for this worker partition, num_partitions = ( @@ -176,7 +191,7 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: obs_splits = _splits(len(obs_joinids_chunked), num_partitions) obs_partition_joinids = obs_joinids_chunked[ obs_splits[partition] : obs_splits[partition + 1] - ] + ].copy() obs_joinid_iter = iter(obs_partition_joinids) if pytorch_logger.isEnabledFor(logging.DEBUG) and self.partition: @@ -193,7 +208,7 @@ def _init_once(self) -> None: return pytorch_logger.debug( - f"Initializing ExperimentAxisQueryIterable (shuffle={self._shuffle_rng is not None})" + f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})" ) with self.experiment_locator.open_experiment() as exp: @@ -280,13 +295,9 @@ def _io_batch_iter( ) var_indexer = soma.IntIndexer(self._var_joinids, context=X.context) - batched_obs_joinid_iter = _batched( - obs_joinid_iter, - self._shuffle_chunk_count if self._shuffle_chunk_count is not None else 1, - ) - for obs_coord_chunks in batched_obs_joinid_iter: + for obs_coords in obs_joinid_iter: st_time = time.perf_counter() - obs_coords = np.concatenate(obs_coord_chunks) + # obs_coords = np.concatenate(obs_coord_chunks) obs_shuffled_coords = ( obs_coords if self._shuffle_rng is None @@ -527,10 +538,10 @@ def __init__( batch_size: int = 1, shuffle: bool = True, seed: int | None = None, - soma_chunk_size: int = 64, + io_batch_size: int = 2**17, + shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, encoders: List[Encoder] | None = None, - shuffle_chunk_count: int = 2000, ): self._exp_iter = ExperimentAxisQueryIterable( experiment=experiment, @@ -542,10 +553,10 @@ def __init__( batch_size=batch_size, shuffle=shuffle, seed=seed, - soma_chunk_size=soma_chunk_size, + io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, encoders=encoders, - shuffle_chunk_count=shuffle_chunk_count, + shuffle_chunk_size=shuffle_chunk_size, # --- partition=True, ) @@ -586,10 +597,10 @@ def __init__( batch_size: int = 1, # XXX add docstring noting values >1 will not work with default collator shuffle: bool = True, seed: int | None = None, - soma_chunk_size: int = 64, + io_batch_size: int = 2**17, + shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, encoders: List[Encoder] | None = None, - shuffle_chunk_count: int = 2000, ): self._exp_iter = ExperimentAxisQueryIterable( experiment=experiment, @@ -601,10 +612,10 @@ def __init__( batch_size=batch_size, shuffle=shuffle, seed=seed, - soma_chunk_size=soma_chunk_size, + io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, encoders=encoders, - shuffle_chunk_count=shuffle_chunk_count, + shuffle_chunk_size=shuffle_chunk_size, # --- partition=True, ) @@ -720,7 +731,7 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: >>> _splits(4, 2) array([0, 2, 4]) """ - if not sections > 0: + if sections <= 0: raise ValueError("number of sections must greater than 0.") from None each_section, extras = divmod(total_length, sections) per_section_sizes = ( From 98c4510874473b56a49207628ef52c9b910983f9 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 22:28:28 +0000 Subject: [PATCH 11/70] lint --- apis/python/src/tiledbsoma/ml/encoders.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/apis/python/src/tiledbsoma/ml/encoders.py b/apis/python/src/tiledbsoma/ml/encoders.py index 13f0496..314aaf4 100644 --- a/apis/python/src/tiledbsoma/ml/encoders.py +++ b/apis/python/src/tiledbsoma/ml/encoders.py @@ -7,15 +7,22 @@ import abc import functools -from typing import List +from typing import TYPE_CHECKING, Any, List, Tuple +import numpy as np import numpy.typing as npt import pandas as pd from sklearn.preprocessing import LabelEncoder as SklearnLabelEncoder +if TYPE_CHECKING: + NDArrayNumber = npt.NDArray[np.number[Any]] +else: + NDArrayNumber = np.ndarray +XObsDatum = Tuple[NDArrayNumber, pd.DataFrame] + class Encoder(abc.ABC): - """Base class for ``obs`` encoders. + """Base class for ``pytorch`` encoders. To define a custom encoder, five methods must be implemented: @@ -63,7 +70,9 @@ def columns(self) -> List[str]: class LabelEncoder(Encoder): - """Default encoder based on :class:`sklearn.preprocessing.LabelEncoder`. + """Default encoder - will encode obs values only (will not encode X values). + + Based on :class:`sklearn.preprocessing.LabelEncoder`. Lifecycle: experimental From 933787d7bdef481c6eac08055483cc20442173a0 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 22:29:07 +0000 Subject: [PATCH 12/70] remove encoders; more perf work --- apis/python/src/tiledbsoma/ml/encoders.py | 159 ---------- apis/python/src/tiledbsoma/ml/pytorch.py | 338 ++++------------------ 2 files changed, 61 insertions(+), 436 deletions(-) delete mode 100644 apis/python/src/tiledbsoma/ml/encoders.py diff --git a/apis/python/src/tiledbsoma/ml/encoders.py b/apis/python/src/tiledbsoma/ml/encoders.py deleted file mode 100644 index 314aaf4..0000000 --- a/apis/python/src/tiledbsoma/ml/encoders.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation -# Copyright (c) 2021-2024 TileDB, Inc. -# -# Licensed under the MIT License. - -from __future__ import annotations - -import abc -import functools -from typing import TYPE_CHECKING, Any, List, Tuple - -import numpy as np -import numpy.typing as npt -import pandas as pd -from sklearn.preprocessing import LabelEncoder as SklearnLabelEncoder - -if TYPE_CHECKING: - NDArrayNumber = npt.NDArray[np.number[Any]] -else: - NDArrayNumber = np.ndarray -XObsDatum = Tuple[NDArrayNumber, pd.DataFrame] - - -class Encoder(abc.ABC): - """Base class for ``pytorch`` encoders. - - To define a custom encoder, five methods must be implemented: - - - ``fit``: defines how the encoder will be fitted to the data. - - ``transform``: defines how the encoder will be applied to the data - in order to create an ``obs`` tensor. - - ``inverse_transform``: defines how to decode the encoded values back - to the original values. - - ``name``: The name of the encoder. This will be used as the key in the - dictionary of encoders. Each encoder passed to a :class:`.pytorch.ExperimentDataPipe` must have a unique name. - - ``columns``: List of columns in ``obs`` that the encoder will be applied to. - - See the implementation of :class:`LabelEncoder` for an example. - - Lifecycle: - experimental - """ - - @abc.abstractmethod - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with obs.""" - pass - - @abc.abstractmethod - def transform(self, df: pd.DataFrame) -> npt.ArrayLike: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" - pass - - @abc.abstractmethod - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - pass - - @property - @abc.abstractmethod - def name(self) -> str: - """Name of the encoder.""" - pass - - @property - @abc.abstractmethod - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - pass - - -class LabelEncoder(Encoder): - """Default encoder - will encode obs values only (will not encode X values). - - Based on :class:`sklearn.preprocessing.LabelEncoder`. - - Lifecycle: - experimental - """ - - def __init__(self, col: str) -> None: - self._encoder = SklearnLabelEncoder() - self.col = col - - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with ``obs``.""" - self._encoder.fit(obs[self.col].unique()) - - def transform(self, df: pd.DataFrame) -> npt.ArrayLike: - """Transform the obs :class:`pandas.DataFrame` into a :class:`numpy.typing.ArrayLike` of encoded values.""" - return self._encoder.transform(df[self.col]) # type: ignore - - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - return self._encoder.inverse_transform(encoded_values) # type: ignore - - @property - def name(self) -> str: - """Name of the encoder.""" - return self.col - - @property - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - return [self.col] - - @property - def classes_(self): # type: ignore - """Classes of the encoder.""" - return self._encoder.classes_ - - -class BatchEncoder(Encoder): - """An encoder that concatenates and encodes several ``obs`` columns. - - Lifecycle: - experimental - """ - - def __init__(self, cols: List[str], name: str = "batch"): - self.cols = cols - from sklearn.preprocessing import LabelEncoder - - self._name = name - self._encoder = LabelEncoder() - - def _join_cols(self, df: pd.DataFrame): # type: ignore - return functools.reduce( - lambda a, b: a + b, [df[c].astype(str) for c in self.cols] - ) - - def transform(self, df: pd.DataFrame) -> npt.ArrayLike: - """Transform the obs :class:`pandas.DataFrame` into a :class:`pandas.DataFrame` of encoded values.""" - arr = self._join_cols(df) - return self._encoder.transform(arr) # type: ignore - - def inverse_transform(self, encoded_values: npt.ArrayLike) -> npt.ArrayLike: - """Inverse transform the encoded values back to the original values.""" - return self._encoder.inverse_transform(encoded_values) # type: ignore - - def fit(self, obs: pd.DataFrame) -> None: - """Fit the encoder with ``obs``.""" - arr = self._join_cols(obs) - self._encoder.fit(arr.unique()) - - @property - def columns(self) -> List[str]: - """Columns in ``obs`` that the encoder will be applied to.""" - return self.cols - - @property - def name(self) -> str: - """Name of the encoder.""" - return self._name - - @property - def classes_(self): # type: ignore - """Classes of the encoder.""" - return self._encoder.classes_ diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/apis/python/src/tiledbsoma/ml/pytorch.py index 51c35aa..44dc028 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/apis/python/src/tiledbsoma/ml/pytorch.py @@ -15,7 +15,6 @@ Dict, Iterable, Iterator, - List, Sequence, Tuple, TypeVar, @@ -28,7 +27,6 @@ import numpy as np import numpy.typing as npt import pandas as pd -import psutil import pyarrow as pa import scipy.sparse as sparse import torch @@ -37,10 +35,9 @@ import tiledbsoma as soma -from .encoders import Encoder, LabelEncoder - -pytorch_logger = logging.getLogger("tiledbsoma.ml.pytorch") +logger = logging.getLogger("tiledbsoma.ml.pytorch") +_T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) if TYPE_CHECKING: @@ -54,16 +51,6 @@ The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" -Encoders = Dict[str, Encoder] -"""A dictionary of ``Encoder``s keyed by the ``obs`` column name.""" - - -if TYPE_CHECKING: - PsUtilMemInfo = Tuple[psutil.pfullmem, psutil.svmem, psutil.sswap] -else: - PsUtilMemInfo = Tuple[Any] - - @attrs.define(frozen=True, kw_only=True) class _ExperimentLocator: """State required to open the Experiment. @@ -106,14 +93,13 @@ def __init__( X_name: str = "raw", obs_query: soma.AxisQuery | None = None, var_query: soma.AxisQuery | None = None, - obs_column_names: Sequence[str] = (), + obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, shuffle: bool = True, io_batch_size: int = 2**17, shuffle_chunk_size: int = 64, seed: int | None = None, use_eager_fetch: bool = True, - encoders: List[Encoder] | None = None, partition: bool = True, ): super().__init__() @@ -124,13 +110,11 @@ def __init__( self.layer_name = X_name self.obs_query = obs_query self.var_query = var_query - self.obs_column_names = obs_column_names + self.obs_column_names = list(obs_column_names) self.batch_size = batch_size self.io_batch_size = io_batch_size self.shuffle = shuffle self.use_eager_fetch = use_eager_fetch - # XXX - TODO: when/if we add X encoders, how will they be differentiated from obs encoders? Naming? - self._encoders = encoders or [] self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None @@ -144,19 +128,8 @@ def __init__( ceil(io_batch_size / shuffle_chunk_size) * shuffle_chunk_size ) - if obs_column_names and encoders: - raise ValueError( - "Cannot specify both `obs_column_names` and `encoders`. If `encoders` are specified, columns will be inferred automatically." - ) - - if encoders: - # Check if names are unique - if len(encoders) != len({enc.name for enc in encoders}): - raise ValueError("Encoders must have unique names") - - self.obs_column_names = list( - dict.fromkeys(itertools.chain(*[enc.columns for enc in encoders])) - ) + if not self.obs_column_names: + raise ValueError("Must specify at least one value in `obs_column_names`") def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: """Private - create iterator over obs id chunks with split size of (roughly) io_batch_size. @@ -194,8 +167,8 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: ].copy() obs_joinid_iter = iter(obs_partition_joinids) - if pytorch_logger.isEnabledFor(logging.DEBUG) and self.partition: - pytorch_logger.debug( + if logger.isEnabledFor(logging.DEBUG) and self.partition: + logger.debug( f"Process {os.getpid()} handling partition {partition + 1} of {num_partitions}, " f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) @@ -207,7 +180,7 @@ def _init_once(self) -> None: if self._initialized: return - pytorch_logger.debug( + logger.debug( f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})" ) @@ -219,7 +192,6 @@ def _init_once(self) -> None: ) self._obs_joinids = query.obs_joinids().to_numpy() self._var_joinids = query.var_joinids().to_numpy() - self._encoders = self._build_encoders(query) self._initialized = True @@ -234,9 +206,7 @@ def __iter__(self) -> Iterator[XObsDatum]: ) obs_joinid_iter = self._create_obs_joinid_iter() - _mini_batch_iter = self._encoded_mini_batch_iter( - exp.obs, X, obs_joinid_iter - ) + _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) if self.use_eager_fetch: _mini_batch_iter = _EagerIterator( _mini_batch_iter, pool=exp.context.threadpool @@ -297,13 +267,12 @@ def _io_batch_iter( for obs_coords in obs_joinid_iter: st_time = time.perf_counter() - # obs_coords = np.concatenate(obs_coord_chunks) obs_shuffled_coords = ( obs_coords if self._shuffle_rng is None else self._shuffle_rng.permuted(obs_coords) ) - pytorch_logger.debug( + logger.debug( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) @@ -316,32 +285,30 @@ def _io_batch_iter( .reindex(obs_shuffled_coords, copy=False) .reset_index(), ) + obs_io_batch = obs_io_batch[self.obs_column_names] X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() X_tbl = pa.concat_tables(X_tbl_iter) obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) - X_tbl = pa.Table.from_pydict( - { - "soma_dim_0": obs_indexer.get_indexer( - X_tbl["soma_dim_0"].to_numpy() - ), - "soma_dim_1": var_indexer.get_indexer( - X_tbl["soma_dim_1"].to_numpy() - ), - "soma_data": X_tbl["soma_data"].to_numpy(), - } - ) + + i = obs_indexer.get_indexer(X_tbl["soma_dim_0"].to_numpy()) + j = var_indexer.get_indexer(X_tbl["soma_dim_1"].to_numpy()) + d = X_tbl["soma_data"].to_numpy() + if len(i) < np.iinfo(np.int32).max: + # downcast where able, as this saves considerable memory + i = i.astype(np.int32) + j = j.astype(np.int32) X_io_batch = sparse.csr_array( - _D_IJ(X_tbl), + (d, (i, j)), shape=(len(obs_coords), len(self._var_joinids)), copy=False, ) - del X_tbl, X_tbl_iter + del X_tbl, X_tbl_iter, i, j, d del obs_coords, obs_shuffled_coords, obs_indexer - _run_gc() + gc.collect() tm = time.perf_counter() - st_time - pytorch_logger.debug( + logger.debug( f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" ) yield X_io_batch, obs_io_batch @@ -351,8 +318,8 @@ def _mini_batch_iter( obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], - ) -> Iterator[Tuple[sparse.csr_array | sparse.csr_matrix, pd.DataFrame]]: - """Break IO batches into shuffled mini-batch-sized chunks, still in internal format (CSR, Pandas). + ) -> Iterator[XObsDatum]: + """Break IO batches into shuffled mini-batch-sized chunks. Private. """ @@ -395,137 +362,23 @@ def _mini_batch_iter( assert result[0].shape[0] == result[1].shape[0] if result[0].shape[0] == mini_batch_size: - yield result + # yield result + yield (_csr_to_dense(result[0]), result[1]) result = None else: # yield a remnant, if any if result is not None: - yield result - - def _encoded_mini_batch_iter( - self, - obs: soma.DataFrame, - X: soma.SparseNDArray, - obs_joinid_iter: Iterator[npt.NDArray[np.int64]], - ) -> Iterator[XObsDatum]: - """Apply encoding on top of the mini batches. - - Returns numpy encodings (X, obs). - """ - _encoded_mini_batch_iter = self._mini_batch_iter(obs, X, obs_joinid_iter) - if self.use_eager_fetch: - _encoded_mini_batch_iter = _EagerIterator( - _encoded_mini_batch_iter, pool=X.context.threadpool - ) - - for X_mini_batch, obs_mini_batch in _encoded_mini_batch_iter: - # TODO - X encoding - X_encoded = _csr_to_dense( - X_mini_batch - ) # same as X_mini_batch.todense(), which is SLOW - - # Obs encoding - obs_encoded = ( - pd.DataFrame( - {enc.name: enc.transform(obs_mini_batch) for enc in self._encoders} - ) - if self._encoders - else obs_mini_batch - ) - - del obs_mini_batch, X_mini_batch - yield X_encoded, obs_encoded - - def _build_encoders( - self, query: soma.ExperimentAxisQuery[soma.Experiment] # type: ignore[type-var] - ) -> List[Encoder]: - pytorch_logger.debug("Initializing encoders") - - encoders = [] - - if "soma_joinid" not in self.obs_column_names: - cols = ["soma_joinid", *self.obs_column_names] - else: - cols = list(self.obs_column_names) - - obs = query.obs(column_names=cols).concat().to_pandas() - - if self._encoders: - # Fit all the custom encoders with obs - for enc in self._encoders: - enc.fit(obs) - encoders.append(enc) - else: - # Create one LabelEncoder for each column, and fit it with obs - for col in self.obs_column_names: - enc = LabelEncoder(col) - enc.fit(obs) - encoders.append(enc) - - return encoders - - @property - def encoders(self) -> Encoders: - """Returns a dictionary of :class:`sklearn.preprocessing.LabelEncoder` objects, keyed on ``obs`` column names, - which were used to encode the ``obs`` column values. - - These encoders can be used to decode the encoded values as follows: - - >>> exp_data_pipe.encoders[""].inverse_transform(encoded_values) - - Returns: - A ``Dict[str, LabelEncoder]``, mapping column names to :class:`sklearn.preprocessing.LabelEncoder` objects. - """ - self._init_once() - assert self._encoders is not None - return {enc.name: enc for enc in self._encoders} + # yield result + yield (_csr_to_dense(result[0]), result[1]) class ExperimentAxisQueryDataPipe( torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] - torch.utils.data.dataset.Dataset[XObsTensorDatum] + torch.utils.data.dataset.Dataset[XObsDatum] ], ): - """An :class:`torchdata.datapipes.iter.IterDataPipe` which reads ``obs`` and ``X`` data from a - :class:`tiledbsoma.Experiment`, based upon the specified queries along the ``obs`` and ``var`` axes. Provides - a standard Python iterable interface. - - >>> for batch in ExperimentAxisQueryDataPipe(...): - X_batch, obs_batch = batch - - **WARNING:** :class:`torchdata.datapipes` is deprecated as of version 0.9 (July 2024), and is slated for removal - in a future release (late 2024). It is recommended that new code utilize :class:`ExperimentAxisQueryIterableDataset`. - Older code should pin the torchdata version to 0.9 or older. For more information, see - https://github.com/pytorch/data/issues/1196. - - The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each - iteration. If the ``batch_size`` is 1, then each Tensor will have rank 1: - - >>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data - tensor([2415, 0, 0], dtype=torch.int64)) # obs data, encoded - - For larger ``batch_size`` values, the returned Tensors will have rank 2: - - >>> DataLoader(..., batch_size=3, ...): - (tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch - [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), - tensor([[2415, 0, 0], # obs batch - [2416, 0, 4], - [2417, 0, 3]], dtype=torch.int64)) - - The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` Tensor. - User-specified encoders may be provided - when not provided, the ``X`` batch will not be encoded, and the - ``obs`` batch will be encoded with a simple label encoder. - - If needed, encoded valeus can be decoded by calling ``inverse_transform`` method on the encoder, e.g., - - >>> exp_data_pipe.encoders[""].inverse_transform(encoded_values) - - Lifecycle: - experimental - """ + # TODO: XXX docstrings def __init__( self, @@ -534,15 +387,16 @@ def __init__( X_name: str = "raw", obs_query: soma.AxisQuery | None = None, var_query: soma.AxisQuery | None = None, - obs_column_names: Sequence[str] = (), + obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, shuffle: bool = True, seed: int | None = None, io_batch_size: int = 2**17, shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, - encoders: List[Encoder] | None = None, + # encoders: List[Encoder] | None = None, ): + super().__init__() self._exp_iter = ExperimentAxisQueryIterable( experiment=experiment, measurement_name=measurement_name, @@ -555,21 +409,17 @@ def __init__( seed=seed, io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, - encoders=encoders, shuffle_chunk_size=shuffle_chunk_size, # --- partition=True, ) - def __iter__(self) -> Iterator[XObsTensorDatum]: + def __iter__(self) -> Iterator[XObsDatum]: batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: - X_tensor, obs_tensor = torch.from_numpy(X), torch.from_numpy(obs.to_numpy()) if batch_size == 1: - X_tensor = X_tensor[0] - obs_tensor = obs_tensor[0] - - yield X_tensor, obs_tensor + X = X[0] + yield X, obs def __len__(self) -> int: return self._exp_iter.__len__() @@ -578,14 +428,12 @@ def __len__(self) -> int: def shape(self) -> Tuple[int, int]: return self._exp_iter.shape - @property - def encoders(self) -> Encoders: - return self._exp_iter.encoders - class ExperimentAxisQueryIterableDataset( - torch.utils.data.IterableDataset[XObsNpDatum] # type:ignore[misc] + torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] ): + # TODO: XXX docstrings + def __init__( self, experiment: soma.Experiment, @@ -593,15 +441,15 @@ def __init__( X_name: str = "raw", obs_query: soma.AxisQuery | None = None, var_query: soma.AxisQuery | None = None, - obs_column_names: Sequence[str] = (), + obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, # XXX add docstring noting values >1 will not work with default collator shuffle: bool = True, seed: int | None = None, io_batch_size: int = 2**17, shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, - encoders: List[Encoder] | None = None, ): + super().__init__() self._exp_iter = ExperimentAxisQueryIterable( experiment=experiment, measurement_name=measurement_name, @@ -614,21 +462,17 @@ def __init__( seed=seed, io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, - encoders=encoders, shuffle_chunk_size=shuffle_chunk_size, # --- partition=True, ) - def __iter__(self) -> Iterator[XObsNpDatum]: + def __iter__(self) -> Iterator[XObsDatum]: batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: - obs_np: NDArrayNumber = obs.to_numpy() if batch_size == 1: X = X[0] - obs_np = obs_np[0] - - yield X, obs_np + yield X, obs def __len__(self) -> int: return self._exp_iter.__len__() @@ -637,21 +481,10 @@ def __len__(self) -> int: def shape(self) -> Tuple[int, int]: return self._exp_iter.shape - @property - def encoders(self) -> Encoders: - return self._exp_iter.encoders - -def _collate_ndarray_to_tensor( - datum: Sequence[NDArrayNumber | torch.Tensor], -) -> Tuple[torch.Tensor, ...]: - """Default torch.utils.data.DataLoader collate_fn for ``experiment_dataloader`` -- converts ndarray to a Tensor. - - Must be a top-level function to play nice with picking for multiprocessing use cases. - """ - return tuple( - torch.from_numpy(d) if not isinstance(d, torch.Tensor) else d for d in datum - ) +def _collate_noop(datum: _T) -> _T: + """Disable collation in dataloader instance.""" + return datum def experiment_dataloader( @@ -659,37 +492,10 @@ def experiment_dataloader( num_workers: int = 0, **dataloader_kwargs: Any, ) -> torch.utils.data.DataLoader: - """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a - :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`. - - Several :class:`torch.utils.data.DataLoader` parameters are disallowed as they are either non-performant and better - specified on the :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset` (``shuffle``, ``batch_size``) or are not compatible - (``sampler``, ``batch_sampler``, ``collate_fn``). - - Args: - datapipe: - An :class:`torch.util.data.IterableDataset`, which can be an - :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`, - :class:`tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` or any other - :class:`torch.util.data.IterableDataset` that has been chained to the - :class:`tiledbsoma.ml.ExperimentAxisQueryIterableDataset`. - num_workers: - Number of worker processes to use for data loading. If ``0``, data will be loaded in the main process. - **dataloader_kwargs: - Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, - except for ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, and ``collate_fn``, which are not - supported when using :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. - - Returns: - A :class:`torch.utils.data.DataLoader`. - - Raises: - ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, ``batch_sampler``, or ``collate_fn`` params - are passed as keyword arguments. - - Lifecycle: - experimental - """ + # TODO: XXX docstrings + + # XXX why do we disallow collate_fn? + unsupported_dataloader_args = [ "shuffle", "batch_size", @@ -707,10 +513,10 @@ def experiment_dataloader( return torch.utils.data.DataLoader( ds, - batch_size=None, # batching is handled by our ExperimentAxisQueryIterableDataset + batch_size=None, # batching is handled by upstream iterator num_workers=num_workers, - collate_fn=_collate_ndarray_to_tensor, - # shuffling is handled by ExperimentAxisQueryIterableDataset + collate_fn=_collate_noop, + # shuffling is handled by upstream iterator shuffle=False, **dataloader_kwargs, ) @@ -741,16 +547,6 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: return splits -def _D_IJ( - tbl: pa.Table, -) -> Tuple[NDArrayNumber, Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]]: - """Given SOMA-style Pyarrow Table of COO sparse array data, return tuple (D, (I, J)) vectors.""" - d = tbl["soma_data"].to_numpy() - i = tbl["soma_dim_0"].to_numpy() - j = tbl["soma_dim_1"].to_numpy() - return d, (i, j) - - if sys.version_info >= (3, 12): _batched = itertools.batched @@ -759,7 +555,7 @@ def _D_IJ( def _batched( iterable: Iterable[_T_co], n: int, *, strict: bool = False ) -> Iterator[Tuple[_T_co, ...]]: - """Same as the Python 3.12+ itertools.batched, but polyfilled for old implementations.""" + """Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions.""" if n < 1: raise ValueError("n must be at least one") it = iter(iterable) @@ -769,22 +565,6 @@ def _batched( yield batch -def _run_gc() -> Tuple[PsUtilMemInfo, PsUtilMemInfo, float]: - """Run Python GC and log stats.""" - proc = psutil.Process(os.getpid()) - - pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - start = time.time() - gc.collect() - gc_elapsed = time.time() - start - post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() - - pytorch_logger.debug(f"gc: pre={pre_gc}") - pytorch_logger.debug(f"gc: post={post_gc}") - - return pre_gc, post_gc, gc_elapsed - - def _get_torch_partition_info() -> Tuple[int, int]: """Return this workers partition and total partition count as a tuple. @@ -827,7 +607,7 @@ def _init_multiprocessing() -> None: orig_start_method = torch.multiprocessing.get_start_method() if orig_start_method != "spawn": if orig_start_method: - pytorch_logger.warning( + logger.warning( "switching torch multiprocessing start method from " f'"{torch.multiprocessing.get_start_method()}" to "spawn"' ) @@ -845,6 +625,10 @@ def _csr_to_dense_inner(indptr, indices, data, out): # type:ignore[no-untyped-d def _csr_to_dense(sp: sparse.csr_array | sparse.csr_matrix) -> NDArrayNumber: + """Fast, parallel, variant of scipy.sparse.csr_array.todense. + + Typically 4-8X faster, dending on host and size of array/matrix. + """ assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) return cast( NDArrayNumber, From d44fcae3de89d6393b34f8605d8442ac83abb74d Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 00:21:51 +0000 Subject: [PATCH 13/70] reorganize into separate python package --- CHANGELOG.md | 18 + README.md | 36 + pyproject.toml | 53 ++ .../ml => src/tiledbsoma_ml}/__init__.py | 8 +- .../ml => src/tiledbsoma_ml}/pytorch.py | 7 +- tests/test_pytorch.py | 757 ++++++++++++++++++ 6 files changed, 873 insertions(+), 6 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 README.md create mode 100644 pyproject.toml rename {apis/python/src/tiledbsoma/ml => src/tiledbsoma_ml}/__init__.py (66%) rename {apis/python/src/tiledbsoma/ml => src/tiledbsoma_ml}/pytorch.py (99%) create mode 100644 tests/test_pytorch.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..2254e83 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,18 @@ + +# Change Log +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/) +and this project adheres to [Semantic Versioning](http://semver.org/). + +## [Unreleased] - yyyy-mm-dd + +Porting and enhancing initial code contribution from the Chan Zuckerberg Initiative Foundation +[CELLxGENE](https://cellxgene.cziscience.com/) project. + +### Added + +### Changed + +### Fixed + diff --git a/README.md b/README.md new file mode 100644 index 0000000..60b5bb5 --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ + +# tiledbsoma-ml + +A Python package containing ML tools for use with `tiledbsoma`. + +## Description + +The package currently contains a prototype PyTorch `IterableDataset` for use with the +[`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) +API. + +## Getting Started + +### Installing + +Install using your favorite package installer. For exapmle, with pip: + +> pip install tiledbsoma-ml + + +### Documentation + +TBD + +## Version History + +See the [CHANGELOG.md](CHANGELOG.md) file. + +## License + +This project is licensed under the MIT License. + +## Acknowledgements + +The SOMA team is grateful to the Chan Zuckerberg Initiative Foundation [CELLxGENE Census](https://cellxgene.cziscience.com) +team for their initial contribution. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..09b8833 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "tiledbsoma-ml" +dynamic = ["version"] +dependencies = [ + "attrs", + "tiledbsoma", + "torch", + "torchdata<=0.9", +] +requires-python = ">= 3.8" +description = "Machine learning tools for use with tiledbsoma" +readme = "README.md" +authors = [ + {name = "TileDB, Inc.", email = "help@tiledb.io"}, + {name = "The Chan Zuckerberg Initiative Foundation", email = "soma@chanzuckerberg.com" }, +] +maintainers = [ + {name = "TileDB, Inc.", email="help@tiledb.io"}, +] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Operating System :: Unix", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +[project.urls] +Repository = "https://github.com/single-cell-data/TileDB-SOMA.git" +Issues = "https://github.com/single-cell-data/TileDB-SOMA/issues" +Changelog = "https://github.com/single-cell-data/TileDB-SOMA/blob/main/CHANGELOG.md" + +[tool.setuptools.dynamic] +version = {attr = "tiledbsoma_ml.__version__"} + +[tool.setuptools_scm] +root = "../../.." diff --git a/apis/python/src/tiledbsoma/ml/__init__.py b/src/tiledbsoma_ml/__init__.py similarity index 66% rename from apis/python/src/tiledbsoma/ml/__init__.py rename to src/tiledbsoma_ml/__init__.py index 857f30f..e610573 100644 --- a/apis/python/src/tiledbsoma/ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -3,20 +3,18 @@ # # Licensed under the MIT License. -"""An API to facilitate use of PyTorch ML training with data from SOMA data.""" +"""An API to support machine learning applications built on SOMA.""" -from .encoders import BatchEncoder, Encoder, LabelEncoder from .pytorch import ( ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset, experiment_dataloader, ) +__version__ = "0.1.0-dev" + __all__ = [ "ExperimentAxisQueryDataPipe", "ExperimentAxisQueryIterableDataset", "experiment_dataloader", - "Encoder", - "LabelEncoder", - "BatchEncoder", ] diff --git a/apis/python/src/tiledbsoma/ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py similarity index 99% rename from apis/python/src/tiledbsoma/ml/pytorch.py rename to src/tiledbsoma_ml/pytorch.py index 44dc028..85fa3dc 100644 --- a/apis/python/src/tiledbsoma/ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -1,3 +1,8 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + from __future__ import annotations import gc @@ -35,7 +40,7 @@ import tiledbsoma as soma -logger = logging.getLogger("tiledbsoma.ml.pytorch") +logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py new file mode 100644 index 0000000..d3b8fd8 --- /dev/null +++ b/tests/test_pytorch.py @@ -0,0 +1,757 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import pathlib +from typing import Callable, List, Optional, Sequence, Union +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from scipy import sparse +from scipy.sparse import coo_matrix, spmatrix + +import tiledbsoma as soma +from tiledbsoma import Experiment, _factory +from tiledbsoma._collection import CollectionBase + +# conditionally import torch, as it will not be available in all test environments. +# This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent +# tests. +try: + from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryDataPipe, + ExperimentAxisQueryIterableDataset, + experiment_dataloader, + ) + from torch.utils.data._utils.worker import WorkerInfo +except ImportError: + # this should only occur when not running `ml`-marked tests + pass + + +def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + occupied_shape = ( + obs_range.stop - obs_range.start, + var_range.stop - var_range.start, + ) + checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) + checkerboard_of_ones.row += obs_range.start + checkerboard_of_ones.col += var_range.start + return checkerboard_of_ones + + +def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + """A sparse matrix where the values of each col are the obs_range values. Useful for checking the + X values are being returned in the correct order.""" + data = np.vstack([list(obs_range)] * len(var_range)).flatten() + rows = np.vstack([list(obs_range)] * len(var_range)).flatten() + cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() + return coo_matrix((data, (rows, cols))) + + +@pytest.fixture +def X_layer_names() -> List[str]: + return ["raw"] + + +@pytest.fixture +def obsp_layer_names() -> Optional[List[str]]: + return None + + +@pytest.fixture +def varp_layer_names() -> Optional[List[str]]: + return None + + +def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ("label2", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": list(value_range), + "label": [str(i) for i in value_range], + "label2": ["c" for i in value_range], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, + key: str, + obs_range: range, + var_range: range, + value_gen: Callable[[range, range], spmatrix], +) -> None: + a = coll.add_new_sparse_ndarray( + key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) + ) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) + a.write(tensor) + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: pathlib.Path, + obs_range: Union[int, range], + var_range: Union[int, range], + X_value_gen: Callable[[range, range], sparse.spmatrix], + obsp_layer_names: Sequence[str], + varp_layer_names: Sequence[str], +) -> soma.Experiment: + with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: + if isinstance(obs_range, int): + obs_range = range(obs_range) + if isinstance(var_range, int): + var_range = range(var_range) + + add_dataframe(exp, "obs", obs_range) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", soma.Measurement) + add_dataframe(rna, "var", var_range) + rna_x = rna.add_new_collection("X", soma.Collection) + add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) + + if obsp_layer_names: + obsp = rna.add_new_collection("obsp") + for obsp_layer_name in obsp_layer_names: + add_sparse_array( + obsp, obsp_layer_name, obs_range, var_range, X_value_gen + ) + + if varp_layer_names: + varp = rna.add_new_collection("varp") + for varp_layer_name in varp_layer_names: + add_sparse_array( + varp, varp_layer_name, obs_range, var_range, X_value_gen + ) + return _factory.open((tmp_path / "exp").as_posix()) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_non_batched( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + # batch_size should default to 1 + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + assert type(exp_data_pipe.shape) is tuple + assert len(exp_data_pipe.shape) == 2 + assert exp_data_pipe.shape == (6, 3) + + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + assert isinstance(row[0], np.ndarray) + assert isinstance(row[1], pd.DataFrame) + assert row[0].shape == (3,) + assert row[1].shape == (1, 1) + assert row[0].tolist() == [0, 1, 0] + assert row[1].keys() == ["label"] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_uneven_soma_and_result_batches( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + """This is checking that batches are correctly created when they require fetching multiple chunks.""" + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + shuffle=False, + batch_size=3, + io_batch_size=2, + use_eager_fetch=use_eager_fetch, + ) + row_iter = iter(exp_data_pipe) + + X_batch, obs_batch = next(row_iter) + assert isinstance(X_batch, np.ndarray) + assert isinstance(obs_batch, pd.DataFrame) + assert X_batch.shape[0] == obs_batch.shape[0] + assert X_batch.shape == (3, 3) + assert obs_batch.shape == (3, 1) + assert X_batch[0].tolist() == [0, 1, 0] + assert ["label"] == obs_batch.keys() + assert obs_batch["label"].tolist() == ["0", "1", "2"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__all_batches_full_size( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["3", "4", "5"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (range(100_000_000, 100_000_003), 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_unique_soma_joinids( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["soma_joinid", "label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] + ) + assert len(np.unique(soma_joinids)) == len(soma_joinids) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__partial_final_batch_size( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + next(batch_iter) + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__exactly_one_batch( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__empty_query_result( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_query=soma.AxisQuery(coords=([],)), + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 1, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_batching__partial_soma_batches_are_concatenated( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + exp_data_pipe = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches + io_batch_size=4, + use_eager_fetch=use_eager_fetch, + ) + + full_result = list(exp_data_pipe) + + assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_multiprocessing__returns_full_result( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + """Tests the ExperimentAxisQueryDataPipe provides all data, as collected from multiple processes that are managed by a + PyTorch DataLoader with multiple workers configured.""" + + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["soma_joinid", "label"], + io_batch_size=3, # two chunks, one per worker + ) + # Note we're testing the ExperimentAxisQueryDataPipe via a DataLoader, since this is what sets up the multiprocessing + dl = experiment_dataloader(dp, num_workers=2) + + full_result = list(iter(dl)) + + soma_joinids = np.concatenate([t[1]["soma_joinid"].to_numpy() for t in full_result]) + assert sorted(soma_joinids) == list(range(6)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_distributed__returns_data_partition_for_rank( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, + using mocks to avoid having to do real PyTorch distributed setup.""" + + with patch("torch.distributed.is_initialized") as mock_dist_is_initialized, patch( + "torch.distributed.get_rank" + ) as mock_dist_get_rank, patch( + "torch.distributed.get_world_size" + ) as mock_dist_get_world_size: + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = 1 + mock_dist_get_world_size.return_value = 3 + + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) + full_result = list(iter(dp)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + # Of the 6 obs rows, the PyTorch process of rank 1 should get [2, 3] + # (rank 0 gets [0, 1], rank 2 gets [4, 5]) + assert sorted(soma_joinids) == [2, 3] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(12, 3, pytorch_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_distributed_and_multiprocessing__returns_data_partition_for_rank( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and + DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch + setup or real DataLoader multiprocessing.""" + + with patch("torch.utils.data.get_worker_info") as mock_get_worker_info, patch( + "torch.distributed.is_initialized" + ) as mock_dist_is_initialized, patch( + "torch.distributed.get_rank" + ) as mock_dist_get_rank, patch( + "torch.distributed.get_world_size" + ) as mock_dist_get_world_size: + mock_get_worker_info.return_value = WorkerInfo(id=1, num_workers=2, seed=1234) + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = 1 + mock_dist_get_world_size.return_value = 3 + + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) + + full_result = list(iter(dp)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + # Of the 12 obs rows, the PyTorch process of rank 1 should get [4..7], and then within that partition, + # the 2nd DataLoader process should get the second half of the rank's partition, which is just [6, 7] + # (rank 0 gets [0..3], rank 2 gets [8..11]) + assert sorted(soma_joinids) == [6, 7] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__non_batched( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + assert all(d[0].shape == (3,) for d in data) + assert all(d[1].shape == (1, 1) for d in data) + + row = data[0] + assert row[0].tolist() == [0, 1, 0] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__batched( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + + batch = data[0] + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].to_numpy().tolist() == [[0], [1], [2]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__batched_length( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + assert len(dl) == len(list(dl)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test__X_tensor_dtype_matches_X_matrix( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + data = next(iter(dp)) + + assert data[0].dtype == np.float32 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)] +) +def test__pytorch_splitting( + soma_experiment: Experiment, +) -> None: + dp = ExperimentAxisQueryDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + ) + # function not available for IterableDataset, yet.... + dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=1234) + dl = experiment_dataloader(dp_train) + + all_rows = list(iter(dl)) + assert len(all_rows) == 7 + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test__shuffle( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, +) -> None: + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + shuffle=True, + ) + + all_rows = list(iter(dp)) + assert all(r[0].shape == (1,) for r in all_rows) + soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] + X_values = [row[0][0].item() for row in all_rows] + + # same elements + assert set(soma_joinids) == set(range(16)) + # not ordered! (...with a `1/16!` probability of being ordered) + assert soma_joinids != list(range(16)) + # randomizes X in same order as obs + # note: X values were explicitly set to match obs_joinids to allow for this simple assertion + assert X_values == soma_joinids + + +def test_experiment_dataloader__unsupported_params__fails() -> None: + with patch( + "tiledbsoma_ml.pytorch.ExperimentAxisQueryDataPipe" + ) as dummy_exp_data_pipe: + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, shuffle=True) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_size=3) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, sampler=[]) + with pytest.raises(ValueError): + experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x) + + +def test_batched() -> None: + from tiledbsoma_ml.pytorch import _batched + + assert list(_batched(range(6), 1)) == list((i,) for i in range(6)) + assert list(_batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] + assert list(_batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] + assert list(_batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] + assert list(_batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] + assert list(_batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] + assert list(_batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] + + # bogus batch value + with pytest.raises(ValueError): + list(_batched([0, 1], 0)) + with pytest.raises(ValueError): + list(_batched([2, 3], -1)) + + # strict enforcement + with pytest.raises(ValueError): + list(_batched([0, 1, 2], 2, strict=True)) + + +def test_splits() -> None: + from tiledbsoma_ml.pytorch import _splits + + assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 3).tolist() == [0, 4, 7, 10] + assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] + assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert _splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] + + # bad number of sections + with pytest.raises(ValueError): + _splits(10, 0) + with pytest.raises(ValueError): + _splits(10, -1) + + +def test_csr_to_dense() -> None: + from tiledbsoma_ml.pytorch import _csr_to_dense + + coo = sparse.eye(1001, 77, format="coo", dtype=np.float32) + + assert np.array_equal( + sparse.csr_array(coo).todense(), _csr_to_dense(sparse.csr_array(coo)) + ) + assert np.array_equal( + sparse.csr_matrix(coo).todense(), _csr_to_dense(sparse.csr_matrix(coo)) + ) + + csr = sparse.csr_array(coo) + assert np.array_equal(csr.todense(), _csr_to_dense(csr)) + assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :])) + assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:])) + assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22])) From 8840983d49653cf21b06082e318457dc844517f8 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 00:52:33 +0000 Subject: [PATCH 14/70] fix name --- .github/workflows/python-tiledbsoma-ml.yml | 59 +++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 9b6911d..0990a51 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,10 +1,65 @@ name: python-tiledbsoma-ml on: + pull_request: + branches: ["*"] + paths: + - !'**' + - 'other_packages/python/tiledbsoma_ml/**' + + push: + branches: [main] + paths: + - !'**' + - 'other_packages/python/tiledbsoma_ml/**' + workflow_dispatch: jobs: - job: + lint: + runs-on: ubuntu-latest + env: + PYTHON_VERSION: "3.8" + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Restore pre-commit cache + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: Install pre-commit + run: pip -v install pre-commit + + - name: Run pre-commit hooks on all files + run: pre-commit run -v --files echo `git ls-files other_packages/python/tiledbsoma_ml` + + tests: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: python-spec/requirements-py${{ matrix.python-version }}.txt + + - name: Install prereqs + run: | + pip install --upgrade pip wheel pytest pytest-cov setuptools + pip install . + + - name: Run tests + working-directory: ./other_packages/python/tiledbsoma_ml/ + run: | + pytest -s -v tests/ From bb70d6d810ab44ef67fa9220530faec9d94114b3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 00:54:47 +0000 Subject: [PATCH 15/70] add more paths to CI --- .github/workflows/python-tiledbsoma-ml.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 0990a51..2272455 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -6,12 +6,14 @@ on: paths: - !'**' - 'other_packages/python/tiledbsoma_ml/**' + - '.github/workflows/python-tiledbsoma-ml.yml' push: branches: [main] paths: - !'**' - 'other_packages/python/tiledbsoma_ml/**' + - '.github/workflows/python-tiledbsoma-ml.yml' workflow_dispatch: From 8e51344ac50addded20e982ca0930886c72928fe Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 00:58:17 +0000 Subject: [PATCH 16/70] fix typo in ci --- .github/workflows/python-tiledbsoma-ml.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 2272455..0d95685 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -4,7 +4,7 @@ on: pull_request: branches: ["*"] paths: - - !'**' + - '!**' - 'other_packages/python/tiledbsoma_ml/**' - '.github/workflows/python-tiledbsoma-ml.yml' From 6714a8965d7b91e1ebf2f9ffc42bf6112af43e56 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 00:59:14 +0000 Subject: [PATCH 17/70] fix a second typo in ci --- .github/workflows/python-tiledbsoma-ml.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 0d95685..4a9370c 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -11,7 +11,7 @@ on: push: branches: [main] paths: - - !'**' + - '!**' - 'other_packages/python/tiledbsoma_ml/**' - '.github/workflows/python-tiledbsoma-ml.yml' From 5505c286db18c0a3541e31391ee8fd2fdeb61d01 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 18:09:23 -0700 Subject: [PATCH 18/70] set working dir in CI --- .github/workflows/python-tiledbsoma-ml.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 4a9370c..898bd14 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -57,6 +57,7 @@ jobs: cache-dependency-path: python-spec/requirements-py${{ matrix.python-version }}.txt - name: Install prereqs + working-directory: ./other_packages/python/tiledbsoma_ml/ run: | pip install --upgrade pip wheel pytest pytest-cov setuptools pip install . From 6090af21813c6cf1ff40944fd8e1be16eaf58da5 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 18:29:15 -0700 Subject: [PATCH 19/70] make batched 3.12 compat --- src/tiledbsoma_ml/pytorch.py | 6 +----- tests/test_pytorch.py | 4 ---- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 85fa3dc..61cfb81 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -557,16 +557,12 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: else: - def _batched( - iterable: Iterable[_T_co], n: int, *, strict: bool = False - ) -> Iterator[Tuple[_T_co, ...]]: + def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: """Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions.""" if n < 1: raise ValueError("n must be at least one") it = iter(iterable) while batch := tuple(islice(it, n)): - if strict and len(batch) != n: - raise ValueError("batched(): incomplete batch") yield batch diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index d3b8fd8..e847771 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -717,10 +717,6 @@ def test_batched() -> None: with pytest.raises(ValueError): list(_batched([2, 3], -1)) - # strict enforcement - with pytest.raises(ValueError): - list(_batched([0, 1, 2], 2, strict=True)) - def test_splits() -> None: from tiledbsoma_ml.pytorch import _splits From c9789f9fcf467784b907a76800dd3b2dfc99b235 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 18:30:19 -0700 Subject: [PATCH 20/70] debugging pre-commit failure --- .github/workflows/python-tiledbsoma-ml.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 898bd14..6ce2f27 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -39,7 +39,7 @@ jobs: run: pip -v install pre-commit - name: Run pre-commit hooks on all files - run: pre-commit run -v --files echo `git ls-files other_packages/python/tiledbsoma_ml` + run: pre-commit run -v -a tests: runs-on: ubuntu-latest From 33f3c2d7fbb4c05fe82b109f3826da45b917c7c4 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 19:06:11 -0700 Subject: [PATCH 21/70] lint, lint, lint --- pyproject.toml | 26 ++++++++++++++++++++++++++ src/tiledbsoma_ml/pytorch.py | 11 ++++++----- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 09b8833..8e8a40e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,29 @@ version = {attr = "tiledbsoma_ml.__version__"} [tool.setuptools_scm] root = "../../.." + +[tool.mypy] +show_error_codes = true +ignore_missing_imports = true +warn_unreachable = true +strict = true +python_version = 3.8 +plugins = "numpy.typing.mypy_plugin" + +[tool.ruff] +lint.select = ["E", "F", "B", "I"] +lint.ignore = ["E501"] # line too long +lint.extend-select = ["I001"] # unsorted-imports +fix = true +target-version = "py38" +line-length = 120 + + +[tool.ruff.lint.isort] +# HACK: tiledb needs to come after tiledbsoma: https://github.com/single-cell-data/TileDB-SOMA/issues/2293 +section-order = ["future", "standard-library", "third-party", "tiledbsoma", "tiledb", "first-party", "local-folder"] +no-lines-before = ["tiledb"] + +[tool.ruff.lint.isort.sections] +"tiledbsoma" = ["tiledbsoma"] +"tiledb" = ["tiledb"] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 61cfb81..d055656 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -37,6 +37,7 @@ import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator +from typing_extensions import TypeAlias import tiledbsoma as soma @@ -46,12 +47,12 @@ _T_co = TypeVar("_T_co", covariant=True) if TYPE_CHECKING: - NDArrayNumber = npt.NDArray[np.number[Any]] + NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] else: - NDArrayNumber = np.ndarray -XObsDatum = Tuple[NDArrayNumber, pd.DataFrame] -XObsNpDatum = Tuple[NDArrayNumber, NDArrayNumber] -XObsTensorDatum = Tuple[torch.Tensor, torch.Tensor] + NDArrayNumber: TypeAlias = np.ndarray +XObsDatum: TypeAlias = Tuple[NDArrayNumber, pd.DataFrame] +XObsNpDatum: TypeAlias = Tuple[NDArrayNumber, NDArrayNumber] +XObsTensorDatum: TypeAlias = Tuple[torch.Tensor, torch.Tensor] """Return type of ``ExperimentAxisQueryDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" From 467bb1570a88d5888cac5f1361d3499781408f1a Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 19:50:03 -0700 Subject: [PATCH 22/70] more CI debugging --- .github/workflows/python-tiledbsoma-ml.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 6ce2f27..1287208 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -21,7 +21,7 @@ jobs: lint: runs-on: ubuntu-latest env: - PYTHON_VERSION: "3.8" + PYTHON_VERSION: "3.11" steps: - uses: actions/checkout@v4 From 52efb62965b26b86e0e8608a37ffd72f123d2298 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 24 Aug 2024 19:56:34 -0700 Subject: [PATCH 23/70] add build test to CI --- .github/workflows/python-tiledbsoma-ml.yml | 32 ++++++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 1287208..89c4ae0 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -4,30 +4,28 @@ on: pull_request: branches: ["*"] paths: - - '!**' - - 'other_packages/python/tiledbsoma_ml/**' - - '.github/workflows/python-tiledbsoma-ml.yml' + - "!**" + - "other_packages/python/tiledbsoma_ml/**" + - ".github/workflows/python-tiledbsoma-ml.yml" push: branches: [main] paths: - - '!**' - - 'other_packages/python/tiledbsoma_ml/**' - - '.github/workflows/python-tiledbsoma-ml.yml' + - "!**" + - "other_packages/python/tiledbsoma_ml/**" + - ".github/workflows/python-tiledbsoma-ml.yml" workflow_dispatch: jobs: lint: runs-on: ubuntu-latest - env: - PYTHON_VERSION: "3.11" steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: ${{ env.PYTHON_VERSION }} + python-version: "3.11" - name: Restore pre-commit cache uses: actions/cache@v4 @@ -66,3 +64,19 @@ jobs: working-directory: ./other_packages/python/tiledbsoma_ml/ run: | pytest -s -v tests/ + + build: + # for now, just do a test build to ensure that it works + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Set up environment + working-directory: ./other_packages/python/tiledbsoma_ml/ + run: | + pip install --upgrade build pip wheel setuptools setuptools-scm + python -m build . From 6ab8334be5802f107f29dabf82378a14021d8483 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 07:44:42 -0700 Subject: [PATCH 24/70] add code coverage --- .github/workflows/python-tiledbsoma-ml.yml | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 89c4ae0..af67d28 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -61,9 +61,22 @@ jobs: pip install . - name: Run tests - working-directory: ./other_packages/python/tiledbsoma_ml/ run: | - pytest -s -v tests/ + PYTHONPATH=$(pwd)/other_packages/python/tiledbsoma_ml python -m pytest \ + --cov=other_packages/python/tiledbsoma_ml/src \ + --cov-report=xml other_packages/python/tiledbsoma_ml/tests \ + -v + + - name: Report coverage to Codecov + uses: codecov/codecov-action@v4 + with: + flags: python + # Although Codecov isn't supposed to require an auth token for public repos like this one, + # the uploader can be unreliable without one; see + # https://github.com/codecov/codecov-action/issues/557#issuecomment-1216749652 + # As of this writing (8 Nov 2022) the CODECOV_TOKEN was generated by @aaronwolen in his + # Codecov settings page for this repo, then filled into the GitHub Actions secrets. + token: ${{ secrets.CODECOV_TOKEN }} build: # for now, just do a test build to ensure that it works From 4df5049f8e5ea2a9dc2f08810d587c078063c8fb Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 08:11:48 -0700 Subject: [PATCH 25/70] update GHA --- .github/workflows/python-tiledbsoma-ml.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index af67d28..2ab57ba 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -68,6 +68,7 @@ jobs: -v - name: Report coverage to Codecov + if: ${{ matrix.python-version == '3.11' }} uses: codecov/codecov-action@v4 with: flags: python @@ -88,7 +89,7 @@ jobs: with: python-version: "3.11" - - name: Set up environment + - name: Do build working-directory: ./other_packages/python/tiledbsoma_ml/ run: | pip install --upgrade build pip wheel setuptools setuptools-scm From 1b31d3229750cb72bddd4586194d5f78b3822ca7 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 08:17:48 -0700 Subject: [PATCH 26/70] test TypeAlias --- src/tiledbsoma_ml/pytorch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index d055656..dd9b651 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -15,7 +15,6 @@ from itertools import islice from math import ceil from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, @@ -46,10 +45,11 @@ _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) -if TYPE_CHECKING: - NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] -else: - NDArrayNumber: TypeAlias = np.ndarray +# if TYPE_CHECKING: +# NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] +# else: +# NDArrayNumber: TypeAlias = np.ndarray +NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] XObsDatum: TypeAlias = Tuple[NDArrayNumber, pd.DataFrame] XObsNpDatum: TypeAlias = Tuple[NDArrayNumber, NDArrayNumber] XObsTensorDatum: TypeAlias = Tuple[torch.Tensor, torch.Tensor] From 71be8029312e116cabf9443316c28f62672ca8af Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:11:14 -0700 Subject: [PATCH 27/70] add missing dependencies --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8e8a40e..6bf4654 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,11 @@ dependencies = [ "tiledbsoma", "torch", "torchdata<=0.9", + "numpy", + "numba", + "pandas", + "pyarrow", + "scipy" ] requires-python = ">= 3.8" description = "Machine learning tools for use with tiledbsoma" From a0a834480eed13f540e2b28a89baa41976883f7e Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:11:26 -0700 Subject: [PATCH 28/70] extend tests --- tests/test_pytorch.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index e847771..2063687 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -27,6 +27,7 @@ from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterable, experiment_dataloader, ) from torch.utils.data._utils.worker import WorkerInfo @@ -684,6 +685,30 @@ def test__shuffle( assert X_values == soma_joinids +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +def test_experiment_axis_query_iterable_error_checks(soma_experiment: Experiment) -> None: + dp = ExperimentAxisQueryIterable( + soma_experiment, + measurement_name="RNA", + X_name="raw", + shuffle=True, + ) + with pytest.raises(NotImplementedError): + dp[0] + + with pytest.raises(ValueError): + dp = ExperimentAxisQueryIterable( + soma_experiment, + obs_column_names=(), + measurement_name="RNA", + X_name="raw", + shuffle=True, + ) + + + def test_experiment_dataloader__unsupported_params__fails() -> None: with patch( "tiledbsoma_ml.pytorch.ExperimentAxisQueryDataPipe" From 13da9c70011d6cd5dc3330774844d8abab766cc2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:11:48 -0700 Subject: [PATCH 29/70] remove coverage reporting from CI for now --- .github/workflows/python-tiledbsoma-ml.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 2ab57ba..6d3e846 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -67,17 +67,17 @@ jobs: --cov-report=xml other_packages/python/tiledbsoma_ml/tests \ -v - - name: Report coverage to Codecov - if: ${{ matrix.python-version == '3.11' }} - uses: codecov/codecov-action@v4 - with: - flags: python - # Although Codecov isn't supposed to require an auth token for public repos like this one, - # the uploader can be unreliable without one; see - # https://github.com/codecov/codecov-action/issues/557#issuecomment-1216749652 - # As of this writing (8 Nov 2022) the CODECOV_TOKEN was generated by @aaronwolen in his - # Codecov settings page for this repo, then filled into the GitHub Actions secrets. - token: ${{ secrets.CODECOV_TOKEN }} + # - name: Report coverage to Codecov + # if: ${{ matrix.python-version == '3.11' }} + # uses: codecov/codecov-action@v4 + # with: + # flags: python + # # Although Codecov isn't supposed to require an auth token for public repos like this one, + # # the uploader can be unreliable without one; see + # # https://github.com/codecov/codecov-action/issues/557#issuecomment-1216749652 + # # As of this writing (8 Nov 2022) the CODECOV_TOKEN was generated by @aaronwolen in his + # # Codecov settings page for this repo, then filled into the GitHub Actions secrets. + # token: ${{ secrets.CODECOV_TOKEN }} build: # for now, just do a test build to ensure that it works From 7074e99a03b2f405ff83ce60d26d24717608a1df Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:12:05 -0700 Subject: [PATCH 30/70] docstrings --- src/tiledbsoma_ml/pytorch.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index dd9b651..e2f002c 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -15,6 +15,7 @@ from itertools import islice from math import ceil from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, @@ -45,16 +46,18 @@ _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) -# if TYPE_CHECKING: -# NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] -# else: -# NDArrayNumber: TypeAlias = np.ndarray -NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] +if TYPE_CHECKING: + # Python 3.8 does not support subscripting types, so work-around by + # restricting this to when we are running a type checker. TODO: remove + # the conditional when Python 3.8 support is dropped. + NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] +else: + NDArrayNumber: TypeAlias = np.ndarray + XObsDatum: TypeAlias = Tuple[NDArrayNumber, pd.DataFrame] -XObsNpDatum: TypeAlias = Tuple[NDArrayNumber, NDArrayNumber] -XObsTensorDatum: TypeAlias = Tuple[torch.Tensor, torch.Tensor] -"""Return type of ``ExperimentAxisQueryDataPipe`` that pairs a Tensor of ``obs`` row(s) with a Tensor of ``X`` matrix row(s). -The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" +"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, +which pairs a NumPy ndarray of ``X`` row(s) with a Pandas DataFrame of ``obs`` row(s). If the +``batch_size`` is 1, the objects are of rank 1, else they are of rank 2.""" @attrs.define(frozen=True, kw_only=True) From 47d705235cdaa266c5c6a4da87ab0fd81df20574 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:12:16 -0700 Subject: [PATCH 31/70] more file organization --- notebooks/turtorial_pytorch.ipynb | 619 ++++++++++++++++++++++++++++++ 1 file changed, 619 insertions(+) create mode 100644 notebooks/turtorial_pytorch.ipynb diff --git a/notebooks/turtorial_pytorch.ipynb b/notebooks/turtorial_pytorch.ipynb new file mode 100644 index 0000000..533f9cc --- /dev/null +++ b/notebooks/turtorial_pytorch.ipynb @@ -0,0 +1,619 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a PyTorch Model\n", + "\n", + "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryDataPipe`, and not as an example of how to train a biologically useful model.\n", + "\n", + "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", + "\n", + "**Prerequisites**\n", + "\n", + "Install `tiledbsoma` with the optional `ml` dependencies, for example:\n", + "\n", + "> pip install tiledbsoma[ml]\n", + "\n", + "\n", + "**Contents**\n", + "\n", + "* [Create a DataLoader](#Create-a-DataLoader)\n", + "* [Define the model](#Define-the-model)\n", + "* [Train the model](#Train-the-model)\n", + "* [Make predictions with the model](#Make-predictions-with-the-model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an ExperimentAxisQueryDataPipe\n", + "\n", + "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", + "\n", + "We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/miniforge3/envs/tiledbsoma-dev/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma as soma\n", + "import tiledbsoma.ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = (\n", + " \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + ")\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL, context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"})\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + "\n", + "experiment_dataset = soma_ml.ExperimentAxisQueryDataPipe(\n", + " experiment,\n", + " measurement_name=\"RNA\",\n", + " X_name=\"raw\",\n", + " obs_query=obs_query,\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + ")\n", + "\n", + "with experiment.axis_query(measurement_name=\"RNA\", obs_query=obs_query) as query:\n", + " obs_df = query.obs(column_names=['cell_type']).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df['cell_type'].unique())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `ExperimentAxisQueryDataPipe` class explained\n", + "\n", + "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", + "\n", + "### `ExperimentAxisQueryDataPipe` parameters explained\n", + "\n", + "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", + "\n", + "To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.\n", + "\n", + "The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.\n", + "\n", + "The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).\n", + "\n", + "The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:\n", + "* You should use this flag instead of the `DataLoader` `shuffle` flag, primarily for performance reasons.\n", + "* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterableDataset`. However, the `Shuffler` will not \"globally\" randomize the training data, as it only \"locally\" randomizes the ordering of the training data within fixed-size \"windows\". Due to the layout of Census data, a given \"window\" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can inspect the shape of the full dataset, without causing the full dataset to be loaded:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(15020, 60530)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_dataset.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split the dataset\n", + "\n", + "You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the DataLoader\n", + "\n", + "With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(train_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the model\n", + "\n", + "With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we define a function to train the model for a single epoch:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + "\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + "\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", + " [0., 0., 2., ..., 0., 3., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 1., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 8.]])\n", + " \n", + "```\n", + "\n", + "For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([0., 0., 0., ..., 1., 0., 0.])\n", + "```\n", + " \n", + "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", + "\n", + "```\n", + "tensor([1, 1, 3, ..., 2, 1, 4])\n", + "\n", + "```\n", + "Note that cell type values are integer-encoded values, which can be decoded using `experiment_dataset.encoders` (more on this below).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model\n", + "\n", + "Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0162193 Accuracy 0.3910\n", + "Epoch 2: Train Loss: 0.0148544 Accuracy 0.5580\n", + "Epoch 3: Train Loss: 0.0143674 Accuracy 0.5863\n", + "Epoch 4: Train Loss: 0.0140639 Accuracy 0.7015\n", + "Epoch 5: Train Loss: 0.0138526 Accuracy 0.7417\n", + "Epoch 6: Train Loss: 0.0137076 Accuracy 0.7998\n", + "Epoch 7: Train Loss: 0.0136148 Accuracy 0.8783\n", + "Epoch 8: Train Loss: 0.0135245 Accuracy 0.8984\n", + "Epoch 9: Train Loss: 0.0134525 Accuracy 0.9051\n", + "Epoch 10: Train Loss: 0.0133846 Accuracy 0.9122\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "for epoch in range(10):\n", + " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", + " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make predictions with the model\n", + "\n", + "To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the \"test\" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(test_dataset)\n", + "X_batch, y_batch = next(iter(experiment_dataloader))\n", + "X_batch = torch.from_numpy(X_batch)\n", + "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we invoke the model on the `X_batch` input data and extract the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1, 1, 1, 8, 1, 7, 1, 5, 1, 1, 8, 1, 8, 8, 1, 8, 7, 5,\n", + " 1, 1, 1, 5, 1, 7, 7, 5, 1, 5, 7, 5, 8, 1, 5, 1, 7, 1,\n", + " 5, 7, 1, 1, 7, 8, 5, 8, 7, 8, 1, 7, 1, 8, 5, 1, 1, 5,\n", + " 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 1, 1, 1, 7, 8, 1, 7, 8,\n", + " 8, 1, 5, 1, 6, 1, 5, 1, 7, 1, 1, 1, 1, 7, 1, 7, 1, 1,\n", + " 1, 1, 1, 5, 1, 5, 11, 1, 1, 5, 5, 1, 1, 1, 1, 1, 1, 5,\n", + " 5, 1, 8, 8, 1, 9, 1, 1, 8, 8, 5, 5, 5, 5, 1, 7, 7, 1,\n", + " 1, 1])" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "\n", + "model.to(device)\n", + "outputs = model(X_batch.to(device))\n", + "\n", + "probabilities = torch.nn.functional.softmax(outputs, 1)\n", + "predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + "display(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", + "\n", + "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'leukocyte', 'basal cell', 'leukocyte', 'keratinocyte',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'basal cell', 'epithelial cell', 'keratinocyte',\n", + " 'epithelial cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'leukocyte', 'epithelial cell', 'leukocyte', 'keratinocyte',\n", + " 'leukocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", + " 'basal cell', 'keratinocyte', 'leukocyte', 'leukocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'fibroblast',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'vein endothelial cell', 'basal cell', 'basal cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'leukocyte',\n", + " 'leukocyte', 'basal cell', 'pericyte', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'leukocyte', 'epithelial cell', 'epithelial cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell'],\n", + " dtype=object)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", + "\n", + "display(predicted_cell_types)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we create a Pandas DataFrame to examine the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
actual cell typepredicted cell type
0basal cellbasal cell
1basal cellbasal cell
2basal cellbasal cell
3leukocyteleukocyte
4basal cellbasal cell
.........
123keratinocytekeratinocyte
124keratinocytekeratinocyte
125basal cellbasal cell
126basal cellbasal cell
127basal cellbasal cell
\n", + "

128 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " actual cell type predicted cell type\n", + "0 basal cell basal cell\n", + "1 basal cell basal cell\n", + "2 basal cell basal cell\n", + "3 leukocyte leukocyte\n", + "4 basal cell basal cell\n", + ".. ... ...\n", + "123 keratinocyte keratinocyte\n", + "124 keratinocyte keratinocyte\n", + "125 basal cell basal cell\n", + "126 basal cell basal cell\n", + "127 basal cell basal cell\n", + "\n", + "[128 rows x 2 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "display(\n", + " pd.DataFrame(\n", + " {\n", + " \"actual cell type\": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),\n", + " \"predicted cell type\": predicted_cell_types,\n", + " }\n", + " )\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tiledbsoma-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From c16b68d7b1d0ede1392749ba17f4d4433e0240de Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:21:44 -0700 Subject: [PATCH 32/70] add missing test --- tests/test_pytorch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 2063687..5e007dd 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -26,8 +26,8 @@ try: from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryDataPipe, - ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterable, + ExperimentAxisQueryIterableDataset, experiment_dataloader, ) from torch.utils.data._utils.worker import WorkerInfo @@ -688,7 +688,9 @@ def test__shuffle( @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] ) -def test_experiment_axis_query_iterable_error_checks(soma_experiment: Experiment) -> None: +def test_experiment_axis_query_iterable_error_checks( + soma_experiment: Experiment, +) -> None: dp = ExperimentAxisQueryIterable( soma_experiment, measurement_name="RNA", @@ -708,7 +710,6 @@ def test_experiment_axis_query_iterable_error_checks(soma_experiment: Experiment ) - def test_experiment_dataloader__unsupported_params__fails() -> None: with patch( "tiledbsoma_ml.pytorch.ExperimentAxisQueryDataPipe" From 2d0647625e677e060b4ed8e6857b2aae54d833a1 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 09:22:02 -0700 Subject: [PATCH 33/70] re-run notebook --- notebooks/turtorial_pytorch.ipynb | 172 ++++++++++++++---------------- 1 file changed, 78 insertions(+), 94 deletions(-) diff --git a/notebooks/turtorial_pytorch.ipynb b/notebooks/turtorial_pytorch.ipynb index 533f9cc..7dedfed 100644 --- a/notebooks/turtorial_pytorch.ipynb +++ b/notebooks/turtorial_pytorch.ipynb @@ -40,28 +40,12 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ubuntu/miniforge3/envs/tiledbsoma-dev/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", - "################################################################################\n", - "WARNING!\n", - "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", - "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", - "to learn more and leave feedback.\n", - "################################################################################\n", - "\n", - " deprecation_warning()\n" - ] - } - ], + "outputs": [], "source": [ + "import tiledbsoma_ml as soma_ml\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "import tiledbsoma as soma\n", - "import tiledbsoma.ml as soma_ml\n", "\n", "CZI_Census_Homo_Sapiens_URL = (\n", " \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", @@ -307,16 +291,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0162193 Accuracy 0.3910\n", - "Epoch 2: Train Loss: 0.0148544 Accuracy 0.5580\n", - "Epoch 3: Train Loss: 0.0143674 Accuracy 0.5863\n", - "Epoch 4: Train Loss: 0.0140639 Accuracy 0.7015\n", - "Epoch 5: Train Loss: 0.0138526 Accuracy 0.7417\n", - "Epoch 6: Train Loss: 0.0137076 Accuracy 0.7998\n", - "Epoch 7: Train Loss: 0.0136148 Accuracy 0.8783\n", - "Epoch 8: Train Loss: 0.0135245 Accuracy 0.8984\n", - "Epoch 9: Train Loss: 0.0134525 Accuracy 0.9051\n", - "Epoch 10: Train Loss: 0.0133846 Accuracy 0.9122\n" + "Epoch 1: Train Loss: 0.0160182 Accuracy 0.3806\n", + "Epoch 2: Train Loss: 0.0147012 Accuracy 0.4701\n", + "Epoch 3: Train Loss: 0.0143336 Accuracy 0.5381\n", + "Epoch 4: Train Loss: 0.0141242 Accuracy 0.5828\n", + "Epoch 5: Train Loss: 0.0139505 Accuracy 0.6105\n", + "Epoch 6: Train Loss: 0.0138496 Accuracy 0.6249\n", + "Epoch 7: Train Loss: 0.0137310 Accuracy 0.6665\n", + "Epoch 8: Train Loss: 0.0136376 Accuracy 0.7125\n", + "Epoch 9: Train Loss: 0.0135705 Accuracy 0.7954\n", + "Epoch 10: Train Loss: 0.0134742 Accuracy 0.8539\n" ] } ], @@ -374,14 +358,14 @@ { "data": { "text/plain": [ - "tensor([ 1, 1, 1, 8, 1, 7, 1, 5, 1, 1, 8, 1, 8, 8, 1, 8, 7, 5,\n", - " 1, 1, 1, 5, 1, 7, 7, 5, 1, 5, 7, 5, 8, 1, 5, 1, 7, 1,\n", - " 5, 7, 1, 1, 7, 8, 5, 8, 7, 8, 1, 7, 1, 8, 5, 1, 1, 5,\n", - " 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 1, 1, 1, 7, 8, 1, 7, 8,\n", - " 8, 1, 5, 1, 6, 1, 5, 1, 7, 1, 1, 1, 1, 7, 1, 7, 1, 1,\n", - " 1, 1, 1, 5, 1, 5, 11, 1, 1, 5, 5, 1, 1, 1, 1, 1, 1, 5,\n", - " 5, 1, 8, 8, 1, 9, 1, 1, 8, 8, 5, 5, 5, 5, 1, 7, 7, 1,\n", - " 1, 1])" + "tensor([ 1, 8, 11, 1, 8, 7, 5, 1, 1, 1, 1, 1, 8, 1, 5, 8, 7, 1,\n", + " 8, 8, 1, 1, 8, 1, 5, 1, 7, 8, 5, 1, 5, 5, 1, 1, 5, 7,\n", + " 8, 8, 1, 8, 1, 1, 1, 1, 5, 1, 1, 11, 1, 1, 5, 7, 7, 1,\n", + " 1, 5, 1, 1, 7, 1, 5, 5, 5, 7, 7, 1, 8, 1, 1, 7, 7, 7,\n", + " 8, 8, 5, 1, 1, 8, 1, 5, 5, 1, 6, 5, 1, 5, 8, 8, 1, 5,\n", + " 7, 1, 1, 5, 7, 1, 1, 7, 5, 5, 8, 1, 1, 1, 8, 5, 1, 7,\n", + " 1, 7, 8, 1, 1, 5, 5, 1, 1, 1, 1, 1, 7, 1, 9, 1, 5, 8,\n", + " 1, 7], device='cuda:0')" ] }, "metadata": {}, @@ -417,39 +401,39 @@ { "data": { "text/plain": [ - "array(['basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", - " 'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',\n", - " 'basal cell', 'basal cell', 'leukocyte', 'basal cell', 'leukocyte',\n", - " 'leukocyte', 'basal cell', 'leukocyte', 'keratinocyte',\n", - " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'epithelial cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", - " 'epithelial cell', 'basal cell', 'epithelial cell', 'keratinocyte',\n", - " 'epithelial cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", - " 'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'leukocyte', 'epithelial cell', 'leukocyte', 'keratinocyte',\n", - " 'leukocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'epithelial cell', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", - " 'basal cell', 'keratinocyte', 'leukocyte', 'leukocyte',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'fibroblast',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", + "array(['basal cell', 'leukocyte', 'vein endothelial cell', 'basal cell',\n", + " 'leukocyte', 'keratinocyte', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'epithelial cell', 'leukocyte',\n", + " 'keratinocyte', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", " 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'epithelial cell',\n", + " 'keratinocyte', 'leukocyte', 'leukocyte', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", " 'vein endothelial cell', 'basal cell', 'basal cell',\n", - " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'epithelial cell', 'epithelial cell', 'basal cell', 'leukocyte',\n", - " 'leukocyte', 'basal cell', 'pericyte', 'basal cell', 'basal cell',\n", - " 'leukocyte', 'leukocyte', 'epithelial cell', 'epithelial cell',\n", - " 'epithelial cell', 'epithelial cell', 'basal cell', 'keratinocyte',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell'],\n", - " dtype=object)" + " 'epithelial cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'epithelial cell', 'epithelial cell',\n", + " 'epithelial cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'keratinocyte', 'keratinocyte', 'leukocyte', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", + " 'fibroblast', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'leukocyte', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'epithelial cell',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'epithelial cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'keratinocyte', 'basal cell', 'keratinocyte',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'pericyte', 'basal cell', 'epithelial cell', 'leukocyte',\n", + " 'basal cell', 'keratinocyte'], dtype=object)" ] }, "metadata": {}, @@ -502,28 +486,28 @@ " \n", " \n", " 0\n", - " basal cell\n", + " keratinocyte\n", " basal cell\n", " \n", " \n", " 1\n", - " basal cell\n", - " basal cell\n", + " leukocyte\n", + " leukocyte\n", " \n", " \n", " 2\n", - " basal cell\n", - " basal cell\n", + " capillary endothelial cell\n", + " vein endothelial cell\n", " \n", " \n", " 3\n", - " leukocyte\n", - " leukocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 4\n", - " basal cell\n", - " basal cell\n", + " leukocyte\n", + " leukocyte\n", " \n", " \n", " ...\n", @@ -532,18 +516,18 @@ " \n", " \n", " 123\n", - " keratinocyte\n", - " keratinocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 124\n", - " keratinocyte\n", - " keratinocyte\n", + " epithelial cell\n", + " epithelial cell\n", " \n", " \n", " 125\n", " basal cell\n", - " basal cell\n", + " leukocyte\n", " \n", " \n", " 126\n", @@ -552,8 +536,8 @@ " \n", " \n", " 127\n", - " basal cell\n", - " basal cell\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", "\n", @@ -561,18 +545,18 @@ "" ], "text/plain": [ - " actual cell type predicted cell type\n", - "0 basal cell basal cell\n", - "1 basal cell basal cell\n", - "2 basal cell basal cell\n", - "3 leukocyte leukocyte\n", - "4 basal cell basal cell\n", - ".. ... ...\n", - "123 keratinocyte keratinocyte\n", - "124 keratinocyte keratinocyte\n", - "125 basal cell basal cell\n", - "126 basal cell basal cell\n", - "127 basal cell basal cell\n", + " actual cell type predicted cell type\n", + "0 keratinocyte basal cell\n", + "1 leukocyte leukocyte\n", + "2 capillary endothelial cell vein endothelial cell\n", + "3 basal cell basal cell\n", + "4 leukocyte leukocyte\n", + ".. ... ...\n", + "123 basal cell basal cell\n", + "124 epithelial cell epithelial cell\n", + "125 basal cell leukocyte\n", + "126 basal cell basal cell\n", + "127 keratinocyte keratinocyte\n", "\n", "[128 rows x 2 columns]" ] From 778022d4d2d21f51b2141b18d533d495ff2791e8 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sun, 25 Aug 2024 10:20:30 -0700 Subject: [PATCH 34/70] update changelog --- CHANGELOG.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2254e83..6f0d697 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,18 +1,20 @@ # Change Log + All notable changes to this project will be documented in this file. - + The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). - + ## [Unreleased] - yyyy-mm-dd Porting and enhancing initial code contribution from the Chan Zuckerberg Initiative Foundation [CELLxGENE](https://cellxgene.cziscience.com/) project. ### Added - + +- Initial commits via PR [#2823](https://github.com/single-cell-data/TileDB-SOMA/pull/2823) + ### Changed - + ### Fixed - From 05e7fc512176f2599bae26e17b288c3a0275bda3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 26 Aug 2024 19:09:48 +0000 Subject: [PATCH 35/70] add collate unit test --- tests/test_pytorch.py | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 5e007dd..5e20b1d 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -6,6 +6,7 @@ from __future__ import annotations import pathlib +from functools import partial from typing import Callable, List, Optional, Sequence, Union from unittest.mock import patch @@ -24,13 +25,14 @@ # This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent # tests. try: + from torch.utils.data._utils.worker import WorkerInfo + from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterable, ExperimentAxisQueryIterableDataset, experiment_dataloader, ) - from torch.utils.data._utils.worker import WorkerInfo except ImportError: # this should only occur when not running `ml`-marked tests pass @@ -609,6 +611,41 @@ def test_experiment_dataloader__batched_length( assert len(dl) == len(list(dl)) +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,batch_size", + [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_experiment_dataloader__collate_fn( + PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + batch_size, +): + def collate_fn(batch_size, data): + assert isinstance(data, tuple) + assert len(data) == 2 + assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) + if batch_size > 1: + assert data[0].shape[0] == data[1].shape[0] + assert data[0].shape[0] <= batch_size + else: + assert data[0].ndim == 1 + assert data[1].shape[1] <= batch_size + + dp = PipeClass( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + batch_size=batch_size, + shuffle=False, + ) + dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) + assert len(list(dl)) > 0 + + @pytest.mark.parametrize( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], @@ -722,8 +759,6 @@ def test_experiment_dataloader__unsupported_params__fails() -> None: experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[]) with pytest.raises(ValueError): experiment_dataloader(dummy_exp_data_pipe, sampler=[]) - with pytest.raises(ValueError): - experiment_dataloader(dummy_exp_data_pipe, collate_fn=lambda x: x) def test_batched() -> None: From 1e442a7cd94289a0e2e95eb9784a4ea3c2817b95 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 26 Aug 2024 19:10:01 +0000 Subject: [PATCH 36/70] clean up experiment_dataloader function --- src/tiledbsoma_ml/pytorch.py | 51 +++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index e2f002c..131d700 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -498,35 +498,62 @@ def _collate_noop(datum: _T) -> _T: def experiment_dataloader( ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, - num_workers: int = 0, + # num_workers: int = 0, **dataloader_kwargs: Any, ) -> torch.utils.data.DataLoader: - # TODO: XXX docstrings - - # XXX why do we disallow collate_fn? - + """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a + :class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` + or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. + + Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, + when using loaders form this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. + Specifying any of these parameters will result in an error. + + Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on + :class:`torch.utils.data.DataLoader` parameters. + + Args: + ds: + A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May + include chained data pipes. + num_workers: + How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) + **dataloader_kwargs: + Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, + except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not + supported when data loaders in this module. + + Returns: + A :class:`torch.utils.data.DataLoader`. + + Raises: + ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params + are passed as keyword arguments. + + Lifecycle: + experimental + """ unsupported_dataloader_args = [ "shuffle", "batch_size", "sampler", "batch_sampler", - "collate_fn", ] if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): raise ValueError( - f"The {','.join(unsupported_dataloader_args)} DataLoader params are not supported" + f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" ) - if num_workers > 0: + if dataloader_kwargs.get("num_workers", 0) > 0: _init_multiprocessing() + if "collate_fn" not in dataloader_kwargs: + dataloader_kwargs["collate_fn"] = _collate_noop + return torch.utils.data.DataLoader( ds, batch_size=None, # batching is handled by upstream iterator - num_workers=num_workers, - collate_fn=_collate_noop, - # shuffling is handled by upstream iterator - shuffle=False, + shuffle=False, # shuffling is handled by upstream iterator **dataloader_kwargs, ) From b01f609310090bb387c866a494f837ad6239b6e6 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 27 Aug 2024 02:02:51 +0000 Subject: [PATCH 37/70] docstrings --- src/tiledbsoma_ml/pytorch.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 131d700..1080991 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -491,11 +491,6 @@ def shape(self) -> Tuple[int, int]: return self._exp_iter.shape -def _collate_noop(datum: _T) -> _T: - """Disable collation in dataloader instance.""" - return datum - - def experiment_dataloader( ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, # num_workers: int = 0, @@ -558,13 +553,22 @@ def experiment_dataloader( ) +def _collate_noop(datum: _T) -> _T: + """Noop collation for use with a dataloader instance. + + Private. + """ + return datum + + def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. A total_length of L, split into N sections, will return L%N sections of size L//N+1, - and the remainder as size L//N. + and the remainder as size L//N. This results in the same split as numpy.array_split, + for an array of length L and sections N. - This results in is the same split as numpy.array_split, for an array of length L. + Private. Examples -------- @@ -635,6 +639,7 @@ def _init_multiprocessing() -> None: Also, CUDA does not support forked child processes: https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + Private. """ orig_start_method = torch.multiprocessing.get_start_method() if orig_start_method != "spawn": @@ -660,6 +665,8 @@ def _csr_to_dense(sp: sparse.csr_array | sparse.csr_matrix) -> NDArrayNumber: """Fast, parallel, variant of scipy.sparse.csr_array.todense. Typically 4-8X faster, dending on host and size of array/matrix. + + Private. """ assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) return cast( From 7b56809daeb77aaba36d23c83d5d0318bb039d54 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 28 Aug 2024 16:54:20 +0000 Subject: [PATCH 38/70] fix typo in notebook name (thanks Ryan!) --- notebooks/{turtorial_pytorch.ipynb => tutorial_pytorch.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename notebooks/{turtorial_pytorch.ipynb => tutorial_pytorch.ipynb} (100%) diff --git a/notebooks/turtorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb similarity index 100% rename from notebooks/turtorial_pytorch.ipynb rename to notebooks/tutorial_pytorch.ipynb From 330bf43af57b3b8bb99bbcb753d9f91cd643ad7a Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 15:54:00 +0000 Subject: [PATCH 39/70] checkpoint updates --- src/tiledbsoma_ml/pytorch.py | 652 +++++++++++++++++++++++++++++------ tests/test_pytorch.py | 587 +++++++++++++++---------------- 2 files changed, 840 insertions(+), 399 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 1080991..edf5e44 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -5,9 +5,11 @@ from __future__ import annotations +import contextlib import gc import itertools import logging +import math import os import sys import time @@ -17,6 +19,7 @@ from typing import ( TYPE_CHECKING, Any, + ContextManager, Dict, Iterable, Iterator, @@ -33,7 +36,6 @@ import numpy.typing as npt import pandas as pd import pyarrow as pa -import scipy.sparse as sparse import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator @@ -56,7 +58,7 @@ XObsDatum: TypeAlias = Tuple[NDArrayNumber, pd.DataFrame] """Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, -which pairs a NumPy ndarray of ``X`` row(s) with a Pandas DataFrame of ``obs`` row(s). If the +which pairs a :class:`numpy.ndarray` of ``X`` row(s) with a :class:`pandas.DataFrame` of ``obs`` row(s). If the ``batch_size`` is 1, the objects are of rank 1, else they are of rank 2.""" @@ -66,7 +68,7 @@ class _ExperimentLocator: Necessary as we will likely be invoked across multiple processes. - Private. + Private implementation class. """ uri: str @@ -91,34 +93,95 @@ def open_experiment(self) -> Iterator[soma.Experiment]: class ExperimentAxisQueryIterable(Iterable[XObsDatum]): - r"""Private base class for Dataset/DataPipe subclasses.""" + """An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as + selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator + produces equal sized ``X`` and ``obs`` data, in the form of a :class:``numpy.ndarray`` and + :class:`pandas.DataFrame`. + + Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and + :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` + and `ExperimentAxisQueryDataPipe` for more details on usage. - # XXX TODO - docstrings, slots, etc. + Lifecycle: + experimental + """ def __init__( self, - experiment: soma.Experiment, - measurement_name: str = "RNA", - X_name: str = "raw", - obs_query: soma.AxisQuery | None = None, - var_query: soma.AxisQuery | None = None, + query: soma.ExperimentAxisQuery, + X_name: str, obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, shuffle: bool = True, - io_batch_size: int = 2**17, + io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, seed: int | None = None, use_eager_fetch: bool = True, - partition: bool = True, ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a 2-tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the X layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of + this class's behavior: 1) The maximum memory utilization, with larger values providing + better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling + (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but + may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers + are employed). + shuffle_chunk_size: + The number of contiguous rows sampled, prior to concatenation and shuffling. + Larger numbers correspond to more randomness per training batch. + If ``shuffle == False``, this parameter is ignored. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.data.utils.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + """ + super().__init__() - # Anything set in the instance needs to be picklable for multi-process users - self.experiment_locator = _ExperimentLocator.create(experiment) - self.measurement_name = measurement_name + # Anything set in the instance needs to be picklable for multi-process DataLoaders + self.experiment_locator = _ExperimentLocator.create(query.experiment) self.layer_name = X_name - self.obs_query = obs_query - self.var_query = var_query + self.measurement_name = query.measurement_name + self.obs_query = query._matrix_axis_query.obs + self.var_query = query._matrix_axis_query.var self.obs_column_names = list(obs_column_names) self.batch_size = batch_size self.io_batch_size = io_batch_size @@ -127,10 +190,9 @@ def __init__( self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None - self.partition = partition + self.shuffle_chunk_size = shuffle_chunk_size self._initialized = False - self.shuffle_chunk_size = shuffle_chunk_size if self.shuffle: # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. self.io_batch_size = ( @@ -141,9 +203,11 @@ def __init__( raise ValueError("Must specify at least one value in `obs_column_names`") def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: - """Private - create iterator over obs id chunks with split size of (roughly) io_batch_size. + """Create iterator over obs id chunks with split size of (roughly) io_batch_size. As appropriate, will chunk, shuffle and apply partitioning per worker. + + Private method. """ assert self._obs_joinids is not None obs_joinids: npt.NDArray[np.int64] = self._obs_joinids @@ -167,16 +231,15 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: ) # Now extract the partition for this worker - partition, num_partitions = ( - _get_torch_partition_info() if self.partition else (0, 1) - ) + partition, num_partitions = _get_torch_partition_info() + obs_splits = _splits(len(obs_joinids_chunked), num_partitions) obs_partition_joinids = obs_joinids_chunked[ obs_splits[partition] : obs_splits[partition + 1] ].copy() obs_joinid_iter = iter(obs_partition_joinids) - if logger.isEnabledFor(logging.DEBUG) and self.partition: + if logger.isEnabledFor(logging.DEBUG): logger.debug( f"Process {os.getpid()} handling partition {partition + 1} of {num_partitions}, " f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" @@ -184,8 +247,13 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: return obs_joinid_iter - def _init_once(self) -> None: - """One-time per worker initialization. All operations be idempotent in order to support pipe reset().""" + def _init_once(self, exp: soma.Experiment | None = None) -> None: + """One-time per worker initialization. + + All operations be idempotent in order to support pipe reset(). + + Private method. + """ if self._initialized: return @@ -193,21 +261,37 @@ def _init_once(self) -> None: f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})" ) - with self.experiment_locator.open_experiment() as exp: - query = exp.axis_query( + if exp is None: + # If no user-provided Experiment, open/close it ourselves + exp_cm: ContextManager[soma.Experiment] = ( + self.experiment_locator.open_experiment() + ) + else: + # else, it is caller responsibility to open/close the experiment + exp_cm = contextlib.nullcontext(exp) + + with exp_cm as exp: + with exp.axis_query( measurement_name=self.measurement_name, obs_query=self.obs_query, var_query=self.var_query, - ) - self._obs_joinids = query.obs_joinids().to_numpy() - self._var_joinids = query.var_joinids().to_numpy() + ) as query: + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() self._initialized = True def __iter__(self) -> Iterator[XObsDatum]: - with self.experiment_locator.open_experiment() as exp: - self._init_once() + """Create iterator over query. + + Returns: + ``iterator`` + Lifecycle: + experimental + """ + with self.experiment_locator.open_experiment() as exp: + self._init_once(exp) X = exp.ms[self.measurement_name].X[self.layer_name] if not isinstance(X, soma.SparseNDArray): raise NotImplementedError( @@ -224,15 +308,26 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter def __len__(self) -> int: + """Return approximate number of batches this iterable will product. + + See import caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + domentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ self._init_once() assert self._obs_joinids is not None - div, rem = divmod(len(self._obs_joinids), self.batch_size) return div + bool(rem) @property def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`tiledbsoma.ml.pytorch.ExperimentDataPipe`. + """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect the size of the partition of the data assigned to the active process. @@ -249,21 +344,23 @@ def shape(self) -> Tuple[int, int]: return len(self._obs_joinids), len(self._var_joinids) def __getitem__(self, index: int) -> XObsDatum: - raise NotImplementedError("Can only be iterated") + raise NotImplementedError( + "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" + ) def _io_batch_iter( self, obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], - ) -> Iterator[Tuple[sparse.csr_array, pd.DataFrame]]: + ) -> Iterator[Tuple[_CSR, pd.DataFrame]]: """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of (X: csr_array, obs: DataFrame). obs joinids read are controlled by the obs_joinid_iter. Iterator results will be reindexed and shuffled (if shuffling enabled). - Private. + Private method. """ assert self._var_joinids is not None @@ -296,26 +393,38 @@ def _io_batch_iter( ) obs_io_batch = obs_io_batch[self.obs_column_names] - X_tbl_iter = X.read(coords=(obs_coords, self._var_joinids)).tables() - X_tbl = pa.concat_tables(X_tbl_iter) + X_tbl_iter: Iterator[pa.Table] = X.read( + coords=(obs_coords, self._var_joinids) + ).tables() obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) - i = obs_indexer.get_indexer(X_tbl["soma_dim_0"].to_numpy()) - j = var_indexer.get_indexer(X_tbl["soma_dim_1"].to_numpy()) - d = X_tbl["soma_data"].to_numpy() - if len(i) < np.iinfo(np.int32).max: - # downcast where able, as this saves considerable memory - i = i.astype(np.int32) - j = j.astype(np.int32) - X_io_batch = sparse.csr_array( - (d, (i, j)), - shape=(len(obs_coords), len(self._var_joinids)), - copy=False, + def make_csr( + X_tbl: pa.Table, + obs_coords: npt.NDArray[np.int64], + var_coords: npt.NDArray[np.int64], + obs_indexer: soma.IntIndexer, + ) -> _CSR: + """This function provides a GC after we throw off (large) garbage.""" + m = _CSR.from_ijd( + obs_indexer.get_indexer(X_tbl["soma_dim_0"]), + var_indexer.get_indexer(X_tbl["soma_dim_1"]), + X_tbl["soma_data"].to_numpy(), + shape=(len(obs_coords), len(var_coords)), + ) + gc.collect(generation=0) + return m + + _csr_iter = ( + make_csr(X_tbl, obs_coords, self._var_joinids, obs_indexer) + for X_tbl in X_tbl_iter ) + if self.use_eager_fetch: + _csr_iter = _EagerIterator(_csr_iter, pool=X.context.threadpool) + X_io_batch = _CSR.merge(tuple(_csr_iter)) - del X_tbl, X_tbl_iter, i, j, d - del obs_coords, obs_shuffled_coords, obs_indexer + del obs_indexer, obs_coords, obs_shuffled_coords, _csr_iter gc.collect() + tm = time.perf_counter() - st_time logger.debug( f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" @@ -330,7 +439,7 @@ def _mini_batch_iter( ) -> Iterator[XObsDatum]: """Break IO batches into shuffled mini-batch-sized chunks. - Private. + Private method. """ assert self._obs_joinids is not None assert self._var_joinids is not None @@ -340,9 +449,7 @@ def _mini_batch_iter( io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) mini_batch_size = self.batch_size - result: Tuple[sparse.csr_array | sparse.csr_matrix, pd.DataFrame] | None = ( - None # partial result - ) + result: Tuple[NDArrayNumber, pd.DataFrame] | None = None for X_io_batch, obs_io_batch in io_batch_iter: assert X_io_batch.shape[0] == obs_io_batch.shape[0] assert X_io_batch.shape[1] == len(self._var_joinids) @@ -353,7 +460,9 @@ def _mini_batch_iter( if result is None: # perform zero copy slice where possible result = ( - X_io_batch[iob_idx : iob_idx + mini_batch_size], + X_io_batch.densified_slice( + slice(iob_idx, iob_idx + mini_batch_size) + ), obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], ) iob_idx += len(result[1]) @@ -361,25 +470,22 @@ def _mini_batch_iter( # use remanent from previous IO batch to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) result = ( - # In older versions of scipy.sparse, vstack will return _matrix when - # called with _array. Various code paths must accomadate this bug (mostly - # in their type declarations) - sparse.vstack([result[0], X_io_batch[0:to_take]]), + np.concatenate( + [result[0], X_io_batch.densified_slice(slice(0, to_take))] + ), pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), ) iob_idx += to_take assert result[0].shape[0] == result[1].shape[0] if result[0].shape[0] == mini_batch_size: - # yield result - yield (_csr_to_dense(result[0]), result[1]) + yield result result = None else: # yield a remnant, if any if result is not None: - # yield result - yield (_csr_to_dense(result[0]), result[1]) + yield result class ExperimentAxisQueryDataPipe( @@ -387,31 +493,40 @@ class ExperimentAxisQueryDataPipe( torch.utils.data.dataset.Dataset[XObsDatum] ], ): - # TODO: XXX docstrings + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for + legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the + TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. + + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ def __init__( self, - experiment: soma.Experiment, - measurement_name: str = "RNA", + query: soma.ExperimentAxisQuery, X_name: str = "raw", - obs_query: soma.AxisQuery | None = None, - var_query: soma.AxisQuery | None = None, obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, shuffle: bool = True, seed: int | None = None, - io_batch_size: int = 2**17, + io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, - # encoders: List[Encoder] | None = None, ): + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ super().__init__() self._exp_iter = ExperimentAxisQueryIterable( - experiment=experiment, - measurement_name=measurement_name, + query=query, X_name=X_name, - obs_query=obs_query, - var_query=var_query, obs_column_names=obs_column_names, batch_size=batch_size, shuffle=shuffle, @@ -419,11 +534,15 @@ def __init__( io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, shuffle_chunk_size=shuffle_chunk_size, - # --- - partition=True, ) def __iter__(self) -> Iterator[XObsDatum]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: if batch_size == 1: @@ -431,40 +550,153 @@ def __iter__(self) -> Iterator[XObsDatum]: yield X, obs def __len__(self) -> int: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ return self._exp_iter.__len__() @property def shape(self) -> Tuple[int, int]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ return self._exp_iter.shape class ExperimentAxisQueryIterableDataset( torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] ): - # TODO: XXX docstrings + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as + specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an NumPy ndarray and a Pandas DataFrame. + + For example: + + >>> import torch + >>> import tiledbsoma + >>> import tiledbsoma_ml + >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: + with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(ds) + dataloader = torch.utils.data.DataLoader(ds) + >>> data = next(iter(dataloader)) + >>> data + (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), + soma_joinid + 0 57905025) + >>> data[0] + array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) + >>> data[1] + soma_joinid + 0 57905025 + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each + iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` + of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more + performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. + + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the + default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). + + The `io_batch_size` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage, and may reduce average read time per row. + + Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using + :class:`DataLoader` shuffling. The shuffling algorithm works as follows: + + 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". + 2. A random selection of shuffle `chunks` is drawn and read as a single I/O buffer (of size ``io_buffer_size``). + 3. The entire I/O buffer is shuffled. + + Put another way, we read randomly selected groups of observations from across all query results, concatenate + those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle + is therefore determiend by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` + (number of rows in each draw). + + Lifecycle: + experimental + """ def __init__( self, - experiment: soma.Experiment, - measurement_name: str = "RNA", + query: soma.ExperimentAxisQuery, X_name: str = "raw", - obs_query: soma.AxisQuery | None = None, - var_query: soma.AxisQuery | None = None, obs_column_names: Sequence[str] = ("soma_joinid",), - batch_size: int = 1, # XXX add docstring noting values >1 will not work with default collator + batch_size: int = 1, shuffle: bool = True, seed: int | None = None, - io_batch_size: int = 2**17, + io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, use_eager_fetch: bool = True, ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a 2-tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the X layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of + this class's behavior: 1) The maximum memory utilization, with larger values providing + better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling + (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but + may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers + are employed). + shuffle_chunk_size: + The number of contiguous rows sampled, prior to concatenation and shuffling. + Larger numbers correspond to more randomness per training batch. + If ``shuffle == False``, this parameter is ignored. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.data.utils.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ super().__init__() self._exp_iter = ExperimentAxisQueryIterable( - experiment=experiment, - measurement_name=measurement_name, + query=query, X_name=X_name, - obs_query=obs_query, - var_query=var_query, obs_column_names=obs_column_names, batch_size=batch_size, shuffle=shuffle, @@ -472,11 +704,17 @@ def __init__( io_batch_size=io_batch_size, use_eager_fetch=use_eager_fetch, shuffle_chunk_size=shuffle_chunk_size, - # --- - partition=True, ) def __iter__(self) -> Iterator[XObsDatum]: + """Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: if batch_size == 1: @@ -484,10 +722,34 @@ def __iter__(self) -> Iterator[XObsDatum]: yield X, obs def __len__(self) -> int: + """Return approximate number of batches this iterable will produce. + + See import caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + domentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ return self._exp_iter.__len__() @property def shape(self) -> Tuple[int, int]: + """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`. + + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. + + Returns: + A 2-tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ return self._exp_iter.shape @@ -651,27 +913,199 @@ def _init_multiprocessing() -> None: torch.multiprocessing.set_start_method("spawn", force=True) -@numba.njit(nogil=True, parallel=True) -def _csr_to_dense_inner(indptr, indices, data, out): # type:ignore[no-untyped-def] - n_rows = out.shape[0] - for i in numba.prange(n_rows): +class _CSR: + """Implement fast slice and converion to numpy.ndarray, which is used to materialize + X mini-batches. + + This is faster and users less memory than the equivalent scipy.sparse.csr_array operations, + e.g., `mtrx[n:n+m].todense()`, and also produces fewer memory copies as views are materialized. + + Primary optimizations: + * zero copy / parallel conversion to dense ndarray + * faster construction: stores in classic P,J,V format, but does not sort/de-dup minor axis + * incremental construction from COO via merging allows overlapping construction and I/O + """ + + __slots__ = ("indptr", "indices", "data", "shape") + + def __init__( + self, + indptr: NDArrayNumber, + indices: NDArrayNumber, + data: NDArrayNumber, + shape: Tuple[int, int], + ) -> None: + self.shape = shape + self.indptr = indptr + self.indices = indices + self.data = data + + @staticmethod + def from_ijd( + i: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] + ) -> _CSR: + """Factory from COO""" + nnz = len(d) + if nnz < np.iinfo(np.int32).max: + index_dtype: npt.DTypeLike = np.int32 + else: + index_dtype = np.int64 + indptr = np.zeros((shape[0] + 1), dtype=index_dtype) + indices = np.empty((nnz,), dtype=index_dtype) + data = np.empty((nnz,), dtype=d.dtype) + _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) + return _CSR(indptr, indices, data, shape) + + @staticmethod + def from_pjd( + p: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] + ) -> _CSR: + """Factory from CSR""" + return _CSR(p, j, d, shape) + + @property + def nnz(self) -> int: + return len(self.indices) + + @property + def nybtes(self) -> int: + return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) + + @property + def dtype(self) -> npt.DTypeLike: + return self.data.dtype + + def densified_slice(self, row_index: slice) -> NDArrayNumber: + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + indptr = self.indptr + indices = self.indices + data = self.data + return _csr_to_dense_inner( + row_idx_start, + row_idx_end, + indptr, + indices, + data, + np.zeros( + (row_idx_end - row_idx_start, self.shape[1]), dtype=self.data.dtype + ), + ) + + @staticmethod + def merge(mtxs: Sequence[_CSR]) -> _CSR: + assert len(mtxs) > 0 + nnz = sum(m.nnz for m in mtxs) + shape = mtxs[0].shape + assert all(m.shape == shape for m in mtxs) + + if nnz < np.iinfo(np.int32).max: + index_dtype: npt.DTypeLike = np.int32 + else: + index_dtype = np.int64 + + indptr = np.sum([m.indptr for m in mtxs], axis=0, dtype=index_dtype) + indices = np.empty((nnz,), dtype=index_dtype) + data = np.empty((nnz,), mtxs[0].data.dtype) + + _csr_merge_inner( + tuple((m.indptr, m.indices, m.data) for m in mtxs), indptr, indices, data + ) + return _CSR.from_pjd(indptr, indices, data, shape) + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_merge_inner( + As: Tuple[Tuple[NDArrayNumber, NDArrayNumber, NDArrayNumber], ...], # P,J,D + Bp: NDArrayNumber, + Bj: NDArrayNumber, + Bd: NDArrayNumber, +) -> None: + n_rows = len(Bp) - 1 + offsets = Bp.copy() + for Ap, Aj, Ad in As: + n_elmts = Ap[1:] - Ap[:-1] + for n in numba.prange(n_rows): + Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] + Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] + offsets[:-1] += n_elmts + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_to_dense_inner( + row_idx_start: int, + row_idx_end: int, + indptr: NDArrayNumber, + indices: NDArrayNumber, + data: NDArrayNumber, + out: NDArrayNumber, +) -> NDArrayNumber: + for i in numba.prange(row_idx_start, row_idx_end): for j in range(indptr[i], indptr[i + 1]): - out[i, indices[j]] = data[j] + out[i - row_idx_start, indices[j]] = data[j] return out -def _csr_to_dense(sp: sparse.csr_array | sparse.csr_matrix) -> NDArrayNumber: - """Fast, parallel, variant of scipy.sparse.csr_array.todense. +@numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] +def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: + """Private: parallel row count.""" + nnz = len(Ai) - Typically 4-8X faster, dending on host and size of array/matrix. + partition_size = 10 * 1024**2 + n_partitions = math.ceil(nnz / partition_size) + if n_partitions > 1: + counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) + for p in numba.prange(n_partitions): + for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): + row = Ai[n] + counts[p, row] += 1 - Private. - """ - assert isinstance(sp, (sparse.csr_array, sparse.csr_matrix)) - return cast( - NDArrayNumber, - _csr_to_dense_inner( - sp.indptr, sp.indices, sp.data, np.zeros(sp.shape, dtype=sp.dtype) - ), - ) + Bp[:-1] = counts.sum(axis=0) + else: + for n in range(nnz): + row = Ai[n] + Bp[row] += 1 + + return Bp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _coo_to_csr_inner( + n_rows: int, + Ai: NDArrayNumber, + Aj: NDArrayNumber, + Ad: NDArrayNumber, + Bp: NDArrayNumber, + Bj: NDArrayNumber, + Bd: NDArrayNumber, +) -> None: + nnz = len(Ai) + + _count_rows(n_rows, Ai, Bp) + + # cum sum to get the row index pointers (NOTE: starting with zero) + cumsum = 0 + for n in range(n_rows): + tmp = Bp[n] + Bp[n] = cumsum + cumsum += tmp + Bp[n_rows] = nnz + + # reorganize all of the data. side-effect: pointers shifted. + for n in range(nnz): + row = Ai[n] + dst_row = Bp[row] + + Bj[dst_row] = Aj[n] + Bd[dst_row] = Ad[n] + + Bp[row] += 1 + + # and shift the pointers by one (ie., start at zero) + prev_ptr = 0 + for n in range(n_rows + 1): + tmp = Bp[n] + Bp[n] = prev_ptr + prev_ptr = tmp diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 5e20b1d..1445293 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -7,10 +7,11 @@ import pathlib from functools import partial -from typing import Callable, List, Optional, Sequence, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from unittest.mock import patch import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa import pytest @@ -25,14 +26,13 @@ # This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent # tests. try: - from torch.utils.data._utils.worker import WorkerInfo - from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterable, ExperimentAxisQueryIterableDataset, experiment_dataloader, ) + from torch.utils.data._utils.worker import WorkerInfo except ImportError: # this should only occur when not running `ml`-marked tests pass @@ -161,28 +161,28 @@ def test_non_batched( use_eager_fetch: bool, ) -> None: # batch_size should default to 1 - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - assert type(exp_data_pipe.shape) is tuple - assert len(exp_data_pipe.shape) == 2 - assert exp_data_pipe.shape == (6, 3) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + assert type(exp_data_pipe.shape) is tuple + assert len(exp_data_pipe.shape) == 2 + assert exp_data_pipe.shape == (6, 3) - row_iter = iter(exp_data_pipe) + row_iter = iter(exp_data_pipe) - row = next(row_iter) - assert isinstance(row[0], np.ndarray) - assert isinstance(row[1], pd.DataFrame) - assert row[0].shape == (3,) - assert row[1].shape == (1, 1) - assert row[0].tolist() == [0, 1, 0] - assert row[1].keys() == ["label"] - assert row[1]["label"].tolist() == ["0"] + row = next(row_iter) + assert isinstance(row[0], np.ndarray) + assert isinstance(row[1], pd.DataFrame) + assert row[0].shape == (3,) + assert row[1].shape == (1, 1) + assert row[0].tolist() == [0, 1, 0] + assert row[1].keys() == ["label"] + assert row[1]["label"].tolist() == ["0"] @pytest.mark.parametrize( @@ -198,27 +198,27 @@ def test_uneven_soma_and_result_batches( use_eager_fetch: bool, ) -> None: """This is checking that batches are correctly created when they require fetching multiple chunks.""" - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - shuffle=False, - batch_size=3, - io_batch_size=2, - use_eager_fetch=use_eager_fetch, - ) - row_iter = iter(exp_data_pipe) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + batch_size=3, + io_batch_size=2, + use_eager_fetch=use_eager_fetch, + ) + row_iter = iter(exp_data_pipe) - X_batch, obs_batch = next(row_iter) - assert isinstance(X_batch, np.ndarray) - assert isinstance(obs_batch, pd.DataFrame) - assert X_batch.shape[0] == obs_batch.shape[0] - assert X_batch.shape == (3, 3) - assert obs_batch.shape == (3, 1) - assert X_batch[0].tolist() == [0, 1, 0] - assert ["label"] == obs_batch.keys() - assert obs_batch["label"].tolist() == ["0", "1", "2"] + X_batch, obs_batch = next(row_iter) + assert isinstance(X_batch, np.ndarray) + assert isinstance(obs_batch, pd.DataFrame) + assert X_batch.shape[0] == obs_batch.shape[0] + assert X_batch.shape == (3, 3) + assert obs_batch.shape == (3, 1) + assert X_batch[0].tolist() == [0, 1, 0] + assert ["label"] == obs_batch.keys() + assert obs_batch["label"].tolist() == ["0", "1", "2"] @pytest.mark.parametrize( @@ -233,29 +233,29 @@ def test_batching__all_batches_full_size( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1].keys() == ["label"] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["0", "1", "2"] - batch = next(batch_iter) - assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] - assert batch[1].keys() == ["label"] - assert batch[1]["label"].tolist() == ["3", "4", "5"] + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["3", "4", "5"] - with pytest.raises(StopIteration): - next(batch_iter) + with pytest.raises(StopIteration): + next(batch_iter) @pytest.mark.parametrize( @@ -273,19 +273,19 @@ def test_unique_soma_joinids( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["soma_joinid", "label"], - batch_size=3, - use_eager_fetch=use_eager_fetch, - ) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) - soma_joinids = np.concatenate( - [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] - ) - assert len(np.unique(soma_joinids)) == len(soma_joinids) + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] + ) + assert len(np.unique(soma_joinids)) == len(soma_joinids) @pytest.mark.parametrize( @@ -300,23 +300,23 @@ def test_batching__partial_final_batch_size( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) - - next(batch_iter) - batch = next(batch_iter) - assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) - with pytest.raises(StopIteration): next(batch_iter) + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + + with pytest.raises(StopIteration): + next(batch_iter) @pytest.mark.parametrize( @@ -331,23 +331,23 @@ def test_batching__exactly_one_batch( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1]["label"].tolist() == ["0", "1", "2"] - with pytest.raises(StopIteration): - next(batch_iter) + with pytest.raises(StopIteration): + next(batch_iter) @pytest.mark.parametrize( @@ -362,19 +362,20 @@ def test_batching__empty_query_result( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_query=soma.AxisQuery(coords=([],)), - obs_column_names=["label"], - batch_size=3, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) + with soma_experiment.axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(coords=([],)) + ) as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) - with pytest.raises(StopIteration): - next(batch_iter) + with pytest.raises(StopIteration): + next(batch_iter) @pytest.mark.parametrize( @@ -392,20 +393,20 @@ def test_batching__partial_soma_batches_are_concatenated( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - exp_data_pipe = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches - io_batch_size=4, - use_eager_fetch=use_eager_fetch, - ) + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches + io_batch_size=4, + use_eager_fetch=use_eager_fetch, + ) - full_result = list(exp_data_pipe) + full_result = list(exp_data_pipe) - assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] @pytest.mark.parametrize( @@ -420,21 +421,22 @@ def test_multiprocessing__returns_full_result( ) -> None: """Tests the ExperimentAxisQueryDataPipe provides all data, as collected from multiple processes that are managed by a PyTorch DataLoader with multiple workers configured.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + io_batch_size=3, # two chunks, one per worker + ) + # Note we're testing the ExperimentAxisQueryDataPipe via a DataLoader, since this is what sets up the multiprocessing + dl = experiment_dataloader(dp, num_workers=2) - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["soma_joinid", "label"], - io_batch_size=3, # two chunks, one per worker - ) - # Note we're testing the ExperimentAxisQueryDataPipe via a DataLoader, since this is what sets up the multiprocessing - dl = experiment_dataloader(dp, num_workers=2) - - full_result = list(iter(dl)) + full_result = list(iter(dl)) - soma_joinids = np.concatenate([t[1]["soma_joinid"].to_numpy() for t in full_result]) - assert sorted(soma_joinids) == list(range(6)) + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + assert sorted(soma_joinids) == list(range(6)) @pytest.mark.parametrize( @@ -459,23 +461,23 @@ def test_distributed__returns_data_partition_for_rank( mock_dist_get_rank.return_value = 1 mock_dist_get_world_size.return_value = 3 - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["soma_joinid"], - io_batch_size=2, - shuffle=False, - ) - full_result = list(iter(dp)) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) + full_result = list(iter(dp)) - soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] - ) + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) - # Of the 6 obs rows, the PyTorch process of rank 1 should get [2, 3] - # (rank 0 gets [0, 1], rank 2 gets [4, 5]) - assert sorted(soma_joinids) == [2, 3] + # Of the 6 obs rows, the PyTorch process of rank 1 should get [2, 3] + # (rank 0 gets [0, 1], rank 2 gets [4, 5]) + assert sorted(soma_joinids) == [2, 3] @pytest.mark.parametrize( @@ -504,25 +506,25 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( mock_dist_get_rank.return_value = 1 mock_dist_get_world_size.return_value = 3 - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["soma_joinid"], - io_batch_size=2, - shuffle=False, - ) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + shuffle=False, + ) - full_result = list(iter(dp)) + full_result = list(iter(dp)) - soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] - ) + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) - # Of the 12 obs rows, the PyTorch process of rank 1 should get [4..7], and then within that partition, - # the 2nd DataLoader process should get the second half of the rank's partition, which is just [6, 7] - # (rank 0 gets [0..3], rank 2 gets [8..11]) - assert sorted(soma_joinids) == [6, 7] + # Of the 12 obs rows, the PyTorch process of rank 1 should get [4..7], and then within that partition, + # the 2nd DataLoader process should get the second half of the rank's partition, which is just [6, 7] + # (rank 0 gets [0..3], rank 2 gets [8..11]) + assert sorted(soma_joinids) == [6, 7] @pytest.mark.parametrize( @@ -537,22 +539,22 @@ def test_experiment_dataloader__non_batched( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - data = [row for row in dl] - assert all(d[0].shape == (3,) for d in data) - assert all(d[1].shape == (1, 1) for d in data) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] + assert all(d[0].shape == (3,) for d in data) + assert all(d[1].shape == (1, 1) for d in data) - row = data[0] - assert row[0].tolist() == [0, 1, 0] - assert row[1]["label"].tolist() == ["0"] + row = data[0] + assert row[0].tolist() == [0, 1, 0] + assert row[1]["label"].tolist() == ["0"] @pytest.mark.parametrize( @@ -567,20 +569,20 @@ def test_experiment_dataloader__batched( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - data = [row for row in dl] + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + data = [row for row in dl] - batch = data[0] - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1].to_numpy().tolist() == [[0], [1], [2]] + batch = data[0] + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].to_numpy().tolist() == [[0], [1], [2]] @pytest.mark.parametrize( @@ -598,17 +600,17 @@ def test_experiment_dataloader__batched_length( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - shuffle=False, - use_eager_fetch=use_eager_fetch, - ) - dl = experiment_dataloader(dp) - assert len(dl) == len(list(dl)) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + dl = experiment_dataloader(dp) + assert len(dl) == len(list(dl)) @pytest.mark.parametrize( @@ -621,9 +623,11 @@ def test_experiment_dataloader__batched_length( def test_experiment_dataloader__collate_fn( PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, - batch_size, -): - def collate_fn(batch_size, data): + batch_size: int, +) -> None: + def collate_fn( + batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame] + ) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]: assert isinstance(data, tuple) assert len(data) == 2 assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame) @@ -633,17 +637,18 @@ def collate_fn(batch_size, data): else: assert data[0].ndim == 1 assert data[1].shape[1] <= batch_size + return data - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=batch_size, - shuffle=False, - ) - dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) - assert len(list(dl)) > 0 + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=batch_size, + shuffle=False, + ) + dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) + assert len(list(dl)) > 0 @pytest.mark.parametrize( @@ -658,17 +663,17 @@ def test__X_tensor_dtype_matches_X_matrix( soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - batch_size=3, - use_eager_fetch=use_eager_fetch, - ) - data = next(iter(dp)) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + data = next(iter(dp)) - assert data[0].dtype == np.float32 + assert data[0].dtype == np.float32 @pytest.mark.parametrize( @@ -677,18 +682,20 @@ def test__X_tensor_dtype_matches_X_matrix( def test__pytorch_splitting( soma_experiment: Experiment, ) -> None: - dp = ExperimentAxisQueryDataPipe( - soma_experiment, - measurement_name="RNA", - X_name="raw", - obs_column_names=["label"], - ) - # function not available for IterableDataset, yet.... - dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=1234) - dl = experiment_dataloader(dp_train) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryDataPipe( + query, + X_name="raw", + obs_column_names=["label"], + ) + # function not available for IterableDataset, yet.... + dp_train, dp_test = dp.random_split( + weights={"train": 0.7, "test": 0.3}, seed=1234 + ) + dl = experiment_dataloader(dp_train) - all_rows = list(iter(dl)) - assert len(all_rows) == 7 + all_rows = list(iter(dl)) + assert len(all_rows) == 7 @pytest.mark.parametrize( @@ -701,25 +708,25 @@ def test__shuffle( PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, ) -> None: - dp = PipeClass( - soma_experiment, - measurement_name="RNA", - X_name="raw", - shuffle=True, - ) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + shuffle=True, + ) - all_rows = list(iter(dp)) - assert all(r[0].shape == (1,) for r in all_rows) - soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] - X_values = [row[0][0].item() for row in all_rows] + all_rows = list(iter(dp)) + assert all(r[0].shape == (1,) for r in all_rows) + soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] + X_values = [row[0][0].item() for row in all_rows] - # same elements - assert set(soma_joinids) == set(range(16)) - # not ordered! (...with a `1/16!` probability of being ordered) - assert soma_joinids != list(range(16)) - # randomizes X in same order as obs - # note: X values were explicitly set to match obs_joinids to allow for this simple assertion - assert X_values == soma_joinids + # same elements + assert set(soma_joinids) == set(range(16)) + # not ordered! (...with a `1/16!` probability of being ordered) + assert soma_joinids != list(range(16)) + # randomizes X in same order as obs + # note: X values were explicitly set to match obs_joinids to allow for this simple assertion + assert X_values == soma_joinids @pytest.mark.parametrize( @@ -728,23 +735,22 @@ def test__shuffle( def test_experiment_axis_query_iterable_error_checks( soma_experiment: Experiment, ) -> None: - dp = ExperimentAxisQueryIterable( - soma_experiment, - measurement_name="RNA", - X_name="raw", - shuffle=True, - ) - with pytest.raises(NotImplementedError): - dp[0] - - with pytest.raises(ValueError): + with soma_experiment.axis_query(measurement_name="RNA") as query: dp = ExperimentAxisQueryIterable( - soma_experiment, - obs_column_names=(), - measurement_name="RNA", + query, X_name="raw", shuffle=True, ) + with pytest.raises(NotImplementedError): + dp[0] + + with pytest.raises(ValueError): + dp = ExperimentAxisQueryIterable( + query, + obs_column_names=(), + X_name="raw", + shuffle=True, + ) def test_experiment_dataloader__unsupported_params__fails() -> None: @@ -795,20 +801,21 @@ def test_splits() -> None: _splits(10, -1) -def test_csr_to_dense() -> None: - from tiledbsoma_ml.pytorch import _csr_to_dense +# temp comment out while building _CSR tests +# def test_csr_to_dense() -> None: +# from tiledbsoma_ml.pytorch import _csr_to_dense - coo = sparse.eye(1001, 77, format="coo", dtype=np.float32) +# coo = sparse.eye(1001, 77, format="coo", dtype=np.float32) - assert np.array_equal( - sparse.csr_array(coo).todense(), _csr_to_dense(sparse.csr_array(coo)) - ) - assert np.array_equal( - sparse.csr_matrix(coo).todense(), _csr_to_dense(sparse.csr_matrix(coo)) - ) +# assert np.array_equal( +# sparse.csr_array(coo).todense(), _csr_to_dense(sparse.csr_array(coo)) +# ) +# assert np.array_equal( +# sparse.csr_matrix(coo).todense(), _csr_to_dense(sparse.csr_matrix(coo)) +# ) - csr = sparse.csr_array(coo) - assert np.array_equal(csr.todense(), _csr_to_dense(csr)) - assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :])) - assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:])) - assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22])) +# csr = sparse.csr_array(coo) +# assert np.array_equal(csr.todense(), _csr_to_dense(csr)) +# assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :])) +# assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:])) +# assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22])) From 745d60055c6460f7db542368b07bedcad4da2ec2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 10:40:35 -0700 Subject: [PATCH 40/70] update tests to include _CSR tests --- tests/test_pytorch.py | 94 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 1445293..f64ddd5 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -819,3 +819,97 @@ def test_splits() -> None: # assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :])) # assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:])) # assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22])) + + +@pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml.pytorch import _CSR + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.05) + sp_csr = sp_coo.tocsr() + + _ncsr = _CSR.from_ijd(sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape) + assert _ncsr.nnz == sp_coo.nnz == sp_csr.nnz + assert _ncsr.dtype == sp_coo.dtype == sp_csr.dtype + assert _ncsr.nbytes == ( + _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes + ) + + # _CSR makes no guarantees about minor axis ordering (ie.., "canonical" form), so + # use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_coo.toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(None)), sp_csr[:].toarray()) + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: + from tiledbsoma_ml.pytorch import _CSR + + sp_csr = sparse.random(shape[0], shape[1], dtype=dtype, format="csr", density=0.05) + + _ncsr = _CSR.from_pjd( + sp_csr.indptr.copy(), + sp_csr.indices.copy(), + sp_csr.data.copy(), + shape=sp_csr.shape, + ) + + # _CSR makes no guarantees about minor axis ordering (ie.., "canonical" form), so + # use the SciPy sparse csr package to validate by round-tripping. + assert ( + sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) + != sp_csr + ).nnz == 0 + + assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.densified_slice(slice(None)), sp_csr[:].toarray()) + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100)], +) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) +@pytest.mark.parametrize("n_splits", [2, 3, 4]) +def test_csr__merge( + shape: Tuple[int, int], dtype: npt.DTypeLike, n_splits: int +) -> None: + from tiledbsoma_ml.pytorch import _CSR + + sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.5) + splits = [ + t + for t in zip( + np.array_split(sp_coo.row, n_splits), + np.array_split(sp_coo.col, n_splits), + np.array_split(sp_coo.data, n_splits), + ) + ] + _ncsr = _CSR.merge( + [_CSR.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] + ) + + assert ( + sp_coo.tocsr() + != sparse.csr_matrix( + (_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape + ) + ).nnz == 0 From 065f78c5727aa5aaee466f7d56d4c728cf9c1649 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 10:40:51 -0700 Subject: [PATCH 41/70] fix typo in method name --- src/tiledbsoma_ml/pytorch.py | 42 ++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index edf5e44..54f1fe0 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -914,16 +914,15 @@ def _init_multiprocessing() -> None: class _CSR: - """Implement fast slice and converion to numpy.ndarray, which is used to materialize - X mini-batches. + """Implement a minimal CSR matrix with specific optimizations for use in this package. - This is faster and users less memory than the equivalent scipy.sparse.csr_array operations, - e.g., `mtrx[n:n+m].todense()`, and also produces fewer memory copies as views are materialized. + Operations supported are: + * Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, + and a final "merge" step which combines the result. + * Zero intermediate copy conversion of an arbitrary row slice to dense (ie., mini-batch extraction). + * Parallel ops where it makes sense (construction, merge, etc) - Primary optimizations: - * zero copy / parallel conversion to dense ndarray - * faster construction: stores in classic P,J,V format, but does not sort/de-dup minor axis - * incremental construction from COO via merging allows overlapping construction and I/O + Overall is significantly faster, and uses less memory, than the equivalent scipy.sparse operations. """ __slots__ = ("indptr", "indices", "data", "shape") @@ -935,6 +934,11 @@ def __init__( data: NDArrayNumber, shape: Tuple[int, int], ) -> None: + """Construct from PJV format.""" + assert len(data) == len(indices) + assert len(data) < np.iinfo(indptr.dtype).max + assert indptr[-1] == len(data) and indptr[0] == 0 + self.shape = shape self.indptr = indptr self.indices = indices @@ -968,7 +972,7 @@ def nnz(self) -> int: return len(self.indices) @property - def nybtes(self) -> int: + def nbytes(self) -> int: return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) @property @@ -979,19 +983,17 @@ def densified_slice(self, row_index: slice) -> NDArrayNumber: assert isinstance(row_index, slice) assert row_index.step in (1, None) row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + if row_idx_end - row_idx_start <= 0: + return np.zeros((0, self.shape[1]), dtype=self.data.dtype) + indptr = self.indptr indices = self.indices data = self.data - return _csr_to_dense_inner( - row_idx_start, - row_idx_end, - indptr, - indices, - data, - np.zeros( - (row_idx_end - row_idx_start, self.shape[1]), dtype=self.data.dtype - ), + out = np.zeros( + (row_idx_end - row_idx_start, self.shape[1]), dtype=self.data.dtype ) + _csr_to_dense_inner(row_idx_start, row_idx_end, indptr, indices, data, out) + return out @staticmethod def merge(mtxs: Sequence[_CSR]) -> _CSR: @@ -1040,13 +1042,11 @@ def _csr_to_dense_inner( indices: NDArrayNumber, data: NDArrayNumber, out: NDArrayNumber, -) -> NDArrayNumber: +) -> None: for i in numba.prange(row_idx_start, row_idx_end): for j in range(indptr[i], indptr[i + 1]): out[i - row_idx_start, indices[j]] = data[j] - return out - @numba.njit(nogil=True, parallel=True, inline="always") # type:ignore[misc] def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: From 0fcdd11fbe4e945218652d2c826ab45ca9896faf Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 11:11:42 -0700 Subject: [PATCH 42/70] tuning --- src/tiledbsoma_ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 54f1fe0..c69db9c 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -1053,7 +1053,7 @@ def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNum """Private: parallel row count.""" nnz = len(Ai) - partition_size = 10 * 1024**2 + partition_size = 32 * 1024**2 n_partitions = math.ceil(nnz / partition_size) if n_partitions > 1: counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) From 0b8f786be5e3102387acfbc6a33ac046f57292d2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 11:23:02 -0700 Subject: [PATCH 43/70] update demo notebook --- notebooks/tutorial_pytorch.ipynb | 190 +++++++++++++++---------------- 1 file changed, 94 insertions(+), 96 deletions(-) diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index 7dedfed..d24e6b5 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -47,29 +47,27 @@ "\n", "import tiledbsoma as soma\n", "\n", - "CZI_Census_Homo_Sapiens_URL = (\n", - " \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", - ")\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", "\n", "experiment = soma.open(\n", - " CZI_Census_Homo_Sapiens_URL, context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"})\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", ")\n", "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", - "obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", - "\n", - "experiment_dataset = soma_ml.ExperimentAxisQueryDataPipe(\n", - " experiment,\n", - " measurement_name=\"RNA\",\n", - " X_name=\"raw\",\n", - " obs_query=obs_query,\n", - " obs_column_names=[\"cell_type\"],\n", - " batch_size=128,\n", - " shuffle=True,\n", - ")\n", "\n", - "with experiment.axis_query(measurement_name=\"RNA\", obs_query=obs_query) as query:\n", - " obs_df = query.obs(column_names=['cell_type']).concat().to_pandas()\n", - " cell_type_encoder = LabelEncoder().fit(obs_df['cell_type'].unique())" + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryDataPipe(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n" ] }, { @@ -291,16 +289,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0160182 Accuracy 0.3806\n", - "Epoch 2: Train Loss: 0.0147012 Accuracy 0.4701\n", - "Epoch 3: Train Loss: 0.0143336 Accuracy 0.5381\n", - "Epoch 4: Train Loss: 0.0141242 Accuracy 0.5828\n", - "Epoch 5: Train Loss: 0.0139505 Accuracy 0.6105\n", - "Epoch 6: Train Loss: 0.0138496 Accuracy 0.6249\n", - "Epoch 7: Train Loss: 0.0137310 Accuracy 0.6665\n", - "Epoch 8: Train Loss: 0.0136376 Accuracy 0.7125\n", - "Epoch 9: Train Loss: 0.0135705 Accuracy 0.7954\n", - "Epoch 10: Train Loss: 0.0134742 Accuracy 0.8539\n" + "Epoch 1: Train Loss: 0.0161885 Accuracy 0.2666\n", + "Epoch 2: Train Loss: 0.0148625 Accuracy 0.4692\n", + "Epoch 3: Train Loss: 0.0144180 Accuracy 0.5561\n", + "Epoch 4: Train Loss: 0.0141377 Accuracy 0.6762\n", + "Epoch 5: Train Loss: 0.0139051 Accuracy 0.7803\n", + "Epoch 6: Train Loss: 0.0137547 Accuracy 0.8508\n", + "Epoch 7: Train Loss: 0.0136169 Accuracy 0.8972\n", + "Epoch 8: Train Loss: 0.0134998 Accuracy 0.9135\n", + "Epoch 9: Train Loss: 0.0134092 Accuracy 0.9220\n", + "Epoch 10: Train Loss: 0.0133577 Accuracy 0.9245\n" ] } ], @@ -358,14 +356,14 @@ { "data": { "text/plain": [ - "tensor([ 1, 8, 11, 1, 8, 7, 5, 1, 1, 1, 1, 1, 8, 1, 5, 8, 7, 1,\n", - " 8, 8, 1, 1, 8, 1, 5, 1, 7, 8, 5, 1, 5, 5, 1, 1, 5, 7,\n", - " 8, 8, 1, 8, 1, 1, 1, 1, 5, 1, 1, 11, 1, 1, 5, 7, 7, 1,\n", - " 1, 5, 1, 1, 7, 1, 5, 5, 5, 7, 7, 1, 8, 1, 1, 7, 7, 7,\n", - " 8, 8, 5, 1, 1, 8, 1, 5, 5, 1, 6, 5, 1, 5, 8, 8, 1, 5,\n", - " 7, 1, 1, 5, 7, 1, 1, 7, 5, 5, 8, 1, 1, 1, 8, 5, 1, 7,\n", - " 1, 7, 8, 1, 1, 5, 5, 1, 1, 1, 1, 1, 7, 1, 9, 1, 5, 8,\n", - " 1, 7], device='cuda:0')" + "tensor([ 1, 7, 5, 7, 7, 1, 7, 5, 1, 7, 7, 8, 11, 1, 7, 1, 1, 5,\n", + " 8, 8, 5, 1, 1, 1, 7, 1, 1, 1, 8, 1, 1, 7, 1, 7, 1, 7,\n", + " 6, 7, 1, 1, 5, 8, 1, 1, 8, 1, 1, 7, 8, 1, 1, 1, 1, 1,\n", + " 5, 1, 5, 8, 1, 5, 1, 8, 7, 1, 7, 7, 1, 1, 1, 1, 7, 1,\n", + " 1, 1, 8, 1, 1, 7, 1, 5, 5, 1, 1, 7, 7, 9, 1, 5, 1, 1,\n", + " 8, 1, 7, 7, 11, 7, 1, 7, 1, 7, 1, 8, 1, 1, 1, 11, 1, 1,\n", + " 1, 8, 1, 5, 1, 1, 7, 1, 7, 8, 7, 1, 5, 7, 7, 5, 1, 7,\n", + " 8, 7], device='cuda:0')" ] }, "metadata": {}, @@ -401,39 +399,39 @@ { "data": { "text/plain": [ - "array(['basal cell', 'leukocyte', 'vein endothelial cell', 'basal cell',\n", - " 'leukocyte', 'keratinocyte', 'epithelial cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'epithelial cell', 'leukocyte',\n", - " 'keratinocyte', 'basal cell', 'leukocyte', 'leukocyte',\n", - " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", - " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", - " 'epithelial cell', 'basal cell', 'epithelial cell',\n", - " 'epithelial cell', 'basal cell', 'basal cell', 'epithelial cell',\n", - " 'keratinocyte', 'leukocyte', 'leukocyte', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'vein endothelial cell', 'basal cell', 'basal cell',\n", - " 'epithelial cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'epithelial cell', 'epithelial cell',\n", - " 'epithelial cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'keratinocyte', 'keratinocyte', 'leukocyte', 'leukocyte',\n", + "array(['basal cell', 'keratinocyte', 'epithelial cell', 'keratinocyte',\n", + " 'keratinocyte', 'basal cell', 'keratinocyte', 'epithelial cell',\n", + " 'basal cell', 'keratinocyte', 'keratinocyte', 'leukocyte',\n", + " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'epithelial cell', 'leukocyte',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'keratinocyte', 'fibroblast', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'basal cell', 'leukocyte', 'keratinocyte',\n", + " 'basal cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'keratinocyte', 'pericyte', 'basal cell',\n", " 'epithelial cell', 'basal cell', 'basal cell', 'leukocyte',\n", - " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", - " 'fibroblast', 'epithelial cell', 'basal cell', 'epithelial cell',\n", - " 'leukocyte', 'leukocyte', 'basal cell', 'epithelial cell',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'epithelial cell',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'epithelial cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", - " 'basal cell', 'basal cell', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'vein endothelial cell', 'keratinocyte', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'epithelial cell', 'basal cell',\n", " 'basal cell', 'keratinocyte', 'basal cell', 'keratinocyte',\n", - " 'leukocyte', 'basal cell', 'basal cell', 'epithelial cell',\n", - " 'epithelial cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'pericyte', 'basal cell', 'epithelial cell', 'leukocyte',\n", - " 'basal cell', 'keratinocyte'], dtype=object)" + " 'leukocyte', 'keratinocyte', 'basal cell', 'epithelial cell',\n", + " 'keratinocyte', 'keratinocyte', 'epithelial cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte'], dtype=object)" ] }, "metadata": {}, @@ -486,28 +484,28 @@ " \n", " \n", " 0\n", - " keratinocyte\n", + " basal cell\n", " basal cell\n", " \n", " \n", " 1\n", - " leukocyte\n", - " leukocyte\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", " 2\n", - " capillary endothelial cell\n", - " vein endothelial cell\n", + " basal cell\n", + " epithelial cell\n", " \n", " \n", " 3\n", - " basal cell\n", - " basal cell\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", " 4\n", - " leukocyte\n", - " leukocyte\n", + " basal cell\n", + " keratinocyte\n", " \n", " \n", " ...\n", @@ -516,23 +514,23 @@ " \n", " \n", " 123\n", - " basal cell\n", - " basal cell\n", + " epithelial cell\n", + " epithelial cell\n", " \n", " \n", " 124\n", - " epithelial cell\n", - " epithelial cell\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 125\n", - " basal cell\n", - " leukocyte\n", + " keratinocyte\n", + " keratinocyte\n", " \n", " \n", " 126\n", - " basal cell\n", - " basal cell\n", + " leukocyte\n", + " leukocyte\n", " \n", " \n", " 127\n", @@ -545,18 +543,18 @@ "" ], "text/plain": [ - " actual cell type predicted cell type\n", - "0 keratinocyte basal cell\n", - "1 leukocyte leukocyte\n", - "2 capillary endothelial cell vein endothelial cell\n", - "3 basal cell basal cell\n", - "4 leukocyte leukocyte\n", - ".. ... ...\n", - "123 basal cell basal cell\n", - "124 epithelial cell epithelial cell\n", - "125 basal cell leukocyte\n", - "126 basal cell basal cell\n", - "127 keratinocyte keratinocyte\n", + " actual cell type predicted cell type\n", + "0 basal cell basal cell\n", + "1 keratinocyte keratinocyte\n", + "2 basal cell epithelial cell\n", + "3 keratinocyte keratinocyte\n", + "4 basal cell keratinocyte\n", + ".. ... ...\n", + "123 epithelial cell epithelial cell\n", + "124 basal cell basal cell\n", + "125 keratinocyte keratinocyte\n", + "126 leukocyte leukocyte\n", + "127 keratinocyte keratinocyte\n", "\n", "[128 rows x 2 columns]" ] From 5ab985bdb961fe067c7cb31221fd1415fd2fa89d Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 30 Aug 2024 11:30:43 -0700 Subject: [PATCH 44/70] add to README --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 60b5bb5..576b1b5 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ -# tiledbsoma-ml +# tiledbsoma_ml A Python package containing ML tools for use with `tiledbsoma`. ## Description The package currently contains a prototype PyTorch `IterableDataset` for use with the -[`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) +[`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) API. ## Getting Started @@ -17,11 +17,20 @@ Install using your favorite package installer. For exapmle, with pip: > pip install tiledbsoma-ml +Developers may install editable, from source, in the usual manner: + +> pip install -e . ### Documentation TBD +## Builds + +This is a pure Python package. To build a wheel, ensure you have the `build` package installed, and then: + +> python -m build . + ## Version History See the [CHANGELOG.md](CHANGELOG.md) file. From 075b1ab7c56c10620077940848d175777fcd29b3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 31 Aug 2024 01:29:27 +0000 Subject: [PATCH 45/70] concurrency tweak --- src/tiledbsoma_ml/pytorch.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index c69db9c..0786015 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -378,25 +378,19 @@ def _io_batch_iter( if self._shuffle_rng is None else self._shuffle_rng.permuted(obs_coords) ) + obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) logger.debug( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) - obs_io_batch = cast( - pd.DataFrame, - obs.read(coords=(obs_coords,), column_names=obs_column_names) - .concat() - .to_pandas() - .set_index("soma_joinid") - .reindex(obs_shuffled_coords, copy=False) - .reset_index(), - ) - obs_io_batch = obs_io_batch[self.obs_column_names] - + # to maximize optty's for concurrency, when in eager_fetch mode, + # create the X read iterator first, as the eager iterator will begin + # the read-ahead immediately. Then proceed to fetch obs DataFrame. + # This matters most on latent backing stores, e.g., S3. + # X_tbl_iter: Iterator[pa.Table] = X.read( coords=(obs_coords, self._var_joinids) ).tables() - obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) def make_csr( X_tbl: pa.Table, @@ -420,6 +414,20 @@ def make_csr( ) if self.use_eager_fetch: _csr_iter = _EagerIterator(_csr_iter, pool=X.context.threadpool) + + # Now that X read is potentially in progress (in eager mode), go fetch obs data + # + obs_io_batch = cast( + pd.DataFrame, + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_shuffled_coords, copy=False) + .reset_index(), + ) + obs_io_batch = obs_io_batch[self.obs_column_names] + X_io_batch = _CSR.merge(tuple(_csr_iter)) del obs_indexer, obs_coords, obs_shuffled_coords, _csr_iter From 34d4952ccdee237efeeb25af699d6fce2c9a3826 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 31 Aug 2024 22:07:32 +0000 Subject: [PATCH 46/70] additional memory reductions --- src/tiledbsoma_ml/pytorch.py | 58 +++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 0786015..3596795 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -944,7 +944,8 @@ def __init__( ) -> None: """Construct from PJV format.""" assert len(data) == len(indices) - assert len(data) < np.iinfo(indptr.dtype).max + assert len(data) <= np.iinfo(indptr.dtype).max + assert shape[1] <= np.iinfo(indices.dtype).max assert indptr[-1] == len(data) and indptr[0] == 0 self.shape = shape @@ -958,12 +959,8 @@ def from_ijd( ) -> _CSR: """Factory from COO""" nnz = len(d) - if nnz < np.iinfo(np.int32).max: - index_dtype: npt.DTypeLike = np.int32 - else: - index_dtype = np.int64 - indptr = np.zeros((shape[0] + 1), dtype=index_dtype) - indices = np.empty((nnz,), dtype=index_dtype) + indptr = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) + indices = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) data = np.empty((nnz,), dtype=d.dtype) _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) return _CSR(indptr, indices, data, shape) @@ -991,16 +988,12 @@ def densified_slice(self, row_index: slice) -> NDArrayNumber: assert isinstance(row_index, slice) assert row_index.step in (1, None) row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) - if row_idx_end - row_idx_start <= 0: - return np.zeros((0, self.shape[1]), dtype=self.data.dtype) - - indptr = self.indptr - indices = self.indices - data = self.data - out = np.zeros( - (row_idx_end - row_idx_start, self.shape[1]), dtype=self.data.dtype - ) - _csr_to_dense_inner(row_idx_start, row_idx_end, indptr, indices, data, out) + n_rows = max(row_idx_end - row_idx_start, 0) + out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) + if n_rows >= 0: + _csr_to_dense_inner( + row_idx_start, n_rows, self.indptr, self.indices, self.data, out + ) return out @staticmethod @@ -1008,23 +1001,34 @@ def merge(mtxs: Sequence[_CSR]) -> _CSR: assert len(mtxs) > 0 nnz = sum(m.nnz for m in mtxs) shape = mtxs[0].shape + for m in mtxs[1:]: + assert m.shape == mtxs[0].shape + assert m.indices.dtype == mtxs[0].indices.dtype assert all(m.shape == shape for m in mtxs) - if nnz < np.iinfo(np.int32).max: - index_dtype: npt.DTypeLike = np.int32 - else: - index_dtype = np.int64 - - indptr = np.sum([m.indptr for m in mtxs], axis=0, dtype=index_dtype) - indices = np.empty((nnz,), dtype=index_dtype) + indptr = np.sum( + [m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) + ) + indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) data = np.empty((nnz,), mtxs[0].data.dtype) _csr_merge_inner( - tuple((m.indptr, m.indices, m.data) for m in mtxs), indptr, indices, data + tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), + indptr, + indices, + data, ) return _CSR.from_pjd(indptr, indices, data, shape) +def smallest_uint_dtype(max_val: int) -> npt.DTypeLike: + for dt in [np.uint16, np.uint32]: + if max_val <= np.iinfo(dt).max: + return dt + else: + return np.uint64 + + @numba.njit(nogil=True, parallel=True) # type:ignore[misc] def _csr_merge_inner( As: Tuple[Tuple[NDArrayNumber, NDArrayNumber, NDArrayNumber], ...], # P,J,D @@ -1045,13 +1049,13 @@ def _csr_merge_inner( @numba.njit(nogil=True, parallel=True) # type:ignore[misc] def _csr_to_dense_inner( row_idx_start: int, - row_idx_end: int, + n_rows: int, indptr: NDArrayNumber, indices: NDArrayNumber, data: NDArrayNumber, out: NDArrayNumber, ) -> None: - for i in numba.prange(row_idx_start, row_idx_end): + for i in numba.prange(row_idx_start, row_idx_start + n_rows): for j in range(indptr[i], indptr[i + 1]): out[i - row_idx_start, indices[j]] = data[j] From c2e7fac8c77a43c892ed51771617f4d788668653 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 4 Sep 2024 21:52:08 +0000 Subject: [PATCH 47/70] DDP/multi-GPU support --- src/tiledbsoma_ml/pytorch.py | 121 +++++++++++++++++++++-------------- tests/test_pytorch.py | 72 ++++++++++++++------- 2 files changed, 121 insertions(+), 72 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 3596795..8c59404 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -207,16 +207,39 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: As appropriate, will chunk, shuffle and apply partitioning per worker. + IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will + always partition such that each process has the same number of samples. Where + the number of obs_joinids is not evenly divisible by the number of processes, + the number of joinids will be dropped (dropped ids can never exceed WORLD_SIZE-1). + + Abstractly, the steps taken: + 1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP) + 2. Trim the splits to be of equal length + 3. Chunk and optionally shuffle the chunks + 4. Partition by number of data loader workers (to not generate redundant batches + in cases where the DataLoader is running with `n_workers>1`). + Private method. """ assert self._obs_joinids is not None obs_joinids: npt.NDArray[np.int64] = self._obs_joinids + # 1. Get the split for the model replica/GPU + world_size, rank = _get_distributed_world_rank() + _gpu_splits = _splits(len(obs_joinids), world_size) + _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] + + # 2. Trip to be all of equal length + min_len = np.diff(_gpu_splits).min() + assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 + _gpu_split = _gpu_split[:min_len] + + # 3. Chunk and optionally shuffle chunks if self.shuffle: assert self._shuffle_rng is not None assert self.io_batch_size % self.shuffle_chunk_size == 0 shuffle_split = np.array_split( - obs_joinids, max(1, ceil(len(obs_joinids) / self.shuffle_chunk_size)) + _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) ) self._shuffle_rng.shuffle(shuffle_split) obs_joinids_chunked = list( @@ -227,21 +250,20 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: ) else: obs_joinids_chunked = np.array_split( - obs_joinids, max(1, ceil(len(obs_joinids) / self.io_batch_size)) + _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) ) - # Now extract the partition for this worker - partition, num_partitions = _get_torch_partition_info() - - obs_splits = _splits(len(obs_joinids_chunked), num_partitions) + # 4. Partition by DataLoader worker + n_workers, worker_id = _get_worker_world_rank() + obs_splits = _splits(len(obs_joinids_chunked), n_workers) obs_partition_joinids = obs_joinids_chunked[ - obs_splits[partition] : obs_splits[partition + 1] + obs_splits[worker_id] : obs_splits[worker_id + 1] ].copy() obs_joinid_iter = iter(obs_partition_joinids) if logger.isEnabledFor(logging.DEBUG): logger.debug( - f"Process {os.getpid()} handling partition {partition + 1} of {num_partitions}, " + f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}, " f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) @@ -308,7 +330,9 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter def __len__(self) -> int: - """Return approximate number of batches this iterable will product. + """Return approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or + as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) + count will reflect the size of the partition of the data assigned to the active process. See import caveats in the PyTorch [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) @@ -320,14 +344,17 @@ def __len__(self) -> int: Lifecycle: experimental """ - self._init_once() - assert self._obs_joinids is not None - div, rem = divmod(len(self._obs_joinids), self.batch_size) - return div + bool(rem) + # self._init_once() + # assert self._obs_joinids is not None + # world_size, _ = _get_distributed_world_rank() + # n_workers, _ = _get_worker_world_rank() + # div, rem = divmod(len(self._obs_joinids) // world_size, self.batch_size) + # return div + bool(rem) + return self.shape[0] @property def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect the size of the partition of the data assigned to the active process. @@ -341,7 +368,12 @@ def shape(self) -> Tuple[int, int]: self._init_once() assert self._obs_joinids is not None assert self._var_joinids is not None - return len(self._obs_joinids), len(self._var_joinids) + world_size, _ = _get_distributed_world_rank() + n_workers, _ = _get_worker_world_rank() + div, rem = divmod( + len(self._obs_joinids) // world_size // n_workers, self.batch_size + ) + return div + bool(rem), len(self._var_joinids) def __getitem__(self, index: int) -> XObsDatum: raise NotImplementedError( @@ -353,7 +385,7 @@ def _io_batch_iter( obs: soma.DataFrame, X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], - ) -> Iterator[Tuple[_CSR, pd.DataFrame]]: + ) -> Iterator[Tuple[_CSR_IO_Buffer, pd.DataFrame]]: """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of (X: csr_array, obs: DataFrame). @@ -397,9 +429,9 @@ def make_csr( obs_coords: npt.NDArray[np.int64], var_coords: npt.NDArray[np.int64], obs_indexer: soma.IntIndexer, - ) -> _CSR: + ) -> _CSR_IO_Buffer: """This function provides a GC after we throw off (large) garbage.""" - m = _CSR.from_ijd( + m = _CSR_IO_Buffer.from_ijd( obs_indexer.get_indexer(X_tbl["soma_dim_0"]), var_indexer.get_indexer(X_tbl["soma_dim_1"]), X_tbl["soma_data"].to_numpy(), @@ -428,7 +460,7 @@ def make_csr( ) obs_io_batch = obs_io_batch[self.obs_column_names] - X_io_batch = _CSR.merge(tuple(_csr_iter)) + X_io_batch = _CSR_IO_Buffer.merge(tuple(_csr_iter)) del obs_indexer, obs_coords, obs_shuffled_coords, _csr_iter gc.collect() @@ -871,34 +903,25 @@ def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: yield batch -def _get_torch_partition_info() -> Tuple[int, int]: - """Return this workers partition and total partition count as a tuple. +def _get_distributed_world_rank() -> Tuple[int, int]: + """Return tuple containing equivalent of torch.distributed world size and rank.""" + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + else: + # sometimes these are set even before torch.distributed is initialized, e.g., by torchrun + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) - Private. Used to partition the iterator in some cases. + return world_size, rank - Examples - -------- - >>> _get_torch_partition_info() - (0, 1) - """ +def _get_worker_world_rank() -> Tuple[int, int]: + """Return number of dataloader workers and our worker rank/id""" worker_info = torch.utils.data.get_worker_info() if worker_info is None: - loader_partition, loader_partitions = 0, 1 - else: - loader_partition = worker_info.id - loader_partitions = worker_info.num_workers - - if not torch.distributed.is_initialized(): - dist_partition, num_dist_partitions = 0, 1 - else: - dist_partition = torch.distributed.get_rank() - num_dist_partitions = torch.distributed.get_world_size() - - total_partitions = num_dist_partitions * loader_partitions - partition = dist_partition * loader_partitions + loader_partition - - return partition, total_partitions + return 1, 0 + return worker_info.num_workers, worker_info.id def _init_multiprocessing() -> None: @@ -921,7 +944,7 @@ def _init_multiprocessing() -> None: torch.multiprocessing.set_start_method("spawn", force=True) -class _CSR: +class _CSR_IO_Buffer: """Implement a minimal CSR matrix with specific optimizations for use in this package. Operations supported are: @@ -956,21 +979,21 @@ def __init__( @staticmethod def from_ijd( i: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] - ) -> _CSR: + ) -> _CSR_IO_Buffer: """Factory from COO""" nnz = len(d) indptr = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) indices = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) data = np.empty((nnz,), dtype=d.dtype) _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) - return _CSR(indptr, indices, data, shape) + return _CSR_IO_Buffer(indptr, indices, data, shape) @staticmethod def from_pjd( p: NDArrayNumber, j: NDArrayNumber, d: NDArrayNumber, shape: Tuple[int, int] - ) -> _CSR: + ) -> _CSR_IO_Buffer: """Factory from CSR""" - return _CSR(p, j, d, shape) + return _CSR_IO_Buffer(p, j, d, shape) @property def nnz(self) -> int: @@ -997,7 +1020,7 @@ def densified_slice(self, row_index: slice) -> NDArrayNumber: return out @staticmethod - def merge(mtxs: Sequence[_CSR]) -> _CSR: + def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer: assert len(mtxs) > 0 nnz = sum(m.nnz for m in mtxs) shape = mtxs[0].shape @@ -1018,7 +1041,7 @@ def merge(mtxs: Sequence[_CSR]) -> _CSR: indices, data, ) - return _CSR.from_pjd(indptr, indices, data, shape) + return _CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) def smallest_uint_dtype(max_val: int) -> npt.DTypeLike: diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index f64ddd5..c573bd3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -440,7 +440,12 @@ def test_multiprocessing__returns_full_result( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank", + [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], ) @pytest.mark.parametrize( "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) @@ -448,6 +453,9 @@ def test_multiprocessing__returns_full_result( def test_distributed__returns_data_partition_for_rank( PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, ) -> None: """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, using mocks to avoid having to do real PyTorch distributed setup.""" @@ -458,8 +466,8 @@ def test_distributed__returns_data_partition_for_rank( "torch.distributed.get_world_size" ) as mock_dist_get_world_size: mock_dist_is_initialized.return_value = True - mock_dist_get_rank.return_value = 1 - mock_dist_get_world_size.return_value = 3 + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( @@ -470,18 +478,26 @@ def test_distributed__returns_data_partition_for_rank( shuffle=False, ) full_result = list(iter(dp)) - soma_joinids = np.concatenate( [t[1]["soma_joinid"].to_numpy() for t in full_result] ) - # Of the 6 obs rows, the PyTorch process of rank 1 should get [2, 3] - # (rank 0 gets [0, 1], rank 2 gets [4, 5]) - assert sorted(soma_joinids) == [2, 3] + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ].tolist() + assert sorted(soma_joinids) == expected_joinids @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", [(12, 3, pytorch_x_value_gen)] + "obs_range,var_range,X_value_gen", + [(12, 3, pytorch_x_value_gen), (13, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank,num_workers,worker_id", + [ + (3, 1, 2, 0), + (3, 1, 2, 1), + ], ) @pytest.mark.parametrize( "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) @@ -489,6 +505,11 @@ def test_distributed__returns_data_partition_for_rank( def test_distributed_and_multiprocessing__returns_data_partition_for_rank( PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, + num_workers: int, + worker_id: int, ) -> None: """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch @@ -501,10 +522,12 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( ) as mock_dist_get_rank, patch( "torch.distributed.get_world_size" ) as mock_dist_get_world_size: - mock_get_worker_info.return_value = WorkerInfo(id=1, num_workers=2, seed=1234) + mock_get_worker_info.return_value = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=1234 + ) mock_dist_is_initialized.return_value = True - mock_dist_get_rank.return_value = 1 - mock_dist_get_world_size.return_value = 3 + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( @@ -521,10 +544,11 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( [t[1]["soma_joinid"].to_numpy() for t in full_result] ) - # Of the 12 obs rows, the PyTorch process of rank 1 should get [4..7], and then within that partition, - # the 2nd DataLoader process should get the second half of the rank's partition, which is just [6, 7] - # (rank 0 gets [0..3], rank 2 gets [8..11]) - assert sorted(soma_joinids) == [6, 7] + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ] + expected_joinids = np.array_split(expected_joinids, num_workers)[worker_id] + assert sorted(soma_joinids) == expected_joinids.tolist() @pytest.mark.parametrize( @@ -827,19 +851,21 @@ def test_splits() -> None: ) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: - from tiledbsoma_ml.pytorch import _CSR + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.05) sp_csr = sp_coo.tocsr() - _ncsr = _CSR.from_ijd(sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape) + _ncsr = _CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ) assert _ncsr.nnz == sp_coo.nnz == sp_csr.nnz assert _ncsr.dtype == sp_coo.dtype == sp_csr.dtype assert _ncsr.nbytes == ( _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes ) - # _CSR makes no guarantees about minor axis ordering (ie.., "canonical" form), so + # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie.., "canonical" form), so # use the SciPy sparse csr package to validate by round-tripping. assert ( sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) @@ -859,11 +885,11 @@ def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) - ) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32]) def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) -> None: - from tiledbsoma_ml.pytorch import _CSR + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer sp_csr = sparse.random(shape[0], shape[1], dtype=dtype, format="csr", density=0.05) - _ncsr = _CSR.from_pjd( + _ncsr = _CSR_IO_Buffer.from_pjd( sp_csr.indptr.copy(), sp_csr.indices.copy(), sp_csr.data.copy(), @@ -892,7 +918,7 @@ def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) - def test_csr__merge( shape: Tuple[int, int], dtype: npt.DTypeLike, n_splits: int ) -> None: - from tiledbsoma_ml.pytorch import _CSR + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer sp_coo = sparse.random(shape[0], shape[1], dtype=dtype, format="coo", density=0.5) splits = [ @@ -903,8 +929,8 @@ def test_csr__merge( np.array_split(sp_coo.data, n_splits), ) ] - _ncsr = _CSR.merge( - [_CSR.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] + _ncsr = _CSR_IO_Buffer.merge( + [_CSR_IO_Buffer.from_ijd(i, j, d, shape=sp_coo.shape) for i, j, d in splits] ) assert ( From 723fa21377973541fc7f69bbcd5587a70251f7c3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 6 Sep 2024 18:24:45 +0000 Subject: [PATCH 48/70] add further concurrency to CSR construction --- src/tiledbsoma_ml/pytorch.py | 56 +++++++++++++++++++++++++++++------- tests/test_pytorch.py | 49 +++++++++++++++---------------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 8c59404..91f5734 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -39,7 +39,7 @@ import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias import tiledbsoma as soma @@ -1043,6 +1043,11 @@ def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer: ) return _CSR_IO_Buffer.from_pjd(indptr, indices, data, shape) + def sort_indices(self) -> Self: + """Sort indices, IN PLACE.""" + _csr_sort_indices(self.indptr, self.indices, self.data) + return self + def smallest_uint_dtype(max_val: int) -> npt.DTypeLike: for dt in [np.uint16, np.uint32]: @@ -1128,19 +1133,48 @@ def _coo_to_csr_inner( cumsum += tmp Bp[n_rows] = nnz - # reorganize all of the data. side-effect: pointers shifted. - for n in range(nnz): - row = Ai[n] - dst_row = Bp[row] - - Bj[dst_row] = Aj[n] - Bd[dst_row] = Ad[n] - - Bp[row] += 1 + # Reorganize all of the data. Side-effect: pointers shifted (reversed in the + # subsequent section). + # + # Method is concurrent (partioned by rows) if number of rows is greater + # than 2**partition_bits. This partitioning scheme leverages the fact + # that reads are much cheaper than writes. + # + # The code is equivalent to: + # for n in range(nnz): + # row = Ai[n] + # dst_row = Bp[row] + # Bj[dst_row] = Aj[n] + # Bd[dst_row] = Ad[n] + # Bp[row] += 1 + + partition_bits = 13 + n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits + for p in numba.prange(n_partitions): + for n in range(nnz): + row = Ai[n] + if (row >> partition_bits) != p: + continue + dst_row = Bp[row] + Bj[dst_row] = Aj[n] + Bd[dst_row] = Ad[n] + Bp[row] += 1 - # and shift the pointers by one (ie., start at zero) + # Shift the pointers by one slot (ie., start at zero) prev_ptr = 0 for n in range(n_rows + 1): tmp = Bp[n] Bp[n] = prev_ptr prev_ptr = tmp + + +@numba.njit(nogil=True, parallel=True) # type:ignore[misc] +def _csr_sort_indices(Bp: NDArrayNumber, Bj: NDArrayNumber, Bd: NDArrayNumber) -> None: + """In-place sort of minor axis indices""" + n_rows = len(Bp) - 1 + for r in numba.prange(n_rows): + row_start = Bp[r] + row_end = Bp[r + 1] + order = np.argsort(Bj[row_start:row_end]) + Bj[row_start:row_end] = Bj[row_start:row_end][order] + Bd[row_start:row_end] = Bd[row_start:row_end][order] diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index c573bd3..12aac75 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -825,26 +825,6 @@ def test_splits() -> None: _splits(10, -1) -# temp comment out while building _CSR tests -# def test_csr_to_dense() -> None: -# from tiledbsoma_ml.pytorch import _csr_to_dense - -# coo = sparse.eye(1001, 77, format="coo", dtype=np.float32) - -# assert np.array_equal( -# sparse.csr_array(coo).todense(), _csr_to_dense(sparse.csr_array(coo)) -# ) -# assert np.array_equal( -# sparse.csr_matrix(coo).todense(), _csr_to_dense(sparse.csr_matrix(coo)) -# ) - -# csr = sparse.csr_array(coo) -# assert np.array_equal(csr.todense(), _csr_to_dense(csr)) -# assert np.array_equal(csr[1:, :].todense(), _csr_to_dense(csr[1:, :])) -# assert np.array_equal(csr[:, 1:].todense(), _csr_to_dense(csr[:, 1:])) -# assert np.array_equal(csr[3:501, 1:22].todense(), _csr_to_dense(csr[3:501, 1:22])) - - @pytest.mark.parametrize( # keep these small as we materialize as a dense ndarray "shape", [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], @@ -865,8 +845,8 @@ def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) - _ncsr.data.nbytes + _ncsr.indices.nbytes + _ncsr.indptr.nbytes ) - # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie.., "canonical" form), so - # use the SciPy sparse csr package to validate by round-tripping. + # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. assert ( sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) != sp_csr @@ -896,8 +876,8 @@ def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) - shape=sp_csr.shape, ) - # _CSR makes no guarantees about minor axis ordering (ie.., "canonical" form), so - # use the SciPy sparse csr package to validate by round-tripping. + # _CSR_IO_Buffer makes no guarantees about minor axis ordering (ie, "canonical" form) until + # sort_indices is called, so use the SciPy sparse csr package to validate by round-tripping. assert ( sparse.csr_matrix((_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape) != sp_csr @@ -939,3 +919,24 @@ def test_csr__merge( (_ncsr.data, _ncsr.indices, _ncsr.indptr), shape=_ncsr.shape ) ).nnz == 0 + + +@pytest.mark.parametrize( + "shape", + [(100, 10), (10, 100), (1, 1), (1, 100), (100, 1), (0, 0), (10, 0), (0, 10)], +) +def test_csr__sort_indices(shape: Tuple[int, int]) -> None: + from tiledbsoma_ml.pytorch import _CSR_IO_Buffer + + sp_coo = sparse.random( + shape[0], shape[1], dtype=np.float32, format="coo", density=0.05 + ) + sp_csr = sp_coo.tocsr() + + _ncsr = _CSR_IO_Buffer.from_ijd( + sp_coo.row, sp_coo.col, sp_coo.data, shape=sp_coo.shape + ).sort_indices() + + assert np.array_equal(sp_csr.indptr, _ncsr.indptr) + assert np.array_equal(sp_csr.indices, _ncsr.indices) + assert np.array_equal(sp_csr.data, _ncsr.data) From ea38c5c33041548af829cae9d7418c3fb58836b2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 10 Sep 2024 02:38:22 +0000 Subject: [PATCH 49/70] cleanup --- src/tiledbsoma_ml/pytorch.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 91f5734..d4bd8f7 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -344,12 +344,6 @@ def __len__(self) -> int: Lifecycle: experimental """ - # self._init_once() - # assert self._obs_joinids is not None - # world_size, _ = _get_distributed_world_rank() - # n_workers, _ = _get_worker_world_rank() - # div, rem = divmod(len(self._obs_joinids) // world_size, self.batch_size) - # return div + bool(rem) return self.shape[0] @property @@ -795,7 +789,6 @@ def shape(self) -> Tuple[int, int]: def experiment_dataloader( ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, - # num_workers: int = 0, **dataloader_kwargs: Any, ) -> torch.utils.data.DataLoader: """Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a @@ -813,8 +806,6 @@ def experiment_dataloader( ds: A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May include chained data pipes. - num_workers: - How many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) **dataloader_kwargs: Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not From f704c837d2c78051cb0febb10cfb7d367bf08844 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Fri, 13 Sep 2024 21:19:51 +0000 Subject: [PATCH 50/70] fix multi-gpu hang due to incorrect __len__ return value --- src/tiledbsoma_ml/pytorch.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index d4bd8f7..297c4ab 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -192,6 +192,8 @@ def __init__( self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self.shuffle_chunk_size = shuffle_chunk_size self._initialized = False + self._obs_joinids_partition: npt.NDArray[np.int64] | None = None + self._obs_partition_length: int = -1 if self.shuffle: # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. @@ -202,7 +204,7 @@ def __init__( if not self.obs_column_names: raise ValueError("Must specify at least one value in `obs_column_names`") - def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: + def _create_obs_joinids_partition(self) -> None: """Create iterator over obs id chunks with split size of (roughly) io_batch_size. As appropriate, will chunk, shuffle and apply partitioning per worker. @@ -229,7 +231,7 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: _gpu_splits = _splits(len(obs_joinids), world_size) _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] - # 2. Trip to be all of equal length + # 2. Trim to be all of equal length min_len = np.diff(_gpu_splits).min() assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 _gpu_split = _gpu_split[:min_len] @@ -259,7 +261,10 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: obs_partition_joinids = obs_joinids_chunked[ obs_splits[worker_id] : obs_splits[worker_id + 1] ].copy() - obs_joinid_iter = iter(obs_partition_joinids) + self._obs_joinids_partition = obs_partition_joinids + self._obs_partition_length = sum( + len(partition) for partition in obs_partition_joinids + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -267,8 +272,6 @@ def _create_obs_joinid_iter(self) -> Iterator[npt.NDArray[np.int64]]: f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) - return obs_joinid_iter - def _init_once(self, exp: soma.Experiment | None = None) -> None: """One-time per worker initialization. @@ -301,6 +304,7 @@ def _init_once(self, exp: soma.Experiment | None = None) -> None: self._obs_joinids = query.obs_joinids().to_numpy() self._var_joinids = query.var_joinids().to_numpy() + self._create_obs_joinids_partition() self._initialized = True def __iter__(self) -> Iterator[XObsDatum]: @@ -320,7 +324,8 @@ def __iter__(self) -> Iterator[XObsDatum]: "ExperimentAxisQueryIterDataPipe only supported on X layers which are of type SparseNDArray" ) - obs_joinid_iter = self._create_obs_joinid_iter() + assert self._obs_joinids_partition is not None + obs_joinid_iter = iter(self._obs_joinids_partition) _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) if self.use_eager_fetch: _mini_batch_iter = _EagerIterator( @@ -360,13 +365,10 @@ def shape(self) -> Tuple[int, int]: experimental """ self._init_once() - assert self._obs_joinids is not None assert self._var_joinids is not None - world_size, _ = _get_distributed_world_rank() - n_workers, _ = _get_worker_world_rank() - div, rem = divmod( - len(self._obs_joinids) // world_size // n_workers, self.batch_size - ) + assert self._obs_joinids_partition is not None + assert self._obs_partition_length >= 0 + div, rem = divmod(self._obs_partition_length, self.batch_size) return div + bool(rem), len(self._var_joinids) def __getitem__(self, index: int) -> XObsDatum: From 8e47320d77ac8c360d079fd50774961137bdf9b3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 14 Sep 2024 05:14:08 +0000 Subject: [PATCH 51/70] compat with Lightning --- src/tiledbsoma_ml/pytorch.py | 56 +++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 297c4ab..644741b 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -192,8 +192,6 @@ def __init__( self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self.shuffle_chunk_size = shuffle_chunk_size self._initialized = False - self._obs_joinids_partition: npt.NDArray[np.int64] | None = None - self._obs_partition_length: int = -1 if self.shuffle: # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. @@ -204,7 +202,7 @@ def __init__( if not self.obs_column_names: raise ValueError("Must specify at least one value in `obs_column_names`") - def _create_obs_joinids_partition(self) -> None: + def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: """Create iterator over obs id chunks with split size of (roughly) io_batch_size. As appropriate, will chunk, shuffle and apply partitioning per worker. @@ -261,10 +259,6 @@ def _create_obs_joinids_partition(self) -> None: obs_partition_joinids = obs_joinids_chunked[ obs_splits[worker_id] : obs_splits[worker_id + 1] ].copy() - self._obs_joinids_partition = obs_partition_joinids - self._obs_partition_length = sum( - len(partition) for partition in obs_partition_joinids - ) if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -272,6 +266,8 @@ def _create_obs_joinids_partition(self) -> None: f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) + return iter(obs_partition_joinids) + def _init_once(self, exp: soma.Experiment | None = None) -> None: """One-time per worker initialization. @@ -304,7 +300,6 @@ def _init_once(self, exp: soma.Experiment | None = None) -> None: self._obs_joinids = query.obs_joinids().to_numpy() self._var_joinids = query.var_joinids().to_numpy() - self._create_obs_joinids_partition() self._initialized = True def __iter__(self) -> Iterator[XObsDatum]: @@ -324,8 +319,7 @@ def __iter__(self) -> Iterator[XObsDatum]: "ExperimentAxisQueryIterDataPipe only supported on X layers which are of type SparseNDArray" ) - assert self._obs_joinids_partition is not None - obs_joinid_iter = iter(self._obs_joinids_partition) + obs_joinid_iter = self._create_obs_joinids_partition() _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) if self.use_eager_fetch: _mini_batch_iter = _EagerIterator( @@ -365,10 +359,12 @@ def shape(self) -> Tuple[int, int]: experimental """ self._init_once() + assert self._obs_joinids is not None assert self._var_joinids is not None - assert self._obs_joinids_partition is not None - assert self._obs_partition_length >= 0 - div, rem = divmod(self._obs_partition_length, self.batch_size) + world_size, _ = _get_distributed_world_rank() + n_workers, _ = _get_worker_world_rank() + partition_len = len(self._obs_joinids) // world_size // n_workers + div, rem = divmod(partition_len, self.batch_size) return div + bool(rem), len(self._var_joinids) def __getitem__(self, index: int) -> XObsDatum: @@ -898,23 +894,37 @@ def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: def _get_distributed_world_rank() -> Tuple[int, int]: """Return tuple containing equivalent of torch.distributed world size and rank.""" - if torch.distributed.is_initialized(): + world_size, rank = 1, 0 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # is a NODE_RANK for the node's rank, but no way to tell the local node's + # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a + # proxy, which works fine on a single-CPU box. TODO: could throw/error + # if NODE_RANK != 0. + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + elif torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() - else: - # sometimes these are set even before torch.distributed is initialized, e.g., by torchrun - world_size = int(os.environ.get("WORLD_SIZE", 1)) - rank = int(os.environ.get("RANK", 0)) return world_size, rank def _get_worker_world_rank() -> Tuple[int, int]: - """Return number of dataloader workers and our worker rank/id""" - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - return 1, 0 - return worker_info.num_workers, worker_info.id + """Return number of DataLoader workers and our worker rank/id""" + num_workers, worker = 1, 0 + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + num_workers = int(os.environ["NUM_WORKERS"]) + worker = int(os.environ["WORKER"]) + else: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker = worker_info.id + return num_workers, worker def _init_multiprocessing() -> None: From 70cc170f0d450d5c5d548e738131c1785cec4f0b Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 14 Sep 2024 09:24:08 -0700 Subject: [PATCH 52/70] PR review edits --- notebooks/tutorial_pytorch.ipynb | 18 +++---- src/tiledbsoma_ml/__init__.py | 4 +- src/tiledbsoma_ml/pytorch.py | 86 ++++++++++++++++---------------- tests/test_pytorch.py | 78 ++++++++++++++--------------- 4 files changed, 93 insertions(+), 93 deletions(-) diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index d24e6b5..cce57fb 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -6,7 +6,7 @@ "source": [ "# Training a PyTorch Model\n", "\n", - "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryDataPipe`, and not as an example of how to train a biologically useful model.\n", + "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n", "\n", "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", "\n", @@ -29,9 +29,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create an ExperimentAxisQueryDataPipe\n", + "## Create an ExperimentAxisQueryIterDataPipe\n", "\n", - "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", + "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryIterDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", "\n", "We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`." ] @@ -61,7 +61,7 @@ " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", "\n", - " experiment_dataset = soma_ml.ExperimentAxisQueryDataPipe(\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n", " query,\n", " X_name=\"raw\",\n", " obs_column_names=[\"cell_type\"],\n", @@ -74,11 +74,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `ExperimentAxisQueryDataPipe` class explained\n", + "### `ExperimentAxisQueryIterDataPipe` class explained\n", "\n", - "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", + "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", "\n", - "### `ExperimentAxisQueryDataPipe` parameters explained\n", + "### `ExperimentAxisQueryIterDataPipe` parameters explained\n", "\n", "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", "\n", @@ -160,7 +160,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." + "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryIterDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." ] }, { @@ -388,7 +388,7 @@ "source": [ "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", "\n", - "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." ] }, { diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index e610573..263608f 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -6,15 +6,15 @@ """An API to support machine learning applications built on SOMA.""" from .pytorch import ( - ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, experiment_dataloader, ) __version__ = "0.1.0-dev" __all__ = [ - "ExperimentAxisQueryDataPipe", + "ExperimentAxisQueryIterDataPipe", "ExperimentAxisQueryIterableDataset", "experiment_dataloader", ] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 644741b..a22344a 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -95,12 +95,12 @@ def open_experiment(self) -> Iterator[soma.Experiment]: class ExperimentAxisQueryIterable(Iterable[XObsDatum]): """An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator - produces equal sized ``X`` and ``obs`` data, in the form of a :class:``numpy.ndarray`` and - :class:`pandas.DataFrame`. + produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and + :class:`pandas.DataFrame`, respectively. Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` - and `ExperimentAxisQueryDataPipe` for more details on usage. + and `ExperimentAxisQueryIterDataPipe` for more details on usage. Lifecycle: experimental @@ -122,11 +122,11 @@ def __init__( Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. The resulting iterator will produce a 2-tuple containing associated slices of ``X`` and ``obs`` data, as - a NumPy ``ndarray`` and a Pandas ``DataFrame`` respectively. + a NumPy :class:`numpy.ndarray` and a Pandas :class:`pandas.DataFrame`, respectively. Args: query: - A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data to iterate over. X_name: The name of the X layer to read. obs_column_names: @@ -134,27 +134,25 @@ def __init__( Default is ``('soma_joinid',)``. batch_size: The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of - ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). - - Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` - batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` - batch_size parameter to ``None``. + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s + ``batch_size`` parameter to ``None``. shuffle: Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. io_batch_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of - this class's behavior: 1) The maximum memory utilization, with larger values providing - better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling - (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but - may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers - are employed). + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts: + 1. Maximum memory utilization, larger values provide better read performance, but require more memory. + 2. The number of rows read prior to shuffling (see the ``shuffle`` parameter for details). + The default value of 65,536 provides high performance but may need to be reduced in memory-limited hosts + or when using a large number of :class:`DataLoader` workers. shuffle_chunk_size: - The number of contiguous rows sampled, prior to concatenation and shuffling. - Larger numbers correspond to more randomness per training batch. + The number of contiguous rows sampled prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. If ``shuffle == False``, this parameter is ignored. seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker processes. use_eager_fetch: @@ -165,7 +163,7 @@ def __init__( Returns: An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to - a :class:`torch.data.utils.DataLoader` instance. + a :class:`torch.utils.data.DataLoader` instance. Raises: ``ValueError`` on various unsupported or malformed parameter values. @@ -316,7 +314,7 @@ def __iter__(self) -> Iterator[XObsDatum]: X = exp.ms[self.measurement_name].X[self.layer_name] if not isinstance(X, soma.SparseNDArray): raise NotImplementedError( - "ExperimentAxisQueryIterDataPipe only supported on X layers which are of type SparseNDArray" + "ExperimentAxisQueryIterable only supports X layers which are of type SparseNDArray" ) obs_joinid_iter = self._create_obs_joinids_partition() @@ -329,13 +327,13 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter def __len__(self) -> int: - """Return approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or + """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) - count will reflect the size of the partition of the data assigned to the active process. + count will reflect the size of the data partition assigned to the active process. - See import caveats in the PyTorch + See important caveats in the PyTorch [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) - domentation regarding ``len(dataloader)``, which also apply to this class. + documentation regarding ``len(dataloader)``, which also apply to this class. Returns: An ``int``. @@ -350,10 +348,10 @@ def shape(self) -> Tuple[int, int]: """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the partition of the data assigned to the active process. + the size of the data partition assigned to the active process. Returns: - A 2-tuple of ``int``s, for obs and var counts, respectively. + A tuple of two ``int`` values: number of obs, number of vars. Lifecycle: experimental @@ -520,7 +518,7 @@ def _mini_batch_iter( yield result -class ExperimentAxisQueryDataPipe( +class ExperimentAxisQueryIterDataPipe( torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] torch.utils.data.dataset.Dataset[XObsDatum] ], @@ -608,7 +606,8 @@ class ExperimentAxisQueryIterableDataset( This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of - ``obs`` and ``X`` data. Each iteration will yield a tuple containing an NumPy ndarray and a Pandas DataFrame. + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` + and a :class:`pandas.DataFrame`. For example: @@ -617,7 +616,7 @@ class ExperimentAxisQueryIterableDataset( >>> import tiledbsoma_ml >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: - ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(ds) + ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) dataloader = torch.utils.data.DataLoader(ds) >>> data = next(iter(dataloader)) >>> data @@ -638,20 +637,21 @@ class ExperimentAxisQueryIterableDataset( The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). - The `io_batch_size` parameter determines the number of rows read, from which mini-batches are yielded. A - larger value will increase total memory usage, and may reduce average read time per row. + The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage and may reduce average read time per row. Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using :class:`DataLoader` shuffling. The shuffling algorithm works as follows: 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". - 2. A random selection of shuffle `chunks` is drawn and read as a single I/O buffer (of size ``io_buffer_size``). + 2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``). 3. The entire I/O buffer is shuffled. Put another way, we read randomly selected groups of observations from across all query results, concatenate those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle - is therefore determiend by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` - (number of rows in each draw). + is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` + (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O + performance. Lifecycle: experimental @@ -679,13 +679,13 @@ def __init__( query: A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. X_name: - The name of the X layer to read. + The name of the ``X`` layer to read. obs_column_names: The names of the ``obs`` columns to return. At least one column name must be specified. Default is ``('soma_joinid',)``. batch_size: The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of - ``1`` will result in :class:`torch.Tensor` of rank 1 being returns (a single row); larger values will + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` @@ -702,7 +702,7 @@ def __init__( are employed). shuffle_chunk_size: The number of contiguous rows sampled, prior to concatenation and shuffling. - Larger numbers correspond to more randomness per training batch. + Larger numbers correspond to less randomness, but greater read performance. If ``shuffle == False``, this parameter is ignored. seed: The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using @@ -756,9 +756,9 @@ def __iter__(self) -> Iterator[XObsDatum]: def __len__(self) -> int: """Return approximate number of batches this iterable will produce. - See import caveats in the PyTorch + See important caveats in the PyTorch [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) - domentation regarding ``len(dataloader)``, which also apply to this class. + documentation regarding ``len(dataloader)``, which also apply to this class. Returns: An ``int``. @@ -794,7 +794,7 @@ def experiment_dataloader( or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`. Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant, - when using loaders form this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. + when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. Specifying any of these parameters will result in an error. Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on @@ -807,7 +807,7 @@ def experiment_dataloader( **dataloader_kwargs: Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor, except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not - supported when data loaders in this module. + supported when using data loaders in this module. Returns: A :class:`torch.utils.data.DataLoader`. diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 12aac75..f671a93 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -27,9 +27,9 @@ # tests. try: from tiledbsoma_ml.pytorch import ( - ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterable, ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, experiment_dataloader, ) from torch.utils.data._utils.worker import WorkerInfo @@ -153,10 +153,10 @@ def soma_experiment( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_non_batched( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -190,10 +190,10 @@ def test_non_batched( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_uneven_soma_and_result_batches( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -226,10 +226,10 @@ def test_uneven_soma_and_result_batches( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_batching__all_batches_full_size( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -266,10 +266,10 @@ def test_batching__all_batches_full_size( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_unique_soma_joinids( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -293,10 +293,10 @@ def test_unique_soma_joinids( [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_batching__partial_final_batch_size( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -324,10 +324,10 @@ def test_batching__partial_final_batch_size( [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_batching__exactly_one_batch( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -355,10 +355,10 @@ def test_batching__exactly_one_batch( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_batching__empty_query_result( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -386,10 +386,10 @@ def test_batching__empty_query_result( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_batching__partial_soma_batches_are_concatenated( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -413,13 +413,13 @@ def test_batching__partial_soma_batches_are_concatenated( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_multiprocessing__returns_full_result( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, ) -> None: - """Tests the ExperimentAxisQueryDataPipe provides all data, as collected from multiple processes that are managed by a + """Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a PyTorch DataLoader with multiple workers configured.""" with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( @@ -428,7 +428,7 @@ def test_multiprocessing__returns_full_result( obs_column_names=["soma_joinid", "label"], io_batch_size=3, # two chunks, one per worker ) - # Note we're testing the ExperimentAxisQueryDataPipe via a DataLoader, since this is what sets up the multiprocessing + # Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing dl = experiment_dataloader(dp, num_workers=2) full_result = list(iter(dl)) @@ -448,10 +448,10 @@ def test_multiprocessing__returns_full_result( [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_distributed__returns_data_partition_for_rank( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, obs_range: int, world_size: int, @@ -500,10 +500,10 @@ def test_distributed__returns_data_partition_for_rank( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_distributed_and_multiprocessing__returns_data_partition_for_rank( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, obs_range: int, world_size: int, @@ -556,10 +556,10 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_experiment_dataloader__non_batched( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -586,10 +586,10 @@ def test_experiment_dataloader__non_batched( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_experiment_dataloader__batched( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -617,10 +617,10 @@ def test_experiment_dataloader__batched( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_experiment_dataloader__batched_length( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -642,10 +642,10 @@ def test_experiment_dataloader__batched_length( [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test_experiment_dataloader__collate_fn( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, batch_size: int, ) -> None: @@ -680,10 +680,10 @@ def collate_fn( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test__X_tensor_dtype_matches_X_matrix( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -707,7 +707,7 @@ def test__pytorch_splitting( soma_experiment: Experiment, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = ExperimentAxisQueryDataPipe( + dp = ExperimentAxisQueryIterDataPipe( query, X_name="raw", obs_column_names=["label"], @@ -726,10 +726,10 @@ def test__pytorch_splitting( "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) def test__shuffle( - PipeClass: ExperimentAxisQueryDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: @@ -779,7 +779,7 @@ def test_experiment_axis_query_iterable_error_checks( def test_experiment_dataloader__unsupported_params__fails() -> None: with patch( - "tiledbsoma_ml.pytorch.ExperimentAxisQueryDataPipe" + "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe" ) as dummy_exp_data_pipe: with pytest.raises(ValueError): experiment_dataloader(dummy_exp_data_pipe, shuffle=True) From 37bc9b1adeb614cfefa5bfe79523925d8991d721 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Sat, 14 Sep 2024 09:29:59 -0700 Subject: [PATCH 53/70] formatting --- src/tiledbsoma_ml/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index a22344a..f82dfd6 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -135,7 +135,7 @@ def __init__( batch_size: The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s ``batch_size`` parameter to ``None``. @@ -650,7 +650,7 @@ class ExperimentAxisQueryIterableDataset( Put another way, we read randomly selected groups of observations from across all query results, concatenate those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` - (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O + (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O performance. Lifecycle: From b0c4547c25bf96e59a142405659e4a28ed39d6b2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 16 Sep 2024 19:49:55 +0000 Subject: [PATCH 54/70] add py.typed to package --- pyproject.toml | 3 +++ src/tiledbsoma_ml/py.typed | 2 ++ 2 files changed, 5 insertions(+) create mode 100644 src/tiledbsoma_ml/py.typed diff --git a/pyproject.toml b/pyproject.toml index 6bf4654..530d1f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,9 @@ Changelog = "https://github.com/single-cell-data/TileDB-SOMA/blob/main/CHANGELOG [tool.setuptools.dynamic] version = {attr = "tiledbsoma_ml.__version__"} +[tool.setuptools.package-data] +"tiledbsoma_ml" = ["py.typed"] + [tool.setuptools_scm] root = "../../.." diff --git a/src/tiledbsoma_ml/py.typed b/src/tiledbsoma_ml/py.typed new file mode 100644 index 0000000..288a150 --- /dev/null +++ b/src/tiledbsoma_ml/py.typed @@ -0,0 +1,2 @@ +# Marker file to indicate that this package contains Python typing information, +# and that mypy can use it to typecheck client code. From 8ae39927c5ec50b900a5eba724cb3f0d6e04cbe7 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 16 Sep 2024 23:28:26 +0000 Subject: [PATCH 55/70] add sparse support --- src/tiledbsoma_ml/pytorch.py | 108 ++++++++++++++++++++++++------- tests/test_pytorch.py | 120 +++++++++++++++++++++++++++++++---- 2 files changed, 191 insertions(+), 37 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index f82dfd6..f6c4351 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -36,6 +36,7 @@ import numpy.typing as npt import pandas as pd import pyarrow as pa +import scipy.sparse as sparse import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator @@ -53,13 +54,18 @@ # restricting this to when we are running a type checker. TODO: remove # the conditional when Python 3.8 support is dropped. NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] + XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix] else: NDArrayNumber: TypeAlias = np.ndarray + XDatum: TypeAlias = np.ndarray | sparse.csr_matrix -XObsDatum: TypeAlias = Tuple[NDArrayNumber, pd.DataFrame] +XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame] """Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, -which pairs a :class:`numpy.ndarray` of ``X`` row(s) with a :class:`pandas.DataFrame` of ``obs`` row(s). If the -``batch_size`` is 1, the objects are of rank 1, else they are of rank 2.""" +which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case, +the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` +respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is +returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` +will be returned with rank 1; in all other cases, objects are returned with rank 2.""" @attrs.define(frozen=True, kw_only=True) @@ -115,14 +121,16 @@ def __init__( shuffle: bool = True, io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, seed: int | None = None, use_eager_fetch: bool = True, ): """ Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. - The resulting iterator will produce a 2-tuple containing associated slices of ``X`` and ``obs`` data, as - a NumPy :class:`numpy.ndarray` and a Pandas :class:`pandas.DataFrame`, respectively. + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a + Pandas :class:`pandas.DataFrame`, respectively. Args: query: @@ -151,6 +159,9 @@ def __init__( The number of contiguous rows sampled prior to concatenation and shuffling. Larger numbers correspond to less randomness, but greater read performance. If ``shuffle == False``, this parameter is ignored. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. seed: The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker @@ -184,6 +195,7 @@ def __init__( self.batch_size = batch_size self.io_batch_size = io_batch_size self.shuffle = shuffle + self.return_sparse_X = return_sparse_X self.use_eager_fetch = use_eager_fetch self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None @@ -309,6 +321,17 @@ def __iter__(self) -> Iterator[XObsDatum]: Lifecycle: experimental """ + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + with self.experiment_locator.open_experiment() as exp: self._init_once(exp) X = exp.ms[self.measurement_name].X[self.layer_name] @@ -414,7 +437,7 @@ def _io_batch_iter( coords=(obs_coords, self._var_joinids) ).tables() - def make_csr( + def make_io_buffer( X_tbl: pa.Table, obs_coords: npt.NDArray[np.int64], var_coords: npt.NDArray[np.int64], @@ -430,12 +453,12 @@ def make_csr( gc.collect(generation=0) return m - _csr_iter = ( - make_csr(X_tbl, obs_coords, self._var_joinids, obs_indexer) + _io_buf_iter = ( + make_io_buffer(X_tbl, obs_coords, self._var_joinids, obs_indexer) for X_tbl in X_tbl_iter ) if self.use_eager_fetch: - _csr_iter = _EagerIterator(_csr_iter, pool=X.context.threadpool) + _io_buf_iter = _EagerIterator(_io_buf_iter, pool=X.context.threadpool) # Now that X read is potentially in progress (in eager mode), go fetch obs data # @@ -450,9 +473,9 @@ def make_csr( ) obs_io_batch = obs_io_batch[self.obs_column_names] - X_io_batch = _CSR_IO_Buffer.merge(tuple(_csr_iter)) + X_io_batch = _CSR_IO_Buffer.merge(tuple(_io_buf_iter)) - del obs_indexer, obs_coords, obs_shuffled_coords, _csr_iter + del obs_indexer, obs_coords, obs_shuffled_coords, _io_buf_iter gc.collect() tm = time.perf_counter() - st_time @@ -489,20 +512,34 @@ def _mini_batch_iter( while iob_idx < iob_len: if result is None: # perform zero copy slice where possible - result = ( - X_io_batch.densified_slice( + X_datum = ( + X_io_batch.slice_toscipy( slice(iob_idx, iob_idx + mini_batch_size) - ), + ) + if self.return_sparse_X + else X_io_batch.slice_tonumpy( + slice(iob_idx, iob_idx + mini_batch_size) + ) + ) + result = ( + X_datum, obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], ) iob_idx += len(result[1]) else: - # use remanent from previous IO batch + # use any remnant from previous IO batch to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) + X_datum = ( + sparse.vstack( + [result[0], X_io_batch.slice_toscipy(slice(0, to_take))] + ) + if self.return_sparse_X + else np.concatenate( + [result[0], X_io_batch.slice_tonumpy(slice(0, to_take))] + ) + ) result = ( - np.concatenate( - [result[0], X_io_batch.densified_slice(slice(0, to_take))] - ), + X_datum, pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), ) iob_idx += to_take @@ -513,7 +550,7 @@ def _mini_batch_iter( result = None else: - # yield a remnant, if any + # yield the remnant, if any if result is not None: yield result @@ -545,6 +582,7 @@ def __init__( seed: int | None = None, io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, use_eager_fetch: bool = True, ): """ @@ -562,6 +600,7 @@ def __init__( shuffle=shuffle, seed=seed, io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, use_eager_fetch=use_eager_fetch, shuffle_chunk_size=shuffle_chunk_size, ) @@ -667,13 +706,14 @@ def __init__( seed: int | None = None, io_batch_size: int = 2**16, shuffle_chunk_size: int = 64, + return_sparse_X: bool = False, use_eager_fetch: bool = True, ): """ Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. - The resulting iterator will produce a 2-tuple containing associated slices of ``X`` and ``obs`` data, as - a NumPy ``ndarray`` and a Pandas ``DataFrame`` respectively. + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. Args: query: @@ -704,6 +744,9 @@ def __init__( The number of contiguous rows sampled, prior to concatenation and shuffling. Larger numbers correspond to less randomness, but greater read performance. If ``shuffle == False``, this parameter is ignored. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. seed: The random seed used for shuffling. Defaults to ``None`` (no seed). This arguiment *must* be specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker @@ -734,6 +777,7 @@ def __init__( shuffle=shuffle, seed=seed, io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, use_eager_fetch=use_eager_fetch, shuffle_chunk_size=shuffle_chunk_size, ) @@ -777,7 +821,7 @@ def shape(self) -> Tuple[int, int]: the size of the partition of the data assigned to the active process. Returns: - A 2-tuple of ``int``s, for obs and var counts, respectively. + A tuple of ``int``s, for obs and var counts, respectively. Lifecycle: experimental @@ -955,6 +999,7 @@ class _CSR_IO_Buffer: and a final "merge" step which combines the result. * Zero intermediate copy conversion of an arbitrary row slice to dense (ie., mini-batch extraction). * Parallel ops where it makes sense (construction, merge, etc) + * Minimize memory use for index arrays Overall is significantly faster, and uses less memory, than the equivalent scipy.sparse operations. """ @@ -1010,7 +1055,8 @@ def nbytes(self) -> int: def dtype(self) -> npt.DTypeLike: return self.data.dtype - def densified_slice(self, row_index: slice) -> NDArrayNumber: + def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: + """Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis.""" assert isinstance(row_index, slice) assert row_index.step in (1, None) row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) @@ -1022,6 +1068,22 @@ def densified_slice(self, row_index: slice) -> NDArrayNumber: ) return out + def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: + """Extract slice as a sparse.csr_matrix. Does not assume any paritcular ordering of minor axis, but + will return a canonically ordered scipy sparse object.""" + assert isinstance(row_index, slice) + assert row_index.step in (1, None) + row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) + n_rows = max(row_idx_end - row_idx_start, 0) + if n_rows == 0: + return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) + + indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() + indices = self.indices[indptr[0] : indptr[-1]].copy() + data = self.data[indptr[0] : indptr[-1]].copy() + indptr -= indptr[0] + return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1])) + @staticmethod def merge(mtxs: Sequence[_CSR_IO_Buffer]) -> _CSR_IO_Buffer: assert len(mtxs) > 0 diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index f671a93..1e7ed02 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -152,6 +152,7 @@ def soma_experiment( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) +@pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize( "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) @@ -159,6 +160,7 @@ def test_non_batched( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, + return_sparse_X: bool, ) -> None: # batch_size should default to 1 with soma_experiment.axis_query(measurement_name="RNA") as query: @@ -168,6 +170,7 @@ def test_non_batched( obs_column_names=["label"], shuffle=False, use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, ) assert type(exp_data_pipe.shape) is tuple assert len(exp_data_pipe.shape) == 2 @@ -176,11 +179,20 @@ def test_non_batched( row_iter = iter(exp_data_pipe) row = next(row_iter) - assert isinstance(row[0], np.ndarray) + + if return_sparse_X: + # sparse slices remain 2D, always + assert isinstance(row[0], sparse.csr_matrix) + assert row[0].shape == (1, 3) + assert row[0].todense().tolist() == [[0, 1, 0]] + + else: + assert isinstance(row[0], np.ndarray) + assert row[0].shape == (3,) + assert row[0].tolist() == [0, 1, 0] + assert isinstance(row[1], pd.DataFrame) - assert row[0].shape == (3,) assert row[1].shape == (1, 1) - assert row[0].tolist() == [0, 1, 0] assert row[1].keys() == ["label"] assert row[1]["label"].tolist() == ["0"] @@ -189,6 +201,7 @@ def test_non_batched( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) +@pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize( "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) ) @@ -196,6 +209,7 @@ def test_uneven_soma_and_result_batches( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, soma_experiment: Experiment, use_eager_fetch: bool, + return_sparse_X: bool, ) -> None: """This is checking that batches are correctly created when they require fetching multiple chunks.""" with soma_experiment.axis_query(measurement_name="RNA") as query: @@ -207,16 +221,23 @@ def test_uneven_soma_and_result_batches( batch_size=3, io_batch_size=2, use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, ) row_iter = iter(exp_data_pipe) X_batch, obs_batch = next(row_iter) - assert isinstance(X_batch, np.ndarray) + + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + assert X_batch.todense()[0].tolist() == [[0, 1, 0]] + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch[0].tolist() == [0, 1, 0] + assert isinstance(obs_batch, pd.DataFrame) assert X_batch.shape[0] == obs_batch.shape[0] assert X_batch.shape == (3, 3) assert obs_batch.shape == (3, 1) - assert X_batch[0].tolist() == [0, 1, 0] assert ["label"] == obs_batch.keys() assert obs_batch["label"].tolist() == ["0", "1", "2"] @@ -378,6 +399,63 @@ def test_batching__empty_query_result( next(batch_iter) +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_sparse_output__non_batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + return_sparse_X=True, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().A.squeeze().tolist() == [0, 1, 0] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize( + "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) +) +def test_sparse_output__batched( + PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + return_sparse_X=True, + shuffle=False, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + + @pytest.mark.parametrize( "obs_range,var_range,X_value_gen,use_eager_fetch", [ @@ -852,11 +930,18 @@ def test_csr__construct_from_ijd(shape: Tuple[int, int], dtype: npt.DTypeLike) - != sp_csr ).nnz == 0 - assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_coo.toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_csr.toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(1, -1)), sp_csr[1:-1].toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(None, -2)), sp_csr[:-2].toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(None)), sp_csr[:].toarray()) + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_coo.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 @pytest.mark.parametrize( @@ -883,10 +968,17 @@ def test_csr__construct_from_pjd(shape: Tuple[int, int], dtype: npt.DTypeLike) - != sp_csr ).nnz == 0 - assert np.array_equal(_ncsr.densified_slice(slice(0, shape[0])), sp_csr.toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(1, -1)), sp_csr[1:-1].toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(None, -2)), sp_csr[:-2].toarray()) - assert np.array_equal(_ncsr.densified_slice(slice(None)), sp_csr[:].toarray()) + # Check dense slicing + assert np.array_equal(_ncsr.slice_tonumpy(slice(0, shape[0])), sp_csr.toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(1, -1)), sp_csr[1:-1].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None, -2)), sp_csr[:-2].toarray()) + assert np.array_equal(_ncsr.slice_tonumpy(slice(None)), sp_csr[:].toarray()) + + # Check sparse slicing + assert (_ncsr.slice_toscipy(slice(0, shape[0])) != sp_csr).nnz == 0 + assert (_ncsr.slice_toscipy(slice(1, -1)) != sp_csr[1:-1]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None, -2)) != sp_csr[:-2]).nnz == 0 + assert (_ncsr.slice_toscipy(slice(None)) != sp_csr[:]).nnz == 0 @pytest.mark.parametrize( From 9809c5c277989bb91a74c9056ba92cce148960c3 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 16 Sep 2024 23:50:07 +0000 Subject: [PATCH 56/70] start draft of Ligtning notebook --- notebooks/tutorial_lightning.ipynb | 179 +++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 notebooks/tutorial_lightning.ipynb diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb new file mode 100644 index 0000000..e975f95 --- /dev/null +++ b/notebooks/tutorial_lightning.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a model with PyTorch Lightning\n", + "\n", + "This tutorial provides a quick overview of training a toy model with Lightning, using the `tiledbsoma_ml.ExperimentAxisQueryIterableDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterableDataset`, and not as an example of how to train a biologically useful model.\n", + "\n", + "**Prerequesites**\n", + "\n", + "Install `tiledbsoma_ml` and `scikit-learn`, for example:\n", + "\n", + "> pip install tiledbsoma_ml scikit-learn\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SOMA Experiment query as training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import tiledbsoma_ml as soma_ml\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma as soma\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'lung' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Lightning module" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class LogisticRegressionLightning(pl.LightningModule):\n", + " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n", + " super(LogisticRegressionLightning, self).__init__()\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + " self.cell_type_encoder = cell_type_encoder\n", + " self.learning_rate = learning_rate\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " X_batch, y_batch = batch\n", + " # X_batch = X_batch.float()\n", + " X_batch = torch.from_numpy(X_batch).float().to(self.device)\n", + "\n", + " # Perform prediction\n", + " outputs = self(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute loss\n", + " # y_batch = y_batch.flatten()\n", + " y_batch = torch.from_numpy(\n", + " self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n", + " ).to(self.device)\n", + " loss = self.loss_fn(outputs, y_batch.long())\n", + "\n", + " # Compute accuracy\n", + " train_correct = (predictions == y_batch).sum().item()\n", + " train_accuracy = train_correct / len(predictions)\n", + "\n", + " # Log loss and accuracy\n", + " self.log(\"train_loss\", loss, prog_bar=True)\n", + " self.log(\"train_accuracy\", train_accuracy, prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "# Initialize the PyTorch Lightning model\n", + "model = LogisticRegressionLightning(\n", + " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", + ")\n", + "\n", + "# Define the PyTorch Lightning Trainer\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " # accelerator=args.accelerator,\n", + " # strategy=\"ddp\",\n", + ")\n", + "\n", + "# Train the model\n", + "trainer.fit(model, train_dataloaders=dataloader)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 656e2e80ed9bf5b316a53f24dec15ef1f5a42bd4 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 16 Sep 2024 23:53:52 +0000 Subject: [PATCH 57/70] lint --- src/tiledbsoma_ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index f6c4351..f4d86a6 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -57,7 +57,7 @@ XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix] else: NDArrayNumber: TypeAlias = np.ndarray - XDatum: TypeAlias = np.ndarray | sparse.csr_matrix + XDatum: TypeAlias = Union[np.ndarray, sparse.csr_matrix] XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame] """Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, From f9e13b0278a57becbcaf52d8c80cf0a2b5301de1 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 16 Sep 2024 17:26:35 -0700 Subject: [PATCH 58/70] update notebook for lightning --- notebooks/tutorial_lightning.ipynb | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index e975f95..e17a752 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -8,6 +8,8 @@ "\n", "This tutorial provides a quick overview of training a toy model with Lightning, using the `tiledbsoma_ml.ExperimentAxisQueryIterableDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterableDataset`, and not as an example of how to train a biologically useful model.\n", "\n", + "For more information on these API, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).\n", + "\n", "**Prerequesites**\n", "\n", "Install `tiledbsoma_ml` and `scikit-learn`, for example:\n", @@ -130,7 +132,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", + "dataloader = soma_ml.experiment_dataloader(experiment_dataset, num_workers=2, persistent_workers=True)\n", "\n", "# The size of the input dimension is the number of genes\n", "input_dim = experiment_dataset.shape[1]\n", @@ -150,6 +152,9 @@ " # strategy=\"ddp\",\n", ")\n", "\n", + "# set precision\n", + "torch.set_float32_matmul_precision('medium')\n", + "\n", "# Train the model\n", "trainer.fit(model, train_dataloaders=dataloader)\n" ] From eaeaab4409a4c3eb4f2f696b60075cc37339c915 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 17 Sep 2024 01:10:07 +0000 Subject: [PATCH 59/70] run notebooks --- notebooks/tutorial_lightning.ipynb | 88 ++++++++++++-- notebooks/tutorial_pytorch.ipynb | 184 +++++++++++++++-------------- src/tiledbsoma_ml/pytorch.py | 2 +- 3 files changed, 176 insertions(+), 98 deletions(-) diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index e17a752..b007a23 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ " CZI_Census_Homo_Sapiens_URL,\n", " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", ")\n", - "obs_value_filter = \"tissue_general == 'lung' and is_primary_data == True\"\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", "\n", "with experiment.axis_query(\n", " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", @@ -57,7 +57,8 @@ " obs_column_names=[\"cell_type\"],\n", " batch_size=128,\n", " shuffle=True,\n", - " )\n" + " seed=12345,\n", + " )" ] }, { @@ -69,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +117,7 @@ "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", - " return optimizer\n" + " return optimizer" ] }, { @@ -128,11 +129,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------\n", + "0 | linear | Linear | 726 K | train\n", + "1 | loss_fn | CrossEntropyLoss | 0 | train\n", + "-----------------------------------------------------\n", + "726 K Trainable params\n", + "0 Non-trainable params\n", + "726 K Total params\n", + "2.905 Total estimated model params size (MB)\n", + "2 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", + "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n", + "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.87it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=20` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.86it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]\n" + ] + } + ], "source": [ - "dataloader = soma_ml.experiment_dataloader(experiment_dataset, num_workers=2, persistent_workers=True)\n", + "dataloader = soma_ml.experiment_dataloader(\n", + " experiment_dataset, num_workers=2, persistent_workers=True\n", + ")\n", "\n", "# The size of the input dimension is the number of genes\n", "input_dim = experiment_dataset.shape[1]\n", @@ -147,16 +213,16 @@ "\n", "# Define the PyTorch Lightning Trainer\n", "trainer = pl.Trainer(\n", - " max_epochs=10,\n", + " max_epochs=20,\n", " # accelerator=args.accelerator,\n", " # strategy=\"ddp\",\n", ")\n", "\n", "# set precision\n", - "torch.set_float32_matmul_precision('medium')\n", + "torch.set_float32_matmul_precision(\"high\")\n", "\n", "# Train the model\n", - "trainer.fit(model, train_dataloaders=dataloader)\n" + "trainer.fit(model, train_dataloaders=dataloader)" ] } ], diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index cce57fb..36116c4 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -67,6 +67,7 @@ " obs_column_names=[\"cell_type\"],\n", " batch_size=128,\n", " shuffle=True,\n", + " seed=12345,\n", " )\n" ] }, @@ -97,21 +98,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can inspect the shape of the full dataset, without causing the full dataset to be loaded:" + "You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The `shape` property returns the number of batches on the first dimension:" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(15020, 60530)" + "(118, 60530)" ] }, - "execution_count": 2, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -131,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -149,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -200,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -282,23 +283,33 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0161885 Accuracy 0.2666\n", - "Epoch 2: Train Loss: 0.0148625 Accuracy 0.4692\n", - "Epoch 3: Train Loss: 0.0144180 Accuracy 0.5561\n", - "Epoch 4: Train Loss: 0.0141377 Accuracy 0.6762\n", - "Epoch 5: Train Loss: 0.0139051 Accuracy 0.7803\n", - "Epoch 6: Train Loss: 0.0137547 Accuracy 0.8508\n", - "Epoch 7: Train Loss: 0.0136169 Accuracy 0.8972\n", - "Epoch 8: Train Loss: 0.0134998 Accuracy 0.9135\n", - "Epoch 9: Train Loss: 0.0134092 Accuracy 0.9220\n", - "Epoch 10: Train Loss: 0.0133577 Accuracy 0.9245\n" + "Epoch 1: Train Loss: 0.0169901 Accuracy 0.1925\n", + "Epoch 2: Train Loss: 0.0152003 Accuracy 0.3209\n", + "Epoch 3: Train Loss: 0.0146477 Accuracy 0.4041\n", + "Epoch 4: Train Loss: 0.0141752 Accuracy 0.4333\n", + "Epoch 5: Train Loss: 0.0140288 Accuracy 0.4538\n", + "Epoch 6: Train Loss: 0.0139012 Accuracy 0.5109\n", + "Epoch 7: Train Loss: 0.0137932 Accuracy 0.5773\n", + "Epoch 8: Train Loss: 0.0136451 Accuracy 0.6737\n", + "Epoch 9: Train Loss: 0.0135466 Accuracy 0.7634\n", + "Epoch 10: Train Loss: 0.0134686 Accuracy 0.8611\n", + "Epoch 11: Train Loss: 0.0134011 Accuracy 0.9100\n", + "Epoch 12: Train Loss: 0.0133098 Accuracy 0.9257\n", + "Epoch 13: Train Loss: 0.0132513 Accuracy 0.9289\n", + "Epoch 14: Train Loss: 0.0132211 Accuracy 0.9345\n", + "Epoch 15: Train Loss: 0.0131944 Accuracy 0.9405\n", + "Epoch 16: Train Loss: 0.0131586 Accuracy 0.9456\n", + "Epoch 17: Train Loss: 0.0131391 Accuracy 0.9456\n", + "Epoch 18: Train Loss: 0.0131139 Accuracy 0.9510\n", + "Epoch 19: Train Loss: 0.0130988 Accuracy 0.9550\n", + "Epoch 20: Train Loss: 0.0130793 Accuracy 0.9557\n" ] } ], @@ -315,7 +326,7 @@ "loss_fn = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", "\n", - "for epoch in range(10):\n", + "for epoch in range(20):\n", " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" ] @@ -331,7 +342,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -350,20 +361,20 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([ 1, 7, 5, 7, 7, 1, 7, 5, 1, 7, 7, 8, 11, 1, 7, 1, 1, 5,\n", - " 8, 8, 5, 1, 1, 1, 7, 1, 1, 1, 8, 1, 1, 7, 1, 7, 1, 7,\n", - " 6, 7, 1, 1, 5, 8, 1, 1, 8, 1, 1, 7, 8, 1, 1, 1, 1, 1,\n", - " 5, 1, 5, 8, 1, 5, 1, 8, 7, 1, 7, 7, 1, 1, 1, 1, 7, 1,\n", - " 1, 1, 8, 1, 1, 7, 1, 5, 5, 1, 1, 7, 7, 9, 1, 5, 1, 1,\n", - " 8, 1, 7, 7, 11, 7, 1, 7, 1, 7, 1, 8, 1, 1, 1, 11, 1, 1,\n", - " 1, 8, 1, 5, 1, 1, 7, 1, 7, 8, 7, 1, 5, 7, 7, 5, 1, 7,\n", - " 8, 7], device='cuda:0')" + "tensor([ 1, 8, 1, 7, 7, 7, 1, 1, 7, 1, 11, 1, 1, 7, 1, 1, 7, 1,\n", + " 1, 1, 5, 1, 1, 1, 1, 7, 1, 5, 7, 8, 7, 6, 1, 7, 7, 1,\n", + " 8, 1, 1, 1, 7, 1, 1, 1, 1, 11, 7, 9, 1, 1, 1, 1, 1, 1,\n", + " 1, 8, 1, 1, 1, 8, 1, 8, 5, 7, 7, 5, 11, 8, 7, 1, 1, 1,\n", + " 8, 9, 8, 1, 1, 7, 5, 1, 7, 5, 7, 7, 5, 7, 1, 1, 1, 7,\n", + " 5, 7, 5, 2, 7, 7, 1, 6, 1, 7, 7, 8, 1, 1, 8, 1, 1, 7,\n", + " 11, 5, 7, 5, 1, 5, 5, 1, 5, 7, 9, 8, 5, 7, 1, 1, 11, 7,\n", + " 8, 5], device='cuda:0')" ] }, "metadata": {}, @@ -393,45 +404,46 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(['basal cell', 'keratinocyte', 'epithelial cell', 'keratinocyte',\n", - " 'keratinocyte', 'basal cell', 'keratinocyte', 'epithelial cell',\n", - " 'basal cell', 'keratinocyte', 'keratinocyte', 'leukocyte',\n", - " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'epithelial cell', 'leukocyte',\n", - " 'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + "array(['basal cell', 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'keratinocyte', 'keratinocyte', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'vein endothelial cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", + " 'epithelial cell', 'keratinocyte', 'leukocyte', 'keratinocyte',\n", + " 'fibroblast', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'vein endothelial cell',\n", + " 'keratinocyte', 'pericyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'keratinocyte', 'fibroblast', 'keratinocyte', 'basal cell',\n", - " 'basal cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'epithelial cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'vein endothelial cell', 'leukocyte',\n", + " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'pericyte', 'leukocyte', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'basal cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'epithelial cell', 'keratinocyte',\n", + " 'epithelial cell', 'capillary endothelial cell', 'keratinocyte',\n", + " 'keratinocyte', 'basal cell', 'fibroblast', 'basal cell',\n", + " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'epithelial cell',\n", - " 'basal cell', 'epithelial cell', 'leukocyte', 'basal cell',\n", - " 'epithelial cell', 'basal cell', 'leukocyte', 'keratinocyte',\n", - " 'basal cell', 'keratinocyte', 'keratinocyte', 'basal cell',\n", - " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'keratinocyte', 'pericyte', 'basal cell',\n", - " 'epithelial cell', 'basal cell', 'basal cell', 'leukocyte',\n", - " 'basal cell', 'keratinocyte', 'keratinocyte',\n", - " 'vein endothelial cell', 'keratinocyte', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'vein endothelial cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'leukocyte', 'basal cell', 'epithelial cell', 'basal cell',\n", - " 'basal cell', 'keratinocyte', 'basal cell', 'keratinocyte',\n", - " 'leukocyte', 'keratinocyte', 'basal cell', 'epithelial cell',\n", - " 'keratinocyte', 'keratinocyte', 'epithelial cell', 'basal cell',\n", - " 'keratinocyte', 'leukocyte', 'keratinocyte'], dtype=object)" + " 'keratinocyte', 'vein endothelial cell', 'epithelial cell',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'basal cell', 'epithelial cell', 'keratinocyte',\n", + " 'pericyte', 'leukocyte', 'epithelial cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'vein endothelial cell',\n", + " 'keratinocyte', 'leukocyte', 'epithelial cell'], dtype=object)" ] }, "metadata": {}, @@ -453,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -489,13 +501,13 @@ " \n", " \n", " 1\n", - " keratinocyte\n", - " keratinocyte\n", + " leukocyte\n", + " leukocyte\n", " \n", " \n", " 2\n", " basal cell\n", - " epithelial cell\n", + " basal cell\n", " \n", " \n", " 3\n", @@ -504,7 +516,7 @@ " \n", " \n", " 4\n", - " basal cell\n", + " keratinocyte\n", " keratinocyte\n", " \n", " \n", @@ -514,13 +526,13 @@ " \n", " \n", " 123\n", - " epithelial cell\n", - " epithelial cell\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 124\n", - " basal cell\n", - " basal cell\n", + " vein endothelial cell\n", + " vein endothelial cell\n", " \n", " \n", " 125\n", @@ -534,8 +546,8 @@ " \n", " \n", " 127\n", - " keratinocyte\n", - " keratinocyte\n", + " epithelial cell\n", + " epithelial cell\n", " \n", " \n", "\n", @@ -543,18 +555,18 @@ "" ], "text/plain": [ - " actual cell type predicted cell type\n", - "0 basal cell basal cell\n", - "1 keratinocyte keratinocyte\n", - "2 basal cell epithelial cell\n", - "3 keratinocyte keratinocyte\n", - "4 basal cell keratinocyte\n", - ".. ... ...\n", - "123 epithelial cell epithelial cell\n", - "124 basal cell basal cell\n", - "125 keratinocyte keratinocyte\n", - "126 leukocyte leukocyte\n", - "127 keratinocyte keratinocyte\n", + " actual cell type predicted cell type\n", + "0 basal cell basal cell\n", + "1 leukocyte leukocyte\n", + "2 basal cell basal cell\n", + "3 keratinocyte keratinocyte\n", + "4 keratinocyte keratinocyte\n", + ".. ... ...\n", + "123 basal cell basal cell\n", + "124 vein endothelial cell vein endothelial cell\n", + "125 keratinocyte keratinocyte\n", + "126 leukocyte leukocyte\n", + "127 epithelial cell epithelial cell\n", "\n", "[128 rows x 2 columns]" ] @@ -593,7 +605,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index f4d86a6..beeaf3e 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -163,7 +163,7 @@ def __init__( If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will return ``X`` data as a :class:`numpy.ndarray`. seed: - The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *MUST* be specified when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker processes. use_eager_fetch: From 61face39b3ad5ee79c4ea6d7323c8dc8b757bf59 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 18 Sep 2024 12:08:48 -0700 Subject: [PATCH 60/70] fix RNG state bug in shuffle; add multi-worker notebook --- notebooks/tutorial_lightning.ipynb | 62 +++---- notebooks/tutorial_multiworker.ipynb | 245 +++++++++++++++++++++++++++ notebooks/tutorial_pytorch.ipynb | 198 +++++++++++----------- src/tiledbsoma_ml/pytorch.py | 103 ++++++++++- 4 files changed, 464 insertions(+), 144 deletions(-) create mode 100644 notebooks/tutorial_multiworker.ipynb diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index b007a23..885f502 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -26,9 +26,25 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], "source": [ "import pytorch_lightning as pl\n", "import tiledbsoma_ml as soma_ml\n", @@ -57,7 +73,6 @@ " obs_column_names=[\"cell_type\"],\n", " batch_size=128,\n", " shuffle=True,\n", - " seed=12345,\n", " )" ] }, @@ -70,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -99,7 +114,6 @@ " predictions = torch.argmax(probabilities, axis=1)\n", "\n", " # Compute loss\n", - " # y_batch = y_batch.flatten()\n", " y_batch = torch.from_numpy(\n", " self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n", " ).to(self.device)\n", @@ -129,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -139,6 +153,7 @@ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode \n", @@ -152,32 +167,15 @@ "2.905 Total estimated model params size (MB)\n", "2 Modules in train mode\n", "0 Modules in eval mode\n", - "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", - "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", - "################################################################################\n", - "WARNING!\n", - "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", - "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", - "to learn more and leave feedback.\n", - "################################################################################\n", - "\n", - " deprecation_warning()\n", - "/home/ubuntu/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", - "################################################################################\n", - "WARNING!\n", - "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", - "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", - "to learn more and leave feedback.\n", - "################################################################################\n", - "\n", - " deprecation_warning()\n" + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.87it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]" + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]" ] }, { @@ -191,14 +189,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 19: 100%|██████████| 118/118 [00:17<00:00, 6.86it/s, v_num=6, train_loss=1.680, train_accuracy=0.977]\n" + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n" ] } ], "source": [ - "dataloader = soma_ml.experiment_dataloader(\n", - " experiment_dataset, num_workers=2, persistent_workers=True\n", - ")\n", + "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", "\n", "# The size of the input dimension is the number of genes\n", "input_dim = experiment_dataset.shape[1]\n", @@ -212,11 +208,7 @@ ")\n", "\n", "# Define the PyTorch Lightning Trainer\n", - "trainer = pl.Trainer(\n", - " max_epochs=20,\n", - " # accelerator=args.accelerator,\n", - " # strategy=\"ddp\",\n", - ")\n", + "trainer = pl.Trainer(max_epochs=20)\n", "\n", "# set precision\n", "torch.set_float32_matmul_precision(\"high\")\n", diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb new file mode 100644 index 0000000..1711f3a --- /dev/null +++ b/notebooks/tutorial_multiworker.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-process training\n", + "\n", + "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n", + "* using the `torch.utils.data.DataLoader` with 1 or more worker (ie., with an argument of `n_workers=1` or greater)\n", + "* using a multi-process training configuration, such as `DistributedDataParallel`\n", + "\n", + "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", + "\n", + "1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n", + "2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import tiledbsoma_ml as soma_ml\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma as soma\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + " \n", + "\n", + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-worker DataLoader\n", + "\n", + "If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.\n", + "\n", + "The same approach should be taken for parallel training, e.g., when using DDP or DP.\n", + "\n", + "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n", + "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n", + "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n", + "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n", + "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n", + "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n", + "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n", + "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n", + "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n", + "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n", + "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n", + "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n", + "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n", + "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n", + "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n", + "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n", + "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n", + "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n", + "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n", + "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "\n", + "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", + "# that a different shuffle is applied on each epoch.\n", + "experiment_dataloader = soma_ml.experiment_dataloader(\n", + " experiment_dataset, num_workers=2, persistent_workers=True\n", + ")\n", + "\n", + "for epoch in range(20):\n", + " experiment_dataset.set_epoch(epoch)\n", + " train_loss, train_accuracy = train_epoch(\n", + " model, experiment_dataloader, loss_fn, optimizer, device\n", + " )\n", + " print(\n", + " f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index 36116c4..581beaa 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +67,6 @@ " obs_column_names=[\"cell_type\"],\n", " batch_size=128,\n", " shuffle=True,\n", - " seed=12345,\n", " )\n" ] }, @@ -103,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -112,7 +111,7 @@ "(118, 60530)" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -150,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -175,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -201,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -224,9 +223,7 @@ " predictions = torch.argmax(probabilities, axis=1)\n", "\n", " # Compute the loss and perform back propagation\n", - "\n", " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", - "\n", " train_correct += (predictions == y_batch).sum().item()\n", " train_total += len(predictions)\n", "\n", @@ -283,33 +280,33 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1: Train Loss: 0.0169901 Accuracy 0.1925\n", - "Epoch 2: Train Loss: 0.0152003 Accuracy 0.3209\n", - "Epoch 3: Train Loss: 0.0146477 Accuracy 0.4041\n", - "Epoch 4: Train Loss: 0.0141752 Accuracy 0.4333\n", - "Epoch 5: Train Loss: 0.0140288 Accuracy 0.4538\n", - "Epoch 6: Train Loss: 0.0139012 Accuracy 0.5109\n", - "Epoch 7: Train Loss: 0.0137932 Accuracy 0.5773\n", - "Epoch 8: Train Loss: 0.0136451 Accuracy 0.6737\n", - "Epoch 9: Train Loss: 0.0135466 Accuracy 0.7634\n", - "Epoch 10: Train Loss: 0.0134686 Accuracy 0.8611\n", - "Epoch 11: Train Loss: 0.0134011 Accuracy 0.9100\n", - "Epoch 12: Train Loss: 0.0133098 Accuracy 0.9257\n", - "Epoch 13: Train Loss: 0.0132513 Accuracy 0.9289\n", - "Epoch 14: Train Loss: 0.0132211 Accuracy 0.9345\n", - "Epoch 15: Train Loss: 0.0131944 Accuracy 0.9405\n", - "Epoch 16: Train Loss: 0.0131586 Accuracy 0.9456\n", - "Epoch 17: Train Loss: 0.0131391 Accuracy 0.9456\n", - "Epoch 18: Train Loss: 0.0131139 Accuracy 0.9510\n", - "Epoch 19: Train Loss: 0.0130988 Accuracy 0.9550\n", - "Epoch 20: Train Loss: 0.0130793 Accuracy 0.9557\n" + "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n", + "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n", + "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n", + "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n", + "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n", + "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n", + "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n", + "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n", + "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n", + "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n", + "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n", + "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n", + "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n", + "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n", + "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n", + "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n", + "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n", + "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n", + "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n", + "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n" ] } ], @@ -342,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -361,20 +358,20 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([ 1, 8, 1, 7, 7, 7, 1, 1, 7, 1, 11, 1, 1, 7, 1, 1, 7, 1,\n", - " 1, 1, 5, 1, 1, 1, 1, 7, 1, 5, 7, 8, 7, 6, 1, 7, 7, 1,\n", - " 8, 1, 1, 1, 7, 1, 1, 1, 1, 11, 7, 9, 1, 1, 1, 1, 1, 1,\n", - " 1, 8, 1, 1, 1, 8, 1, 8, 5, 7, 7, 5, 11, 8, 7, 1, 1, 1,\n", - " 8, 9, 8, 1, 1, 7, 5, 1, 7, 5, 7, 7, 5, 7, 1, 1, 1, 7,\n", - " 5, 7, 5, 2, 7, 7, 1, 6, 1, 7, 7, 8, 1, 1, 8, 1, 1, 7,\n", - " 11, 5, 7, 5, 1, 5, 5, 1, 5, 7, 9, 8, 5, 7, 1, 1, 11, 7,\n", - " 8, 5], device='cuda:0')" + "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n", + " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n", + " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n", + " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n", + " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n", + " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n", + " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n", + " 8, 1], device='cuda:0')" ] }, "metadata": {}, @@ -404,46 +401,45 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(['basal cell', 'leukocyte', 'basal cell', 'keratinocyte',\n", - " 'keratinocyte', 'keratinocyte', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'basal cell', 'vein endothelial cell',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'keratinocyte', 'basal cell',\n", - " 'epithelial cell', 'keratinocyte', 'leukocyte', 'keratinocyte',\n", - " 'fibroblast', 'basal cell', 'keratinocyte', 'keratinocyte',\n", - " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'keratinocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'basal cell', 'vein endothelial cell',\n", - " 'keratinocyte', 'pericyte', 'basal cell', 'basal cell',\n", + "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", - " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'leukocyte', 'basal cell', 'leukocyte',\n", - " 'epithelial cell', 'keratinocyte', 'keratinocyte',\n", - " 'epithelial cell', 'vein endothelial cell', 'leukocyte',\n", - " 'keratinocyte', 'basal cell', 'basal cell', 'basal cell',\n", - " 'leukocyte', 'pericyte', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'epithelial cell', 'basal cell', 'keratinocyte',\n", - " 'epithelial cell', 'keratinocyte', 'keratinocyte',\n", - " 'epithelial cell', 'keratinocyte', 'basal cell', 'basal cell',\n", - " 'basal cell', 'keratinocyte', 'epithelial cell', 'keratinocyte',\n", - " 'epithelial cell', 'capillary endothelial cell', 'keratinocyte',\n", - " 'keratinocyte', 'basal cell', 'fibroblast', 'basal cell',\n", " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", - " 'basal cell', 'leukocyte', 'basal cell', 'basal cell',\n", - " 'keratinocyte', 'vein endothelial cell', 'epithelial cell',\n", - " 'keratinocyte', 'epithelial cell', 'basal cell', 'epithelial cell',\n", - " 'epithelial cell', 'basal cell', 'epithelial cell', 'keratinocyte',\n", - " 'pericyte', 'leukocyte', 'epithelial cell', 'keratinocyte',\n", - " 'basal cell', 'basal cell', 'vein endothelial cell',\n", - " 'keratinocyte', 'leukocyte', 'epithelial cell'], dtype=object)" + " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n", + " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n", + " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell'], dtype=object)" ] }, "metadata": {}, @@ -465,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -496,13 +492,13 @@ " \n", " \n", " 0\n", - " basal cell\n", - " basal cell\n", + " leukocyte\n", + " leukocyte\n", " \n", " \n", " 1\n", - " leukocyte\n", - " leukocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 2\n", @@ -511,13 +507,13 @@ " \n", " \n", " 3\n", - " keratinocyte\n", - " keratinocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 4\n", - " keratinocyte\n", - " keratinocyte\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " ...\n", @@ -526,18 +522,18 @@ " \n", " \n", " 123\n", - " basal cell\n", + " fibroblast\n", " basal cell\n", " \n", " \n", " 124\n", - " vein endothelial cell\n", - " vein endothelial cell\n", + " basal cell\n", + " basal cell\n", " \n", " \n", " 125\n", " keratinocyte\n", - " keratinocyte\n", + " basal cell\n", " \n", " \n", " 126\n", @@ -546,8 +542,8 @@ " \n", " \n", " 127\n", - " epithelial cell\n", - " epithelial cell\n", + " basal cell\n", + " basal cell\n", " \n", " \n", "\n", @@ -555,18 +551,18 @@ "" ], "text/plain": [ - " actual cell type predicted cell type\n", - "0 basal cell basal cell\n", - "1 leukocyte leukocyte\n", - "2 basal cell basal cell\n", - "3 keratinocyte keratinocyte\n", - "4 keratinocyte keratinocyte\n", - ".. ... ...\n", - "123 basal cell basal cell\n", - "124 vein endothelial cell vein endothelial cell\n", - "125 keratinocyte keratinocyte\n", - "126 leukocyte leukocyte\n", - "127 epithelial cell epithelial cell\n", + " actual cell type predicted cell type\n", + "0 leukocyte leukocyte\n", + "1 basal cell basal cell\n", + "2 basal cell basal cell\n", + "3 basal cell basal cell\n", + "4 basal cell basal cell\n", + ".. ... ...\n", + "123 fibroblast basal cell\n", + "124 basal cell basal cell\n", + "125 keratinocyte basal cell\n", + "126 leukocyte leukocyte\n", + "127 basal cell basal cell\n", "\n", "[128 rows x 2 columns]" ] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index beeaf3e..cd2a8ea 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -181,6 +181,15 @@ def __init__( Lifecycle: experimental + + .. warning:: + When using this class in any distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you + must provide a seed, ensuring that the same shuffle is used across all replicas. """ super().__init__() @@ -199,9 +208,13 @@ def __init__( self.use_eager_fetch = use_eager_fetch self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None - self._shuffle_rng = np.random.default_rng(seed) if shuffle else None + self.seed = ( + seed if seed is not None else np.random.default_rng().integers(0, 2**32 - 1) + ) + self._user_specified_seed = seed is not None self.shuffle_chunk_size = shuffle_chunk_size self._initialized = False + self.epoch = 0 if self.shuffle: # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. @@ -239,19 +252,23 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: _gpu_splits = _splits(len(obs_joinids), world_size) _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] - # 2. Trim to be all of equal length + # 2. Trim to be all of equal length - equivalent to a "drop_last" + # TODO: may need to add an option to do padding as well. min_len = np.diff(_gpu_splits).min() assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 _gpu_split = _gpu_split[:min_len] # 3. Chunk and optionally shuffle chunks if self.shuffle: - assert self._shuffle_rng is not None assert self.io_batch_size % self.shuffle_chunk_size == 0 shuffle_split = np.array_split( _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) ) - self._shuffle_rng.shuffle(shuffle_split) + + # Deterministically create RNG - state must be same across all processes, ensuring + # that the joinid partitions are identical across all processes. + rng = np.random.default_rng(self.seed + self.epoch + 99) + rng.shuffle(shuffle_split) obs_joinids_chunked = list( np.concatenate(b) for b in _batched( @@ -272,7 +289,8 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: if logger.isEnabledFor(logging.DEBUG): logger.debug( - f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}, " + f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, " + f"n_workers={n_workers}, epoch={self.epoch}, " f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) @@ -332,6 +350,16 @@ def __iter__(self) -> Iterator[XObsDatum]: "(see https://github.com/pytorch/pytorch/issues/20248)" ) + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + logger.debug( + f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}, seed={self.seed}, epoch={self.epoch}" + ) + if world_size > 1 and self.shuffle and self._user_specified_seed is None: + raise ValueError( + "ExperimentAxisQueryIterable requires an explicit `seed` when shuffle is used in a multi-process configuration." + ) + with self.experiment_locator.open_experiment() as exp: self._init_once(exp) X = exp.ms[self.measurement_name].X[self.layer_name] @@ -349,6 +377,8 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter + self.epoch += 1 + def __len__(self) -> int: """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) @@ -388,6 +418,18 @@ def shape(self) -> Tuple[int, int]: div, rem = divmod(partition_len, self.batch_size) return div + bool(rem), len(self._var_joinids) + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + """ + self.epoch = epoch + def __getitem__(self, index: int) -> XObsDatum: raise NotImplementedError( "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" @@ -409,6 +451,10 @@ def _io_batch_iter( """ assert self._var_joinids is not None + # Create RNG - does not need to be identical across processes, but use the seed anyway + # for reproducibility. + shuffle_rng = np.random.default_rng(self.seed + self.epoch) + obs_column_names = ( list(self.obs_column_names) if "soma_joinid" in self.obs_column_names @@ -419,9 +465,7 @@ def _io_batch_iter( for obs_coords in obs_joinid_iter: st_time = time.perf_counter() obs_shuffled_coords = ( - obs_coords - if self._shuffle_rng is None - else self._shuffle_rng.permuted(obs_coords) + obs_coords if not self.shuffle else shuffle_rng.permuted(obs_coords) ) obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) logger.debug( @@ -637,6 +681,25 @@ def shape(self) -> Tuple[int, int]: """ return self._exp_iter.shape + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + class ExperimentAxisQueryIterableDataset( torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] @@ -692,6 +755,11 @@ class ExperimentAxisQueryIterableDataset( (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O performance. + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` + and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition + data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any + data tail will be dropped. + Lifecycle: experimental """ @@ -828,6 +896,25 @@ def shape(self) -> Tuple[int, int]: """ return self._exp_iter.shape + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + def experiment_dataloader( ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, From 6ccad8ece00b8b38e4d049df204f108aa1672628 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 18 Sep 2024 15:37:03 -0400 Subject: [PATCH 61/70] add `rehome-census.sh`, used to construct this repo's history --- scripts/rehome-census.sh | 60 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100755 scripts/rehome-census.sh diff --git a/scripts/rehome-census.sh b/scripts/rehome-census.sh new file mode 100755 index 0000000..b742b2b --- /dev/null +++ b/scripts/rehome-census.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# +# Create a fresh TileDB-SOMA-ML clone, with some "re-homed" history: +# 1. Reproduce the Git history of `api/python/cellxgene_census/src/cellxgene_census/experimental/ml` in the CELLxGENE Census repo: +# - This was developed between May 2023 and July 2024 +# - A few files are omitted, that are not relevant to "PyTorch loaders" work (namely the "huggingface" subdirectory) +# 2. Insert one commit moving these files to `apis/python/src/tiledbsoma/ml` (which they were copied to, in the TileDB-SOMA repo, in July 2024) +# 3. Replay the `bkmartinjr/experimentdatapipe` branch of TileDB-SOMA (developed between July 2024 and September 2024) on top of this +# - Just the commits that touch `apis/python/src/tiledbsoma/ml` directory, or a few other relevant paths (e.g. `other_packages, where they were moved, later in that branch's development) + +set -ex + +pip install git-filter-repo + +# Create a Census clone, filter to files/commits relevant to PyTorch loaders: +git clone -o origin https://github.com/chanzuckerberg/cellxgene-census census-ml && cd census-ml +ml=api/python/cellxgene_census/src/cellxgene_census/experimental/ml +git filter-repo \ + --path $ml/__init__.py \ + --path $ml/pytorch.py \ + --path $ml/encoders.py \ + --path $ml/util +cd .. + +# Create a TileDB-SOMA clone, filter to files/commits relevant to PyTorch loaders: +git clone -o origin -b bkmartinjr/experimentdatapipe git@github.com:single-cell-data/TileDB-SOMA.git soma-pytorch && cd soma-pytorch +git branch -m main +renames=() +for p in CHANGELOG.md README.md notebooks pyproject.toml src tests; do + renames+=(--path-rename "other_packages/python/tiledbsoma_ml/$p:$p") +done +git filter-repo --force \ + --path other_packages \ + --path apis/python/src/tiledbsoma/ml \ + --path .github/workflows/python-tiledbsoma-ml.yml \ + "${renames[@]}" +cd .. + +# Initialize TileDB-SOMA-ML clone, fetch filtered Census and TileDB-SOMA branches from the adjacent directories above: +git clone https://github.com/ryan-williams/TileDB-SOMA-ML soma-ml && cd soma-ml +git remote add c ../census-ml && git fetch c +git remote add t ../soma-pytorch && git fetch t +git reset --hard c/main + +# From the filtered Census HEAD, `git mv` the files to where the TileDB-SOMA branch ported them +tdbs=apis/python/src/tiledbsoma +mkdir -p $tdbs +git mv $ml $tdbs/ + +# Cherry-pick the root commit of the TileDB-SOMA port +root="$(git rev-list --max-parents=0 t/main)" +git cherry-pick $root +# Ensure all files match the TileDB-SOMA root commit +git status --porcelain | grep '^UU' | cut -c4- | xargs git checkout --theirs -- +# Verify there are no diffs vs TileDB-SOMA root commit +git diff --exit-code $root + +# Rebase `$root..t/main` (the rest of the filtered TileDB-SOMA commits) onto cherry-picked HEAD +git reset --hard t/main +git rebase --onto "HEAD@{1}" $root From 9804ee7bf6d5b16c92b7e1e376558f448310f5cf Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 18 Sep 2024 15:46:39 -0400 Subject: [PATCH 62/70] update GHA, repo info --- .github/workflows/python-tiledbsoma-ml.yml | 38 ++++------------------ pyproject.toml | 13 ++++---- 2 files changed, 13 insertions(+), 38 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 6d3e846..a8b3773 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,19 +1,14 @@ -name: python-tiledbsoma-ml +name: python-tiledbsoma-ml CI on: pull_request: branches: ["*"] - paths: - - "!**" - - "other_packages/python/tiledbsoma_ml/**" - - ".github/workflows/python-tiledbsoma-ml.yml" - + paths-ignore: + - 'scripts/**' push: branches: [main] - paths: - - "!**" - - "other_packages/python/tiledbsoma_ml/**" - - ".github/workflows/python-tiledbsoma-ml.yml" + paths-ignore: + - 'scripts/**' workflow_dispatch: @@ -44,7 +39,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -52,32 +47,14 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: pip - cache-dependency-path: python-spec/requirements-py${{ matrix.python-version }}.txt - name: Install prereqs - working-directory: ./other_packages/python/tiledbsoma_ml/ run: | pip install --upgrade pip wheel pytest pytest-cov setuptools pip install . - name: Run tests - run: | - PYTHONPATH=$(pwd)/other_packages/python/tiledbsoma_ml python -m pytest \ - --cov=other_packages/python/tiledbsoma_ml/src \ - --cov-report=xml other_packages/python/tiledbsoma_ml/tests \ - -v - - # - name: Report coverage to Codecov - # if: ${{ matrix.python-version == '3.11' }} - # uses: codecov/codecov-action@v4 - # with: - # flags: python - # # Although Codecov isn't supposed to require an auth token for public repos like this one, - # # the uploader can be unreliable without one; see - # # https://github.com/codecov/codecov-action/issues/557#issuecomment-1216749652 - # # As of this writing (8 Nov 2022) the CODECOV_TOKEN was generated by @aaronwolen in his - # # Codecov settings page for this repo, then filled into the GitHub Actions secrets. - # token: ${{ secrets.CODECOV_TOKEN }} + run: pytest -v --cov=src --cov-report=xml tests build: # for now, just do a test build to ensure that it works @@ -90,7 +67,6 @@ jobs: python-version: "3.11" - name: Do build - working-directory: ./other_packages/python/tiledbsoma_ml/ run: | pip install --upgrade build pip wheel setuptools setuptools-scm python -m build . diff --git a/pyproject.toml b/pyproject.toml index 530d1f4..3b2517e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "pyarrow", "scipy" ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" description = "Machine learning tools for use with tiledbsoma" readme = "README.md" authors = [ @@ -39,7 +39,6 @@ classifiers = [ "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -47,9 +46,9 @@ classifiers = [ ] [project.urls] -Repository = "https://github.com/single-cell-data/TileDB-SOMA.git" -Issues = "https://github.com/single-cell-data/TileDB-SOMA/issues" -Changelog = "https://github.com/single-cell-data/TileDB-SOMA/blob/main/CHANGELOG.md" +Repository = "https://github.com/TileDB-Inc/TileDB-SOMA-ML.git" +Issues = "https://github.com/TileDB-Inc/TileDB-SOMA-ML/issues" +Changelog = "https://github.com/TileDB-Inc/TileDB-SOMA-ML/blob/main/CHANGELOG.md" [tool.setuptools.dynamic] version = {attr = "tiledbsoma_ml.__version__"} @@ -65,7 +64,7 @@ show_error_codes = true ignore_missing_imports = true warn_unreachable = true strict = true -python_version = 3.8 +python_version = 3.9 plugins = "numpy.typing.mypy_plugin" [tool.ruff] @@ -73,7 +72,7 @@ lint.select = ["E", "F", "B", "I"] lint.ignore = ["E501"] # line too long lint.extend-select = ["I001"] # unsorted-imports fix = true -target-version = "py38" +target-version = "py39" line-length = 120 From 5347805cea13dc0d16d85a4fa3eb8bc4305a54d0 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 18 Sep 2024 16:46:23 -0400 Subject: [PATCH 63/70] add `.pre-commit-config.yaml`, run lint --- .pre-commit-config.yaml | 23 +++++++++++++++++++++++ notebooks/tutorial_lightning.ipynb | 3 ++- notebooks/tutorial_multiworker.ipynb | 3 ++- notebooks/tutorial_pytorch.ipynb | 3 ++- tests/test_pytorch.py | 26 +++++++++++++------------- 5 files changed, 42 insertions(+), 16 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8735e40 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/psf/black + rev: "24.8.0" + hooks: + - id: black + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.2 + hooks: + - id: ruff + name: "ruff for tiledbsoma_ml" + args: ["--config=pyproject.toml"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.1 + hooks: + - id: mypy + pass_filenames: false + args: ["--config-file=pyproject.toml", "src"] + additional_dependencies: + - attrs + - numpy + - pandas-stubs>=2 diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index 885f502..72f11c1 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -47,12 +47,13 @@ ], "source": [ "import pytorch_lightning as pl\n", - "import tiledbsoma_ml as soma_ml\n", "import torch\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "import tiledbsoma as soma\n", "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", "\n", "experiment = soma.open(\n", diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb index 1711f3a..5daa6b0 100644 --- a/notebooks/tutorial_multiworker.ipynb +++ b/notebooks/tutorial_multiworker.ipynb @@ -40,12 +40,13 @@ } ], "source": [ - "import tiledbsoma_ml as soma_ml\n", "import torch\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "import tiledbsoma as soma\n", "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", "\n", "experiment = soma.open(\n", diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index 581beaa..da6b08a 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -42,11 +42,12 @@ "metadata": {}, "outputs": [], "source": [ - "import tiledbsoma_ml as soma_ml\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "import tiledbsoma as soma\n", "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", "\n", "experiment = soma.open(\n", diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 1e7ed02..bf4c385 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -26,13 +26,14 @@ # This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent # tests. try: + from torch.utils.data._utils.worker import WorkerInfo + from tiledbsoma_ml.pytorch import ( ExperimentAxisQueryIterable, ExperimentAxisQueryIterableDataset, ExperimentAxisQueryIterDataPipe, experiment_dataloader, ) - from torch.utils.data._utils.worker import WorkerInfo except ImportError: # this should only occur when not running `ml`-marked tests pass @@ -538,11 +539,11 @@ def test_distributed__returns_data_partition_for_rank( """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, using mocks to avoid having to do real PyTorch distributed setup.""" - with patch("torch.distributed.is_initialized") as mock_dist_is_initialized, patch( - "torch.distributed.get_rank" - ) as mock_dist_get_rank, patch( - "torch.distributed.get_world_size" - ) as mock_dist_get_world_size: + with ( + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): mock_dist_is_initialized.return_value = True mock_dist_get_rank.return_value = rank mock_dist_get_world_size.return_value = world_size @@ -593,13 +594,12 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch setup or real DataLoader multiprocessing.""" - with patch("torch.utils.data.get_worker_info") as mock_get_worker_info, patch( - "torch.distributed.is_initialized" - ) as mock_dist_is_initialized, patch( - "torch.distributed.get_rank" - ) as mock_dist_get_rank, patch( - "torch.distributed.get_world_size" - ) as mock_dist_get_world_size: + with ( + patch("torch.utils.data.get_worker_info") as mock_get_worker_info, + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): mock_get_worker_info.return_value = WorkerInfo( id=worker_id, num_workers=num_workers, seed=1234 ) From da1389ae28d6d7c7d7d7e462638208ccca8dd9b2 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 18 Sep 2024 17:55:58 -0700 Subject: [PATCH 64/70] add .gitignore --- .gitignore | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..efa407c --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file From 672d306b71fe3dc7a4e32a3151d12a8585febf47 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 18 Sep 2024 17:56:40 -0700 Subject: [PATCH 65/70] autoupdate pre-commit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8735e40..6385ca9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,14 +5,14 @@ repos: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.5 hooks: - id: ruff name: "ruff for tiledbsoma_ml" args: ["--config=pyproject.toml"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy pass_filenames: false From 975dd1264b7b51b8eb075b433eef4fd2cb25c9da Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Wed, 18 Sep 2024 18:12:34 -0700 Subject: [PATCH 66/70] remove tiledbsoma-specific ruff/isort rules --- notebooks/tutorial_lightning.ipynb | 3 +-- notebooks/tutorial_multiworker.ipynb | 3 +-- notebooks/tutorial_pytorch.ipynb | 3 +-- pyproject.toml | 10 ---------- src/tiledbsoma_ml/pytorch.py | 3 +-- tests/test_pytorch.py | 3 +-- 6 files changed, 5 insertions(+), 20 deletions(-) diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb index 72f11c1..daeedf5 100644 --- a/notebooks/tutorial_lightning.ipynb +++ b/notebooks/tutorial_lightning.ipynb @@ -47,11 +47,10 @@ ], "source": [ "import pytorch_lightning as pl\n", + "import tiledbsoma as soma\n", "import torch\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", - "import tiledbsoma as soma\n", - "\n", "import tiledbsoma_ml as soma_ml\n", "\n", "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb index 5daa6b0..37e17e8 100644 --- a/notebooks/tutorial_multiworker.ipynb +++ b/notebooks/tutorial_multiworker.ipynb @@ -40,11 +40,10 @@ } ], "source": [ + "import tiledbsoma as soma\n", "import torch\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", - "import tiledbsoma as soma\n", - "\n", "import tiledbsoma_ml as soma_ml\n", "\n", "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb index da6b08a..70c62e3 100644 --- a/notebooks/tutorial_pytorch.ipynb +++ b/notebooks/tutorial_pytorch.ipynb @@ -42,9 +42,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn.preprocessing import LabelEncoder\n", - "\n", "import tiledbsoma as soma\n", + "from sklearn.preprocessing import LabelEncoder\n", "\n", "import tiledbsoma_ml as soma_ml\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index 3b2517e..ab21890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,13 +74,3 @@ lint.extend-select = ["I001"] # unsorted-imports fix = true target-version = "py39" line-length = 120 - - -[tool.ruff.lint.isort] -# HACK: tiledb needs to come after tiledbsoma: https://github.com/single-cell-data/TileDB-SOMA/issues/2293 -section-order = ["future", "standard-library", "third-party", "tiledbsoma", "tiledb", "first-party", "local-folder"] -no-lines-before = ["tiledb"] - -[tool.ruff.lint.isort.sections] -"tiledbsoma" = ["tiledbsoma"] -"tiledb" = ["tiledb"] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index cd2a8ea..0e3bb33 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -37,13 +37,12 @@ import pandas as pd import pyarrow as pa import scipy.sparse as sparse +import tiledbsoma as soma import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator from typing_extensions import Self, TypeAlias -import tiledbsoma as soma - logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index bf4c385..ed6440e 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -15,10 +15,9 @@ import pandas as pd import pyarrow as pa import pytest +import tiledbsoma as soma from scipy import sparse from scipy.sparse import coo_matrix, spmatrix - -import tiledbsoma as soma from tiledbsoma import Experiment, _factory from tiledbsoma._collection import CollectionBase From 96c1516e300291eeaa9f46965bd325a28cb64b4b Mon Sep 17 00:00:00 2001 From: Bruce Martin Date: Thu, 19 Sep 2024 10:36:57 -0700 Subject: [PATCH 67/70] Package dependency pins & test (#2) * add minimum version for several dependencies * add compat test for primary dependencies * fix typo in workflow * fix another typo * compat matrix refinement * fix quoting * refine compat test matrix * further simplify matrix * update changelog to use correct links (#3) --- .../python-tilledbsoma-ml-compat.yml | 46 +++++++++++++++++++ CHANGELOG.md | 8 +++- pyproject.toml | 6 +-- 3 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/python-tilledbsoma-ml-compat.yml diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml new file mode 100644 index 0000000..61a5001 --- /dev/null +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -0,0 +1,46 @@ +name: python-tiledbsoma-ml past tiledbsoma compat # Latest tiledbsoma version covered by another workflow + +on: + pull_request: + branches: ["*"] + paths-ignore: + - "scripts/**" + - "notebooks/**" + push: + branches: [main] + paths-ignore: + - "scripts/**" + - "notebooks/**" + +jobs: + unit_tests: + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] # could add 'macos-latest', but the matrix is already huge... + python-version: ["3.9", "3.10", "3.11"] # TODO: add 3.12 when tiledbsoma releases wheels for it. + pkg-version: + - "tiledbsoma~=1.9.0 'numpy<2.0.0'" + - "tiledbsoma~=1.10.0 'numpy<2.0.0'" + - "tiledbsoma~=1.11.0" + - "tiledbsoma~=1.12.0" + - "tiledbsoma~=1.13.0" + - "tiledbsoma~=1.14.0" + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip pytest setuptools + pip install ${{ matrix.pkg-version }} . + + - name: Run tests + run: pytest -v tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f0d697..957c631 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] - yyyy-mm-dd -Porting and enhancing initial code contribution from the Chan Zuckerberg Initiative Foundation +Port and enhance contribution from the Chan Zuckerberg Initiative Foundation [CELLxGENE](https://cellxgene.cziscience.com/) project. +This is not a one-for-one migration of the contributed code. Substantial changes have +been made to the package utility (e.g., multi-GPU support), improve API usability, etc. + ### Added -- Initial commits via PR [#2823](https://github.com/single-cell-data/TileDB-SOMA/pull/2823) +- Initial commits via [PR #1](https://github.com/single-cell-data/TileDB-SOMA-ML/pull/1) +- Refine package dependency pins and compatibility tests via [PR #2](https://github.com/single-cell-data/TileDB-SOMA-ML/pull/2) ### Changed diff --git a/pyproject.toml b/pyproject.toml index ab21890..7901253 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,9 @@ build-backend = "setuptools.build_meta" name = "tiledbsoma-ml" dynamic = ["version"] dependencies = [ - "attrs", - "tiledbsoma", - "torch", + "attrs>=22.2", + "tiledbsoma>=1.9.0", + "torch>=2.0", "torchdata<=0.9", "numpy", "numba", From 2a61940dcc780d05fa152e21656904bd865d9e2c Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 23 Sep 2024 10:48:40 -0700 Subject: [PATCH 68/70] additional tests of implementation base class --- tests/test_pytorch.py | 207 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 171 insertions(+), 36 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index ed6440e..9df6d7f 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -154,10 +154,19 @@ def soma_experiment( ) @pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_non_batched( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool, @@ -188,8 +197,8 @@ def test_non_batched( else: assert isinstance(row[0], np.ndarray) - assert row[0].shape == (3,) - assert row[0].tolist() == [0, 1, 0] + assert np.squeeze(row[0]).shape == (3,) + assert np.squeeze(row[0]).tolist() == [0, 1, 0] assert isinstance(row[1], pd.DataFrame) assert row[1].shape == (1, 1) @@ -203,10 +212,19 @@ def test_non_batched( ) @pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_uneven_soma_and_result_batches( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool, @@ -247,10 +265,19 @@ def test_uneven_soma_and_result_batches( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_batching__all_batches_full_size( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -287,10 +314,19 @@ def test_batching__all_batches_full_size( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_unique_soma_joinids( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -314,10 +350,19 @@ def test_unique_soma_joinids( [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_batching__partial_final_batch_size( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -345,10 +390,19 @@ def test_batching__partial_final_batch_size( [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_batching__exactly_one_batch( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -376,10 +430,19 @@ def test_batching__exactly_one_batch( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_batching__empty_query_result( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -404,10 +467,19 @@ def test_batching__empty_query_result( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_sparse_output__non_batched( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -432,10 +504,19 @@ def test_sparse_output__non_batched( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_sparse_output__batched( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -464,10 +545,19 @@ def test_sparse_output__batched( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_batching__partial_soma_batches_are_concatenated( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -491,7 +581,8 @@ def test_batching__partial_soma_batches_are_concatenated( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), ) def test_multiprocessing__returns_full_result( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, @@ -526,10 +617,19 @@ def test_multiprocessing__returns_full_result( [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_distributed__returns_data_partition_for_rank( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, obs_range: int, world_size: int, @@ -578,10 +678,19 @@ def test_distributed__returns_data_partition_for_rank( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test_distributed_and_multiprocessing__returns_data_partition_for_rank( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, obs_range: int, world_size: int, @@ -633,7 +742,8 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), ) def test_experiment_dataloader__non_batched( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, @@ -663,7 +773,8 @@ def test_experiment_dataloader__non_batched( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), ) def test_experiment_dataloader__batched( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, @@ -694,7 +805,8 @@ def test_experiment_dataloader__batched( ], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), ) def test_experiment_dataloader__batched_length( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, @@ -719,7 +831,8 @@ def test_experiment_dataloader__batched_length( [(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset), ) def test_experiment_dataloader__collate_fn( PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, @@ -757,10 +870,19 @@ def collate_fn( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test__X_tensor_dtype_matches_X_matrix( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -803,10 +925,19 @@ def test__pytorch_splitting( "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] ) @pytest.mark.parametrize( - "PipeClass", (ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset) + "PipeClass", + ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, + ), ) def test__shuffle( - PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset, + PipeClass: ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset + ), soma_experiment: Experiment, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: @@ -817,7 +948,11 @@ def test__shuffle( ) all_rows = list(iter(dp)) - assert all(r[0].shape == (1,) for r in all_rows) + print([r[0].shape for r in all_rows]) + if PipeClass is ExperimentAxisQueryIterable: + assert all(np.squeeze(r[0], axis=0).shape == (1,) for r in all_rows) + else: + assert all(r[0].shape == (1,) for r in all_rows) soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] X_values = [row[0][0].item() for row in all_rows] From ccdfccd11c892a27efef9be3c4e73e7cd6637ec5 Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 23 Sep 2024 11:15:46 -0700 Subject: [PATCH 69/70] refactoring test params --- tests/test_pytorch.py | 235 ++++++++---------------------------------- 1 file changed, 42 insertions(+), 193 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 9df6d7f..c7aeb20 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -38,6 +38,20 @@ pass +# These control which classes are tested (for most, but not all tests). +# Centralized to allow easy add/delete of specific test parameters. +PipeClassType = ( + ExperimentAxisQueryIterable + | ExperimentAxisQueryIterDataPipe + | ExperimentAxisQueryIterableDataset +) +PipeClassImplementation = ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +) + + def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: occupied_shape = ( obs_range.stop - obs_range.start, @@ -153,20 +167,9 @@ def soma_experiment( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize("return_sparse_X", [True, False]) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_non_batched( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool, @@ -211,20 +214,9 @@ def test_non_batched( [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) @pytest.mark.parametrize("return_sparse_X", [True, False]) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_uneven_soma_and_result_batches( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool, @@ -264,20 +256,9 @@ def test_uneven_soma_and_result_batches( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__all_batches_full_size( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -313,20 +294,9 @@ def test_batching__all_batches_full_size( for use_eager_fetch in (True, False) ], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_unique_soma_joinids( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -349,20 +319,9 @@ def test_unique_soma_joinids( "obs_range,var_range,X_value_gen,use_eager_fetch", [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__partial_final_batch_size( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -389,20 +348,9 @@ def test_batching__partial_final_batch_size( "obs_range,var_range,X_value_gen,use_eager_fetch", [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__exactly_one_batch( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -429,20 +377,9 @@ def test_batching__exactly_one_batch( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__empty_query_result( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, ) -> None: @@ -466,22 +403,9 @@ def test_batching__empty_query_result( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_sparse_output__non_batched( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), - soma_experiment: Experiment, - use_eager_fetch: bool, + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -503,22 +427,9 @@ def test_sparse_output__non_batched( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_sparse_output__batched( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), - soma_experiment: Experiment, - use_eager_fetch: bool, + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -544,22 +455,9 @@ def test_sparse_output__batched( for use_eager_fetch in (True, False) ], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__partial_soma_batches_are_concatenated( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), - soma_experiment: Experiment, - use_eager_fetch: bool, + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -616,20 +514,9 @@ def test_multiprocessing__returns_full_result( "world_size,rank", [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_distributed__returns_data_partition_for_rank( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, obs_range: int, world_size: int, @@ -677,20 +564,9 @@ def test_distributed__returns_data_partition_for_rank( (3, 1, 2, 1), ], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_distributed_and_multiprocessing__returns_data_partition_for_rank( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), + PipeClass: PipeClassType, soma_experiment: Experiment, obs_range: int, world_size: int, @@ -869,22 +745,9 @@ def collate_fn( "obs_range,var_range,X_value_gen,use_eager_fetch", [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test__X_tensor_dtype_matches_X_matrix( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), - soma_experiment: Experiment, - use_eager_fetch: bool, + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( @@ -924,22 +787,8 @@ def test__pytorch_splitting( @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] ) -@pytest.mark.parametrize( - "PipeClass", - ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterDataPipe, - ExperimentAxisQueryIterableDataset, - ), -) -def test__shuffle( - PipeClass: ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset - ), - soma_experiment: Experiment, -) -> None: +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test__shuffle(PipeClass: PipeClassType, soma_experiment: Experiment) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: dp = PipeClass( query, From aef4c15d7678735fa37578e3f35263e1ff5f135e Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 23 Sep 2024 11:20:41 -0700 Subject: [PATCH 70/70] fix py 3.9 incompatibility --- tests/test_pytorch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index c7aeb20..c268158 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -40,11 +40,11 @@ # These control which classes are tested (for most, but not all tests). # Centralized to allow easy add/delete of specific test parameters. -PipeClassType = ( - ExperimentAxisQueryIterable - | ExperimentAxisQueryIterDataPipe - | ExperimentAxisQueryIterableDataset -) +PipeClassType = Union[ + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +] PipeClassImplementation = ( ExperimentAxisQueryIterable, ExperimentAxisQueryIterDataPipe,