diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 4a1e9f1..fdd82d0 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -3,10 +3,14 @@ name: python-tiledbsoma-ml CI on: pull_request: branches: ["**"] - paths-ignore: ['scripts/**'] + paths-ignore: + - "scripts/**" + - "notebooks/**" push: branches: [main] - paths-ignore: ['scripts/**'] + paths-ignore: + - "scripts/**" + - "notebooks/**" workflow_dispatch: jobs: diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb new file mode 100644 index 0000000..daeedf5 --- /dev/null +++ b/notebooks/tutorial_lightning.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a model with PyTorch Lightning\n", + "\n", + "This tutorial provides a quick overview of training a toy model with Lightning, using the `tiledbsoma_ml.ExperimentAxisQueryIterableDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterableDataset`, and not as an example of how to train a biologically useful model.\n", + "\n", + "For more information on these API, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).\n", + "\n", + "**Prerequesites**\n", + "\n", + "Install `tiledbsoma_ml` and `scikit-learn`, for example:\n", + "\n", + "> pip install tiledbsoma_ml scikit-learn\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SOMA Experiment query as training data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import pytorch_lightning as pl\n", + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Lightning module" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class LogisticRegressionLightning(pl.LightningModule):\n", + " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n", + " super(LogisticRegressionLightning, self).__init__()\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + " self.cell_type_encoder = cell_type_encoder\n", + " self.learning_rate = learning_rate\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " X_batch, y_batch = batch\n", + " # X_batch = X_batch.float()\n", + " X_batch = torch.from_numpy(X_batch).float().to(self.device)\n", + "\n", + " # Perform prediction\n", + " outputs = self(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute loss\n", + " y_batch = torch.from_numpy(\n", + " self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n", + " ).to(self.device)\n", + " loss = self.loss_fn(outputs, y_batch.long())\n", + "\n", + " # Compute accuracy\n", + " train_correct = (predictions == y_batch).sum().item()\n", + " train_accuracy = train_correct / len(predictions)\n", + "\n", + " # Log loss and accuracy\n", + " self.log(\"train_loss\", loss, prog_bar=True)\n", + " self.log(\"train_accuracy\", train_accuracy, prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------\n", + "0 | linear | Linear | 726 K | train\n", + "1 | loss_fn | CrossEntropyLoss | 0 | train\n", + "-----------------------------------------------------\n", + "726 K Trainable params\n", + "0 Non-trainable params\n", + "726 K Total params\n", + "2.905 Total estimated model params size (MB)\n", + "2 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=20` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n" + ] + } + ], + "source": [ + "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "# Initialize the PyTorch Lightning model\n", + "model = LogisticRegressionLightning(\n", + " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", + ")\n", + "\n", + "# Define the PyTorch Lightning Trainer\n", + "trainer = pl.Trainer(max_epochs=20)\n", + "\n", + "# set precision\n", + "torch.set_float32_matmul_precision(\"high\")\n", + "\n", + "# Train the model\n", + "trainer.fit(model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb new file mode 100644 index 0000000..37e17e8 --- /dev/null +++ b/notebooks/tutorial_multiworker.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-process training\n", + "\n", + "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n", + "* using the `torch.utils.data.DataLoader` with 1 or more worker (ie., with an argument of `n_workers=1` or greater)\n", + "* using a multi-process training configuration, such as `DistributedDataParallel`\n", + "\n", + "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", + "\n", + "1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n", + "2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + " \n", + "\n", + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-worker DataLoader\n", + "\n", + "If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.\n", + "\n", + "The same approach should be taken for parallel training, e.g., when using DDP or DP.\n", + "\n", + "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n", + "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n", + "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n", + "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n", + "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n", + "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n", + "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n", + "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n", + "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n", + "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n", + "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n", + "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n", + "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n", + "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n", + "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n", + "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n", + "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n", + "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n", + "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n", + "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "\n", + "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", + "# that a different shuffle is applied on each epoch.\n", + "experiment_dataloader = soma_ml.experiment_dataloader(\n", + " experiment_dataset, num_workers=2, persistent_workers=True\n", + ")\n", + "\n", + "for epoch in range(20):\n", + " experiment_dataset.set_epoch(epoch)\n", + " train_loss, train_accuracy = train_epoch(\n", + " model, experiment_dataloader, loss_fn, optimizer, device\n", + " )\n", + " print(\n", + " f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb new file mode 100644 index 0000000..70c62e3 --- /dev/null +++ b/notebooks/tutorial_pytorch.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a PyTorch Model\n", + "\n", + "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n", + "\n", + "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", + "\n", + "**Prerequisites**\n", + "\n", + "Install `tiledbsoma` with the optional `ml` dependencies, for example:\n", + "\n", + "> pip install tiledbsoma[ml]\n", + "\n", + "\n", + "**Contents**\n", + "\n", + "* [Create a DataLoader](#Create-a-DataLoader)\n", + "* [Define the model](#Define-the-model)\n", + "* [Train the model](#Train-the-model)\n", + "* [Make predictions with the model](#Make-predictions-with-the-model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an ExperimentAxisQueryIterDataPipe\n", + "\n", + "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryIterDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", + "\n", + "We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import tiledbsoma as soma\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `ExperimentAxisQueryIterDataPipe` class explained\n", + "\n", + "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", + "\n", + "### `ExperimentAxisQueryIterDataPipe` parameters explained\n", + "\n", + "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", + "\n", + "To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.\n", + "\n", + "The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.\n", + "\n", + "The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).\n", + "\n", + "The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:\n", + "* You should use this flag instead of the `DataLoader` `shuffle` flag, primarily for performance reasons.\n", + "* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterableDataset`. However, the `Shuffler` will not \"globally\" randomize the training data, as it only \"locally\" randomizes the ordering of the training data within fixed-size \"windows\". Due to the layout of Census data, a given \"window\" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The `shape` property returns the number of batches on the first dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(118, 60530)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_dataset.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split the dataset\n", + "\n", + "You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the DataLoader\n", + "\n", + "With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(train_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryIterDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the model\n", + "\n", + "With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we define a function to train the model for a single epoch:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", + " [0., 0., 2., ..., 0., 3., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 1., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 8.]])\n", + " \n", + "```\n", + "\n", + "For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([0., 0., 0., ..., 1., 0., 0.])\n", + "```\n", + " \n", + "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", + "\n", + "```\n", + "tensor([1, 1, 3, ..., 2, 1, 4])\n", + "\n", + "```\n", + "Note that cell type values are integer-encoded values, which can be decoded using `experiment_dataset.encoders` (more on this below).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model\n", + "\n", + "Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n", + "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n", + "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n", + "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n", + "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n", + "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n", + "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n", + "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n", + "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n", + "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n", + "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n", + "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n", + "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n", + "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n", + "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n", + "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n", + "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n", + "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n", + "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n", + "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "for epoch in range(20):\n", + " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", + " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make predictions with the model\n", + "\n", + "To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the \"test\" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(test_dataset)\n", + "X_batch, y_batch = next(iter(experiment_dataloader))\n", + "X_batch = torch.from_numpy(X_batch)\n", + "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we invoke the model on the `X_batch` input data and extract the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n", + " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n", + " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n", + " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n", + " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n", + " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n", + " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n", + " 8, 1], device='cuda:0')" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "\n", + "model.to(device)\n", + "outputs = model(X_batch.to(device))\n", + "\n", + "probabilities = torch.nn.functional.softmax(outputs, 1)\n", + "predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + "display(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", + "\n", + "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n", + " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n", + " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell'], dtype=object)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", + "\n", + "display(predicted_cell_types)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we create a Pandas DataFrame to examine the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
actual cell typepredicted cell type
0leukocyteleukocyte
1basal cellbasal cell
2basal cellbasal cell
3basal cellbasal cell
4basal cellbasal cell
.........
123fibroblastbasal cell
124basal cellbasal cell
125keratinocytebasal cell
126leukocyteleukocyte
127basal cellbasal cell
\n", + "

128 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " actual cell type predicted cell type\n", + "0 leukocyte leukocyte\n", + "1 basal cell basal cell\n", + "2 basal cell basal cell\n", + "3 basal cell basal cell\n", + "4 basal cell basal cell\n", + ".. ... ...\n", + "123 fibroblast basal cell\n", + "124 basal cell basal cell\n", + "125 keratinocyte basal cell\n", + "126 leukocyte leukocyte\n", + "127 basal cell basal cell\n", + "\n", + "[128 rows x 2 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "display(\n", + " pd.DataFrame(\n", + " {\n", + " \"actual cell type\": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),\n", + " \"predicted cell type\": predicted_cell_types,\n", + " }\n", + " )\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tiledbsoma-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}