diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 4a1e9f1..fdd82d0 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -3,10 +3,14 @@ name: python-tiledbsoma-ml CI on: pull_request: branches: ["**"] - paths-ignore: ['scripts/**'] + paths-ignore: + - "scripts/**" + - "notebooks/**" push: branches: [main] - paths-ignore: ['scripts/**'] + paths-ignore: + - "scripts/**" + - "notebooks/**" workflow_dispatch: jobs: diff --git a/notebooks/tutorial_lightning.ipynb b/notebooks/tutorial_lightning.ipynb new file mode 100644 index 0000000..daeedf5 --- /dev/null +++ b/notebooks/tutorial_lightning.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a model with PyTorch Lightning\n", + "\n", + "This tutorial provides a quick overview of training a toy model with Lightning, using the `tiledbsoma_ml.ExperimentAxisQueryIterableDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterableDataset`, and not as an example of how to train a biologically useful model.\n", + "\n", + "For more information on these API, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).\n", + "\n", + "**Prerequesites**\n", + "\n", + "Install `tiledbsoma_ml` and `scikit-learn`, for example:\n", + "\n", + "> pip install tiledbsoma_ml scikit-learn\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize SOMA Experiment query as training data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import pytorch_lightning as pl\n", + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Lightning module" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class LogisticRegressionLightning(pl.LightningModule):\n", + " def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):\n", + " super(LogisticRegressionLightning, self).__init__()\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + " self.cell_type_encoder = cell_type_encoder\n", + " self.learning_rate = learning_rate\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " X_batch, y_batch = batch\n", + " # X_batch = X_batch.float()\n", + " X_batch = torch.from_numpy(X_batch).float().to(self.device)\n", + "\n", + " # Perform prediction\n", + " outputs = self(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute loss\n", + " y_batch = torch.from_numpy(\n", + " self.cell_type_encoder.transform(y_batch[\"cell_type\"])\n", + " ).to(self.device)\n", + " loss = self.loss_fn(outputs, y_batch.long())\n", + "\n", + " # Compute accuracy\n", + " train_correct = (predictions == y_batch).sum().item()\n", + " train_accuracy = train_correct / len(predictions)\n", + "\n", + " # Log loss and accuracy\n", + " self.log(\"train_loss\", loss, prog_bar=True)\n", + " self.log(\"train_accuracy\", train_accuracy, prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------\n", + "0 | linear | Linear | 726 K | train\n", + "1 | loss_fn | CrossEntropyLoss | 0 | train\n", + "-----------------------------------------------------\n", + "726 K Trainable params\n", + "0 Non-trainable params\n", + "726 K Total params\n", + "2.905 Total estimated model params size (MB)\n", + "2 Modules in train mode\n", + "0 Modules in eval mode\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.31it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=20` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19: 100%|██████████| 118/118 [00:08<00:00, 14.28it/s, v_num=5, train_loss=1.670, train_accuracy=0.977]\n" + ] + } + ], + "source": [ + "dataloader = soma_ml.experiment_dataloader(experiment_dataset)\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "# Initialize the PyTorch Lightning model\n", + "model = LogisticRegressionLightning(\n", + " input_dim, output_dim, cell_type_encoder=cell_type_encoder\n", + ")\n", + "\n", + "# Define the PyTorch Lightning Trainer\n", + "trainer = pl.Trainer(max_epochs=20)\n", + "\n", + "# set precision\n", + "torch.set_float32_matmul_precision(\"high\")\n", + "\n", + "# Train the model\n", + "trainer.fit(model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_multiworker.ipynb b/notebooks/tutorial_multiworker.ipynb new file mode 100644 index 0000000..37e17e8 --- /dev/null +++ b/notebooks/tutorial_multiworker.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-process training\n", + "\n", + "Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n", + "* using the `torch.utils.data.DataLoader` with 1 or more worker (ie., with an argument of `n_workers=1` or greater)\n", + "* using a multi-process training configuration, such as `DistributedDataParallel`\n", + "\n", + "In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n", + "\n", + "1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n", + "2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + } + ], + "source": [ + "import tiledbsoma as soma\n", + "import torch\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterableDataset(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs\n", + " \n", + "\n", + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-worker DataLoader\n", + "\n", + "If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.\n", + "\n", + "The same approach should be taken for parallel training, e.g., when using DDP or DP.\n", + "\n", + "*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "switching torch multiprocessing start method from \"fork\" to \"spawn\"\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n", + "/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", + "################################################################################\n", + "WARNING!\n", + "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", + "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", + "to learn more and leave feedback.\n", + "################################################################################\n", + "\n", + " deprecation_warning()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0169229 Accuracy 0.3124\n", + "Epoch 2: Train Loss: 0.0148674 Accuracy 0.4272\n", + "Epoch 3: Train Loss: 0.0144468 Accuracy 0.4509\n", + "Epoch 4: Train Loss: 0.0141778 Accuracy 0.4999\n", + "Epoch 5: Train Loss: 0.0139660 Accuracy 0.5619\n", + "Epoch 6: Train Loss: 0.0137670 Accuracy 0.6971\n", + "Epoch 7: Train Loss: 0.0136089 Accuracy 0.8670\n", + "Epoch 8: Train Loss: 0.0135203 Accuracy 0.9099\n", + "Epoch 9: Train Loss: 0.0134427 Accuracy 0.9262\n", + "Epoch 10: Train Loss: 0.0133607 Accuracy 0.9300\n", + "Epoch 11: Train Loss: 0.0133110 Accuracy 0.9348\n", + "Epoch 12: Train Loss: 0.0132749 Accuracy 0.9378\n", + "Epoch 13: Train Loss: 0.0132431 Accuracy 0.9413\n", + "Epoch 14: Train Loss: 0.0132194 Accuracy 0.9444\n", + "Epoch 15: Train Loss: 0.0131942 Accuracy 0.9465\n", + "Epoch 16: Train Loss: 0.0131739 Accuracy 0.9499\n", + "Epoch 17: Train Loss: 0.0131527 Accuracy 0.9526\n", + "Epoch 18: Train Loss: 0.0131369 Accuracy 0.9551\n", + "Epoch 19: Train Loss: 0.0131214 Accuracy 0.9563\n", + "Epoch 20: Train Loss: 0.0131061 Accuracy 0.9578\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "\n", + "# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure\n", + "# that a different shuffle is applied on each epoch.\n", + "experiment_dataloader = soma_ml.experiment_dataloader(\n", + " experiment_dataset, num_workers=2, persistent_workers=True\n", + ")\n", + "\n", + "for epoch in range(20):\n", + " experiment_dataset.set_epoch(epoch)\n", + " train_loss, train_accuracy = train_epoch(\n", + " model, experiment_dataloader, loss_fn, optimizer, device\n", + " )\n", + " print(\n", + " f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "toymodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/tutorial_pytorch.ipynb b/notebooks/tutorial_pytorch.ipynb new file mode 100644 index 0000000..70c62e3 --- /dev/null +++ b/notebooks/tutorial_pytorch.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a PyTorch Model\n", + "\n", + "This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.\n", + "\n", + "This tutorial assumes a basic familiarity with PyTorch and the Census API.\n", + "\n", + "**Prerequisites**\n", + "\n", + "Install `tiledbsoma` with the optional `ml` dependencies, for example:\n", + "\n", + "> pip install tiledbsoma[ml]\n", + "\n", + "\n", + "**Contents**\n", + "\n", + "* [Create a DataLoader](#Create-a-DataLoader)\n", + "* [Define the model](#Define-the-model)\n", + "* [Train the model](#Train-the-model)\n", + "* [Make predictions with the model](#Make-predictions-with-the-model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an ExperimentAxisQueryIterDataPipe\n", + "\n", + "To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryIterDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.\n", + "\n", + "We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import tiledbsoma as soma\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "import tiledbsoma_ml as soma_ml\n", + "\n", + "CZI_Census_Homo_Sapiens_URL = \"s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/\"\n", + "\n", + "experiment = soma.open(\n", + " CZI_Census_Homo_Sapiens_URL,\n", + " context=soma.SOMATileDBContext(tiledb_config={\"vfs.s3.region\": \"us-west-2\"}),\n", + ")\n", + "obs_value_filter = \"tissue_general == 'tongue' and is_primary_data == True\"\n", + "\n", + "with experiment.axis_query(\n", + " measurement_name=\"RNA\", obs_query=soma.AxisQuery(value_filter=obs_value_filter)\n", + ") as query:\n", + " obs_df = query.obs(column_names=[\"cell_type\"]).concat().to_pandas()\n", + " cell_type_encoder = LabelEncoder().fit(obs_df[\"cell_type\"].unique())\n", + "\n", + " experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(\n", + " query,\n", + " X_name=\"raw\",\n", + " obs_column_names=[\"cell_type\"],\n", + " batch_size=128,\n", + " shuffle=True,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `ExperimentAxisQueryIterDataPipe` class explained\n", + "\n", + "This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.\n", + "\n", + "### `ExperimentAxisQueryIterDataPipe` parameters explained\n", + "\n", + "The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.\n", + "\n", + "To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.\n", + "\n", + "The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.\n", + "\n", + "The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).\n", + "\n", + "The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:\n", + "* You should use this flag instead of the `DataLoader` `shuffle` flag, primarily for performance reasons.\n", + "* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterableDataset`. However, the `Shuffler` will not \"globally\" randomize the training data, as it only \"locally\" randomizes the ordering of the training data within fixed-size \"windows\". Due to the layout of Census data, a given \"window\" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The `shape` property returns the number of batches on the first dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(118, 60530)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_dataset.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split the dataset\n", + "\n", + "You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset, test_dataset = experiment_dataset.random_split(weights={\"train\": 0.8, \"test\": 0.2}, seed=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the DataLoader\n", + "\n", + "With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(train_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryIterDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the model\n", + "\n", + "With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "\n", + "class LogisticRegression(torch.nn.Module):\n", + " def __init__(self, input_dim, output_dim):\n", + " super(LogisticRegression, self).__init__() # noqa: UP008\n", + " self.linear = torch.nn.Linear(input_dim, output_dim)\n", + "\n", + " def forward(self, x):\n", + " outputs = torch.sigmoid(self.linear(x))\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we define a function to train the model for a single epoch:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def train_epoch(model, train_dataloader, loss_fn, optimizer, device):\n", + " model.train()\n", + " train_loss = 0\n", + " train_correct = 0\n", + " train_total = 0\n", + "\n", + " for X_batch, y_batch in train_dataloader:\n", + " optimizer.zero_grad()\n", + "\n", + " X_batch = torch.from_numpy(X_batch).float().to(device)\n", + "\n", + " # Perform prediction\n", + " outputs = model(X_batch)\n", + "\n", + " # Determine the predicted label\n", + " probabilities = torch.nn.functional.softmax(outputs, 1)\n", + " predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + " # Compute the loss and perform back propagation\n", + " y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)\n", + " train_correct += (predictions == y_batch).sum().item()\n", + " train_total += len(predictions)\n", + "\n", + " loss = loss_fn(outputs, y_batch.long())\n", + " train_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss /= train_total\n", + " train_accuracy = train_correct / train_total\n", + " return train_loss, train_accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([[0., 0., 0., ..., 1., 0., 0.],\n", + " [0., 0., 2., ..., 0., 3., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 1., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 8.]])\n", + " \n", + "```\n", + "\n", + "For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:\n", + "\n", + "```\n", + "tensor([0., 0., 0., ..., 1., 0., 0.])\n", + "```\n", + " \n", + "For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:\n", + "\n", + "```\n", + "tensor([1, 1, 3, ..., 2, 1, 4])\n", + "\n", + "```\n", + "Note that cell type values are integer-encoded values, which can be decoded using `experiment_dataset.encoders` (more on this below).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the model\n", + "\n", + "Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1: Train Loss: 0.0171090 Accuracy 0.1798\n", + "Epoch 2: Train Loss: 0.0151506 Accuracy 0.3480\n", + "Epoch 3: Train Loss: 0.0146299 Accuracy 0.4174\n", + "Epoch 4: Train Loss: 0.0142093 Accuracy 0.4765\n", + "Epoch 5: Train Loss: 0.0140261 Accuracy 0.5111\n", + "Epoch 6: Train Loss: 0.0138939 Accuracy 0.5634\n", + "Epoch 7: Train Loss: 0.0137783 Accuracy 0.6182\n", + "Epoch 8: Train Loss: 0.0136766 Accuracy 0.7050\n", + "Epoch 9: Train Loss: 0.0135647 Accuracy 0.8293\n", + "Epoch 10: Train Loss: 0.0134729 Accuracy 0.8793\n", + "Epoch 11: Train Loss: 0.0133968 Accuracy 0.8938\n", + "Epoch 12: Train Loss: 0.0133453 Accuracy 0.9013\n", + "Epoch 13: Train Loss: 0.0133143 Accuracy 0.9047\n", + "Epoch 14: Train Loss: 0.0132873 Accuracy 0.9102\n", + "Epoch 15: Train Loss: 0.0132666 Accuracy 0.9176\n", + "Epoch 16: Train Loss: 0.0132246 Accuracy 0.9219\n", + "Epoch 17: Train Loss: 0.0132161 Accuracy 0.9230\n", + "Epoch 18: Train Loss: 0.0131877 Accuracy 0.9295\n", + "Epoch 19: Train Loss: 0.0131658 Accuracy 0.9344\n", + "Epoch 20: Train Loss: 0.0131338 Accuracy 0.9382\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# The size of the input dimension is the number of genes\n", + "input_dim = experiment_dataset.shape[1]\n", + "\n", + "# The size of the output dimension is the number of distinct cell_type values\n", + "output_dim = len(cell_type_encoder.classes_)\n", + "\n", + "model = LogisticRegression(input_dim, output_dim).to(device)\n", + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)\n", + "\n", + "for epoch in range(20):\n", + " train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)\n", + " print(f\"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make predictions with the model\n", + "\n", + "To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the \"test\" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dataloader = soma_ml.experiment_dataloader(test_dataset)\n", + "X_batch, y_batch = next(iter(experiment_dataloader))\n", + "X_batch = torch.from_numpy(X_batch)\n", + "y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we invoke the model on the `X_batch` input data and extract the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 8, 1, 1, 1, 1, 1, 1, 8, 8, 5, 1, 7, 8, 1, 1, 1, 1, 7,\n", + " 7, 8, 1, 1, 5, 5, 1, 8, 1, 1, 1, 7, 8, 7, 7, 7, 8, 7,\n", + " 5, 1, 1, 8, 1, 5, 8, 5, 1, 11, 1, 7, 1, 1, 5, 5, 1, 11,\n", + " 1, 6, 8, 5, 1, 8, 11, 8, 1, 8, 1, 8, 1, 5, 1, 1, 1, 8,\n", + " 8, 7, 5, 1, 1, 8, 1, 7, 2, 1, 7, 1, 5, 1, 1, 7, 1, 8,\n", + " 1, 1, 1, 7, 7, 1, 1, 1, 7, 1, 1, 7, 7, 5, 7, 8, 5, 1,\n", + " 5, 1, 5, 5, 5, 1, 1, 1, 8, 5, 1, 1, 7, 8, 1, 1, 1, 1,\n", + " 8, 1], device='cuda:0')" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "\n", + "model.to(device)\n", + "outputs = model(X_batch.to(device))\n", + "\n", + "probabilities = torch.nn.functional.softmax(outputs, 1)\n", + "predictions = torch.argmax(probabilities, axis=1)\n", + "\n", + "display(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.\n", + "\n", + "At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'epithelial cell', 'basal cell', 'keratinocyte', 'leukocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'epithelial cell', 'epithelial cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'keratinocyte',\n", + " 'keratinocyte', 'leukocyte', 'keratinocyte', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'epithelial cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'vein endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'epithelial cell', 'epithelial cell',\n", + " 'basal cell', 'vein endothelial cell', 'basal cell', 'fibroblast',\n", + " 'leukocyte', 'epithelial cell', 'basal cell', 'leukocyte',\n", + " 'vein endothelial cell', 'leukocyte', 'basal cell', 'leukocyte',\n", + " 'basal cell', 'leukocyte', 'basal cell', 'epithelial cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte', 'leukocyte',\n", + " 'keratinocyte', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'leukocyte', 'basal cell', 'keratinocyte',\n", + " 'capillary endothelial cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'keratinocyte', 'basal cell', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'basal cell', 'keratinocyte',\n", + " 'basal cell', 'basal cell', 'keratinocyte', 'keratinocyte',\n", + " 'epithelial cell', 'keratinocyte', 'leukocyte', 'epithelial cell',\n", + " 'basal cell', 'epithelial cell', 'basal cell', 'epithelial cell',\n", + " 'epithelial cell', 'epithelial cell', 'basal cell', 'basal cell',\n", + " 'basal cell', 'leukocyte', 'epithelial cell', 'basal cell',\n", + " 'basal cell', 'keratinocyte', 'leukocyte', 'basal cell',\n", + " 'basal cell', 'basal cell', 'basal cell', 'leukocyte',\n", + " 'basal cell'], dtype=object)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())\n", + "\n", + "display(predicted_cell_types)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we create a Pandas DataFrame to examine the predictions:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | actual cell type | \n", + "predicted cell type | \n", + "
---|---|---|
0 | \n", + "leukocyte | \n", + "leukocyte | \n", + "
1 | \n", + "basal cell | \n", + "basal cell | \n", + "
2 | \n", + "basal cell | \n", + "basal cell | \n", + "
3 | \n", + "basal cell | \n", + "basal cell | \n", + "
4 | \n", + "basal cell | \n", + "basal cell | \n", + "
... | \n", + "... | \n", + "... | \n", + "
123 | \n", + "fibroblast | \n", + "basal cell | \n", + "
124 | \n", + "basal cell | \n", + "basal cell | \n", + "
125 | \n", + "keratinocyte | \n", + "basal cell | \n", + "
126 | \n", + "leukocyte | \n", + "leukocyte | \n", + "
127 | \n", + "basal cell | \n", + "basal cell | \n", + "
128 rows × 2 columns
\n", + "