diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index d7918e79..9785c585 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -39,6 +39,7 @@ def __call__( inputs: torch.Tensor, diffusion_model: Callable[..., torch.Tensor], noise: torch.Tensor, + timesteps: torch.Tensor, condition: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -48,10 +49,9 @@ def __call__( inputs: Input image to which noise is added. diffusion_model: diffusion model. noise: random noise, of the same shape as the input. + timesteps: random timesteps. condition: Conditioning for network input. """ - num_timesteps = self.scheduler.num_train_timesteps - timesteps = torch.randint(0, num_timesteps, (inputs.shape[0],), device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition) @@ -123,6 +123,7 @@ def __call__( autoencoder_model: Callable[..., torch.Tensor], diffusion_model: Callable[..., torch.Tensor], noise: torch.Tensor, + timesteps: torch.Tensor, condition: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -133,6 +134,7 @@ def __call__( autoencoder_model: first stage model. diffusion_model: diffusion model. noise: random noise, of the same shape as the latent representation. + timesteps: random timesteps. condition: conditioning for network input. """ with torch.no_grad(): @@ -142,6 +144,7 @@ def __call__( inputs=latent, diffusion_model=diffusion_model, noise=noise, + timesteps=timesteps, condition=condition, ) diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index c397949f..364cc592 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -55,6 +55,9 @@ class DDIMScheduler(nn.Module): steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the + diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ def __init__( @@ -66,6 +69,7 @@ def __init__( clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, + prediction_type: str = "epsilon", ) -> None: super().__init__() self.beta_schedule = beta_schedule @@ -79,6 +83,12 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: + raise ValueError( + f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" + ) + + self.prediction_type = prediction_type self.num_train_timesteps = num_train_timesteps self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -171,7 +181,14 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample # 4. Clip "predicted x_0" if self.clip_sample: @@ -231,3 +248,21 @@ def add_noise( noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/generative/networks/schedulers/ddpm.py b/generative/networks/schedulers/ddpm.py index f2ea54e2..f0a64482 100644 --- a/generative/networks/schedulers/ddpm.py +++ b/generative/networks/schedulers/ddpm.py @@ -61,6 +61,7 @@ def __init__( beta_schedule: str = "linear", variance_type: str = "fixed_small", clip_sample: bool = True, + prediction_type: str = "epsilon", ) -> None: super().__init__() self.beta_schedule = beta_schedule @@ -74,6 +75,13 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]: + raise ValueError( + f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" + ) + + self.prediction_type = prediction_type + self.num_train_timesteps = num_train_timesteps self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -170,10 +178,12 @@ def step( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if predict_epsilon: + if self.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif self.prediction_type == "sample": pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" if self.clip_sample: @@ -233,3 +243,21 @@ def add_noise( noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 0ceb71ef..8b3cd511 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -63,7 +63,8 @@ def test_call(self, model_params, input_shape): ) inferer = DiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - sample = inferer(inputs=input, noise=noise, diffusion_model=model) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps) self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 81cb8002..450b940e 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -98,7 +98,10 @@ def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params, ) inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) - prediction = inferer(inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + prediction = inferer( + inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps + ) self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb index ae642e6e..8e0a1013 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb @@ -780,8 +780,13 @@ " # Generate random noise\n", " noise = torch.randn_like(images).to(device)\n", "\n", + " # Create timesteps\n", + " timesteps = torch.randint(\n", + " 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n", + " ).long()\n", + "\n", " # Get model prediction\n", - " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", "\n", " loss = F.mse_loss(noise_pred.float(), noise.float())\n", "\n", @@ -806,7 +811,10 @@ " with torch.no_grad():\n", " with autocast(enabled=True):\n", " noise = torch.randn_like(images).to(device)\n", - " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n", + " timesteps = torch.randint(\n", + " 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n", + " ).long()\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", " val_loss = F.mse_loss(noise_pred.float(), noise.float())\n", "\n", " val_epoch_loss += val_loss.item()\n", diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py index 4d592d4c..80db0e35 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py @@ -207,8 +207,13 @@ # Generate random noise noise = torch.randn_like(images).to(device) + # Create timesteps + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + # Get model prediction - noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise) + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) loss = F.mse_loss(noise_pred.float(), noise.float()) @@ -233,7 +238,10 @@ with torch.no_grad(): with autocast(enabled=True): noise = torch.randn_like(images).to(device) - noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise) + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) val_loss = F.mse_loss(noise_pred.float(), noise.float()) val_epoch_loss += val_loss.item() diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb index a3be5d84..a9f90771 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb @@ -394,8 +394,9 @@ ")\n", "model.to(device)\n", "\n", + "num_train_timesteps = 1000\n", "scheduler = DDPMScheduler(\n", - " num_train_timesteps=1000,\n", + " num_train_timesteps=num_train_timesteps,\n", ")\n", "\n", "optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n", @@ -433,13 +434,17 @@ "\n", " \"\"\"\n", "\n", - " def __init__(self, condition_name: Optional[str] = None):\n", + " def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n", " self.condition_name = condition_name\n", + " self.num_train_timesteps = num_train_timesteps\n", "\n", " def get_noise(self, images):\n", " \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n", " return torch.randn_like(images)\n", "\n", + " def get_timesteps(self, images):\n", + " return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n", + "\n", " def __call__(\n", " self,\n", " batchdata: Dict[str, torch.Tensor],\n", @@ -449,8 +454,9 @@ " ):\n", " images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n", " noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n", + " timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n", "\n", - " kwargs = {\"noise\": noise}\n", + " kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n", "\n", " if self.condition_name is not None and isinstance(batchdata, Mapping):\n", " kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n", @@ -2159,7 +2165,7 @@ " val_data_loader=val_loader,\n", " network=model,\n", " inferer=inferer,\n", - " prepare_batch=DiffusionPrepareBatch(),\n", + " prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n", " key_val_metric={\"val_mean_abs_error\": MeanAbsoluteError(output_transform=from_engine([\"pred\", \"label\"]))},\n", " val_handlers=val_handlers,\n", ")\n", @@ -2178,7 +2184,7 @@ " optimizer=optimizer,\n", " loss_function=torch.nn.MSELoss(),\n", " inferer=inferer,\n", - " prepare_batch=DiffusionPrepareBatch(),\n", + " prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n", " key_train_metric={\"train_acc\": MeanSquaredError(output_transform=from_engine([\"pred\", \"label\"]))},\n", " train_handlers=train_handlers,\n", ")\n", diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py index fcb22bf8..1db6ad83 100644 --- a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py @@ -177,8 +177,9 @@ ) model.to(device) +num_train_timesteps = 1000 scheduler = DDPMScheduler( - num_train_timesteps=1000, + num_train_timesteps=num_train_timesteps, ) optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5) @@ -203,13 +204,17 @@ class DiffusionPrepareBatch(PrepareBatch): """ - def __init__(self, condition_name: Optional[str] = None): + def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None): self.condition_name = condition_name + self.num_train_timesteps = num_train_timesteps def get_noise(self, images): """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" return torch.randn_like(images) + def get_timesteps(self, images): + return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() + def __call__( self, batchdata: Dict[str, torch.Tensor], @@ -219,8 +224,9 @@ def __call__( ): images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) + timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) - kwargs = {"noise": noise} + kwargs = {"noise": noise, "timesteps": timesteps} if self.condition_name is not None and isinstance(batchdata, Mapping): kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs) @@ -244,7 +250,7 @@ def __call__( val_data_loader=val_loader, network=model, inferer=inferer, - prepare_batch=DiffusionPrepareBatch(), + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps), key_val_metric={"val_mean_abs_error": MeanAbsoluteError(output_transform=from_engine(["pred", "label"]))}, val_handlers=val_handlers, ) @@ -263,7 +269,7 @@ def __call__( optimizer=optimizer, loss_function=torch.nn.MSELoss(), inferer=inferer, - prepare_batch=DiffusionPrepareBatch(), + prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps), key_train_metric={"train_acc": MeanSquaredError(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, ) diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb new file mode 100644 index 00000000..613a877e --- /dev/null +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb @@ -0,0 +1,999 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2278a8b4", + "metadata": {}, + "source": [ + "# Denoising Diffusion Probabilistic Models using v-prediction parameterization\n", + "\n", + "This tutorial illustrates how to use MONAI for training a denoising diffusion probabilistic model (DDPM)[1] to create synthetic 2D images using v-prediction parameterization (Section 2.4 from [2]).\n", + "\n", + "[1] - Ho et al. \"Denoising Diffusion Probabilistic Models\" https://arxiv.org/abs/2006.11239\n", + "[2] - Ho et al. \"Imagen Video: High Definition Video Generation with Diffusion Models\" https://arxiv.org/abs/2210.02303\n", + "\n", + "TODO: Add Open in Colab\n", + "\n", + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "87d63656", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm, einops]\"\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "id": "de9a09ad", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e991ba58", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 1.1.dev2239\n", + "Numpy version: 1.23.3\n", + "Pytorch version: 1.8.0+cu111\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: 13b24fa92b9d98bd0dc6d5cdcb52504fd09e297b\n", + "MONAI __file__: /media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.10\n", + "Nibabel version: 4.0.2\n", + "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n", + "Pillow version: 9.2.0\n", + "Tensorboard version: 2.11.0\n", + "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n", + "TorchVision version: 0.9.0+cu111\n", + "tqdm version: 4.64.1\n", + "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n", + "psutil version: 5.9.3\n", + "pandas version: NOT INSTALLED or UNKNOWN VERSION.\n", + "einops version: 0.6.0\n", + "transformers version: NOT INSTALLED or UNKNOWN VERSION.\n", + "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n", + "pynrrd version: NOT INSTALLED or UNKNOWN VERSION.\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "# Copyright 2020 MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "import os\n", + "import shutil\n", + "import tempfile\n", + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from monai import transforms\n", + "from monai.apps import MedNISTDataset\n", + "from monai.config import print_config\n", + "from monai.data import CacheDataset, DataLoader\n", + "from monai.utils import first, set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "\n", + "from generative.inferers import DiffusionInferer\n", + "\n", + "# TODO: Add right import reference after deployed\n", + "from generative.networks.nets import DiffusionModelUNet\n", + "from generative.networks.schedulers import DDPMScheduler\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "a414cba7", + "metadata": {}, + "source": [ + "## Setup data directory\n", + "\n", + "You can specify a directory with the MONAI_DATA_DIRECTORY environment variable.\n", + "\n", + "This allows you to save results and reuse downloads.\n", + "\n", + "If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "22061fd8", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmplcl8fv6u\n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "8962d9b4", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "57bb62f0", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "21210b86", + "metadata": {}, + "source": [ + "## Setup MedNIST Dataset and training and validation dataloaders\n", + "In this tutorial, we will train our models on the MedNIST dataset available on MONAI\n", + "(https://docs.monai.io/en/stable/apps.html#monai.apps.MedNISTDataset). In order to train faster, we will select just\n", + "one of the available classes (\"Hand\"), resulting in a training set with 7999 2D images." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "46fc4bfb", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022-12-11 12:03:19,625 - INFO - Downloaded: /tmp/tmplcl8fv6u/MedNIST.tar.gz\n", + "2022-12-11 12:03:19,696 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2022-12-11 12:03:19,696 - INFO - Writing into directory: /tmp/tmplcl8fv6u.\n" + ] + } + ], + "source": [ + "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, progress=False, seed=0)\n", + "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"Hand\"]" + ] + }, + { + "cell_type": "markdown", + "id": "50c48aef", + "metadata": {}, + "source": [ + "Here we use transforms to augment the training dataset:\n", + "\n", + "1. `LoadImaged` loads the hands images from files.\n", + "1. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", + "1. `ScaleIntensityRanged` extracts intensity range [0, 255] and scales to [0, 1].\n", + "1. `RandAffined` efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a03c3f45", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7999/7999 [00:04<00:00, 1775.57it/s]\n" + ] + } + ], + "source": [ + "train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", + " transforms.RandAffined(\n", + " keys=[\"image\"],\n", + " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", + " translate_range=[(-1, 1), (-1, 1)],\n", + " scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n", + " spatial_size=[64, 64],\n", + " padding_mode=\"zeros\",\n", + " prob=0.5,\n", + " ),\n", + " ]\n", + ")\n", + "train_ds = CacheDataset(data=train_datalist, transform=train_transforms)\n", + "train_loader = DataLoader(train_ds, batch_size=96, shuffle=True, num_workers=4, persistent_workers=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7855726e", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022-12-11 12:03:42,998 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2022-12-11 12:03:42,999 - INFO - File exists: /tmp/tmplcl8fv6u/MedNIST.tar.gz, skipped downloading.\n", + "2022-12-11 12:03:42,999 - INFO - Non-empty folder exists in /tmp/tmplcl8fv6u/MedNIST, skipped extracting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1005/1005 [00:00<00:00, 1831.70it/s]\n" + ] + } + ], + "source": [ + "val_data = MedNISTDataset(root_dir=root_dir, section=\"validation\", download=True, progress=False, seed=0)\n", + "val_datalist = [{\"image\": item[\"image\"]} for item in val_data.data if item[\"class_name\"] == \"Hand\"]\n", + "val_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", + " ]\n", + ")\n", + "val_ds = CacheDataset(data=val_datalist, transform=val_transforms)\n", + "val_loader = DataLoader(val_ds, batch_size=96, shuffle=False, num_workers=4, persistent_workers=True)" + ] + }, + { + "cell_type": "markdown", + "id": "01452490", + "metadata": {}, + "source": [ + "### Visualisation of the training images" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3f68cdfe", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch shape: (96, 1, 64, 64)\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "check_data = first(train_loader)\n", + "print(f\"batch shape: {check_data['image'].shape}\")\n", + "image_visualisation = torch.cat(\n", + " [check_data[\"image\"][0, 0], check_data[\"image\"][1, 0], check_data[\"image\"][2, 0], check_data[\"image\"][3, 0]], dim=1\n", + ")\n", + "plt.figure(\"training images\", (12, 6))\n", + "plt.imshow(image_visualisation, vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.axis(\"off\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "7d026c35", + "metadata": {}, + "source": [ + "### Define network, scheduler, optimizer, and inferer\n", + "At this step, we instantiate the MONAI components to create a DDPM, the UNET, the noise scheduler, and the inferer used for training and sampling. We are using\n", + "the original DDPM scheduler containing 1000 timesteps in its Markov chain, and a 2D UNET with attention mechanisms\n", + "in the 2nd and 3rd levels, each with 1 attention head." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7ba0c0f", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\")\n", + "\n", + "model = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " num_channels=(128, 256, 256),\n", + " attention_levels=(False, True, True),\n", + " num_res_blocks=1,\n", + " num_head_channels=256,\n", + ")\n", + "model.to(device)\n", + "\n", + "scheduler = DDPMScheduler(\n", + " prediction_type=\"v_prediction\",\n", + " num_train_timesteps=1000,\n", + " beta_start = 0.00085,\n", + " beta_end = 0.0120,\n", + ")\n", + "\n", + "optimizer = torch.optim.Adam(params=model.parameters(), lr=1.0e-4)\n", + "\n", + "inferer = DiffusionInferer(scheduler)" + ] + }, + { + "cell_type": "markdown", + "id": "315f8b47", + "metadata": {}, + "source": [ + "### Model training\n", + "Here, we are training our model for 75 epochs (training time: ~50 minutes)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f5f58b7e", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|█████████████| 84/84 [00:38<00:00, 2.17it/s, loss=0.13]\n", + "Epoch 1: 100%|███████████| 84/84 [00:38<00:00, 2.18it/s, loss=0.0485]\n", + "Epoch 2: 100%|███████████| 84/84 [00:38<00:00, 2.17it/s, loss=0.0433]\n", + "Epoch 3: 100%|███████████| 84/84 [00:38<00:00, 2.17it/s, loss=0.0406]\n", + "Epoch 4: 100%|███████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0385]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 71.19it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 5: 100%|███████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0366]\n", + "Epoch 6: 100%|███████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0357]\n", + "Epoch 7: 100%|████████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.035]\n", + "Epoch 8: 100%|███████████| 84/84 [00:38<00:00, 2.15it/s, loss=0.0338]\n", + "Epoch 9: 100%|████████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.034]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 71.24it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10: 100%|██████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0332]\n", + "Epoch 11: 100%|██████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0325]\n", + "Epoch 12: 100%|███████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.032]\n", + "Epoch 13: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0326]\n", + "Epoch 14: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0315]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 72.64it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 15: 100%|██████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0315]\n", + "Epoch 16: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0313]\n", + "Epoch 17: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0307]\n", + "Epoch 18: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0309]\n", + "Epoch 19: 100%|████████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.03]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.03it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 20: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0305]\n", + "Epoch 21: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0301]\n", + "Epoch 22: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0307]\n", + "Epoch 23: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0303]\n", + "Epoch 24: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0299]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.70it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 25: 100%|████████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.03]\n", + "Epoch 26: 100%|████████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.03]\n", + "Epoch 27: 100%|████████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.03]\n", + "Epoch 28: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0294]\n", + "Epoch 29: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0294]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.75it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 30: 100%|██████████| 84/84 [00:38<00:00, 2.15it/s, loss=0.0291]\n", + "Epoch 31: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0297]\n", + "Epoch 32: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0294]\n", + "Epoch 33: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0294]\n", + "Epoch 34: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0287]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.80it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 35: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0292]\n", + "Epoch 36: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0289]\n", + "Epoch 37: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0284]\n", + "Epoch 38: 100%|███████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.029]\n", + "Epoch 39: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0295]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.70it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 40: 100%|██████████| 84/84 [00:38<00:00, 2.16it/s, loss=0.0289]\n", + "Epoch 41: 100%|███████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.029]\n", + "Epoch 42: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0288]\n", + "Epoch 43: 100%|███████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.029]\n", + "Epoch 44: 100%|███████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.029]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 69.12it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 45: 100%|██████████| 84/84 [00:39<00:00, 2.13it/s, loss=0.0282]\n", + "Epoch 46: 100%|██████████| 84/84 [00:40<00:00, 2.06it/s, loss=0.0286]\n", + "Epoch 47: 100%|██████████| 84/84 [00:40<00:00, 2.09it/s, loss=0.0282]\n", + "Epoch 48: 100%|███████████| 84/84 [00:40<00:00, 2.09it/s, loss=0.028]\n", + "Epoch 49: 100%|██████████| 84/84 [00:40<00:00, 2.07it/s, loss=0.0289]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.00it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 50: 100%|██████████| 84/84 [00:41<00:00, 2.02it/s, loss=0.0282]\n", + "Epoch 51: 100%|██████████| 84/84 [00:39<00:00, 2.10it/s, loss=0.0285]\n", + "Epoch 52: 100%|██████████| 84/84 [00:40<00:00, 2.08it/s, loss=0.0283]\n", + "Epoch 53: 100%|██████████| 84/84 [00:42<00:00, 2.00it/s, loss=0.0281]\n", + "Epoch 54: 100%|██████████| 84/84 [00:42<00:00, 1.99it/s, loss=0.0285]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 67.98it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 55: 100%|██████████| 84/84 [00:39<00:00, 2.13it/s, loss=0.0287]\n", + "Epoch 56: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0281]\n", + "Epoch 57: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0288]\n", + "Epoch 58: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0285]\n", + "Epoch 59: 100%|███████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.028]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.07it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 60: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0281]\n", + "Epoch 61: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0282]\n", + "Epoch 62: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0279]\n", + "Epoch 63: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0283]\n", + "Epoch 64: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0278]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.18it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 65: 100%|██████████| 84/84 [00:39<00:00, 2.15it/s, loss=0.0282]\n", + "Epoch 66: 100%|██████████| 84/84 [00:39<00:00, 2.14it/s, loss=0.0278]\n", + "Epoch 67: 100%|██████████| 84/84 [00:39<00:00, 2.11it/s, loss=0.0277]\n", + "Epoch 68: 100%|██████████| 84/84 [00:41<00:00, 2.04it/s, loss=0.0281]\n", + "Epoch 69: 100%|██████████| 84/84 [00:39<00:00, 2.11it/s, loss=0.0277]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 69.83it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 70: 100%|███████████| 84/84 [00:40<00:00, 2.07it/s, loss=0.028]\n", + "Epoch 71: 100%|██████████| 84/84 [00:40<00:00, 2.09it/s, loss=0.0275]\n", + "Epoch 72: 100%|██████████| 84/84 [00:41<00:00, 2.02it/s, loss=0.0274]\n", + "Epoch 73: 100%|██████████| 84/84 [00:40<00:00, 2.07it/s, loss=0.0276]\n", + "Epoch 74: 100%|██████████| 84/84 [00:39<00:00, 2.10it/s, loss=0.0278]\n", + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.73it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train completed, total time: 3196.6940882205963.\n" + ] + } + ], + "source": [ + "n_epochs = 75\n", + "val_interval = 5\n", + "epoch_loss_list = []\n", + "val_epoch_loss_list = []\n", + "\n", + "scaler = GradScaler()\n", + "total_start = time.time()\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " epoch_loss = 0\n", + " progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70)\n", + " progress_bar.set_description(f\"Epoch {epoch}\")\n", + " for step, batch in progress_bar:\n", + " images = batch[\"image\"].to(device)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " with autocast(enabled=True):\n", + " # Generate random noise\n", + " noise = torch.randn_like(images).to(device)\n", + " timesteps = torch.randint(\n", + " 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n", + " ).long()\n", + "\n", + " # Get target for the v-prediction parameterization\n", + " target = inferer.scheduler.get_velocity(images, noise, timesteps)\n", + "\n", + " # Get model prediction\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", + "\n", + " loss = F.mse_loss(noise_pred.float(), target.float())\n", + "\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + "\n", + " epoch_loss += loss.item()\n", + "\n", + " progress_bar.set_postfix(\n", + " {\n", + " \"loss\": epoch_loss / (step + 1),\n", + " }\n", + " )\n", + " epoch_loss_list.append(epoch_loss / (step + 1))\n", + "\n", + " if (epoch + 1) % val_interval == 0:\n", + " model.eval()\n", + " val_epoch_loss = 0\n", + " for step, batch in enumerate(val_loader):\n", + " images = batch[\"image\"].to(device)\n", + " with torch.no_grad():\n", + " with autocast(enabled=True):\n", + " noise = torch.randn_like(images).to(device)\n", + " timesteps = torch.randint(\n", + " 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n", + " ).long()\n", + " target = inferer.scheduler.get_velocity(images, noise, timesteps)\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", + " val_loss = F.mse_loss(noise_pred.float(), target.float())\n", + "\n", + " val_epoch_loss += val_loss.item()\n", + " progress_bar.set_postfix(\n", + " {\n", + " \"val_loss\": val_epoch_loss / (step + 1),\n", + " }\n", + " )\n", + " val_epoch_loss_list.append(val_epoch_loss / (step + 1))\n", + "\n", + " # Sampling image during training\n", + " noise = torch.randn((1, 1, 64, 64))\n", + " noise = noise.to(device)\n", + " scheduler.set_timesteps(num_inference_steps=1000)\n", + " with autocast(enabled=True):\n", + " image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)\n", + "\n", + " plt.figure(figsize=(2, 2))\n", + " plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + "total_time = time.time() - total_start\n", + "print(f\"train completed, total time: {total_time}.\")" + ] + }, + { + "cell_type": "markdown", + "id": "a70fd533", + "metadata": {}, + "source": [ + "### Learning curves" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3db336e6", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.style.use(\"seaborn-v0_8\")\n", + "plt.title(\"Learning Curves\", fontsize=20)\n", + "plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", + "plt.plot(\n", + " np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),\n", + " val_epoch_loss_list,\n", + " color=\"C1\",\n", + " linewidth=2.0,\n", + " label=\"Validation\",\n", + ")\n", + "plt.yticks(fontsize=12)\n", + "plt.xticks(fontsize=12)\n", + "plt.xlabel(\"Epochs\", fontsize=16)\n", + "plt.ylabel(\"Loss\", fontsize=16)\n", + "plt.legend(prop={\"size\": 14})\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a4ce9887", + "metadata": {}, + "source": [ + "### Plotting sampling process along DDPM's Markov chain" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7d3fee5b", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 69.29it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "noise = torch.randn((1, 1, 64, 64))\n", + "noise = noise.to(device)\n", + "scheduler.set_timesteps(num_inference_steps=1000)\n", + "with autocast(enabled=True):\n", + " image, intermediates = inferer.sample(\n", + " input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100\n", + " )\n", + "\n", + "chain = torch.cat(intermediates, dim=-1)\n", + "\n", + "plt.style.use(\"default\")\n", + "plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.tight_layout()\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "c07a83c4", + "metadata": {}, + "source": [ + "### Cleanup data directory\n", + "\n", + "Remove directory if a temporary was used." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f685296b", + "metadata": {}, + "outputs": [], + "source": [ + "if directory is None:\n", + " shutil.rmtree(root_dir)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "py:percent,ipynb" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.py b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.py new file mode 100644 index 00000000..743eb76c --- /dev/null +++ b/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.py @@ -0,0 +1,324 @@ +# --- +# jupyter: +# jupytext: +# formats: py:percent,ipynb +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.1 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Denoising Diffusion Probabilistic Models using v-prediction parameterization +# +# This tutorial illustrates how to use MONAI for training a denoising diffusion probabilistic model (DDPM)[1] to create synthetic 2D images using v-prediction parameterization (Section 2.4 from [2]). +# +# [1] - Ho et al. "Denoising Diffusion Probabilistic Models" https://arxiv.org/abs/2006.11239 +# [2] - Ho et al. "Imagen Video: High Definition Video Generation with Diffusion Models" https://arxiv.org/abs/2210.02303 +# +# TODO: Add Open in Colab +# +# ## Setup environment + +# %% +# !python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm, einops]" +# !python -c "import matplotlib" || pip install -q matplotlib +# %matplotlib inline + +# %% [markdown] +# ## Setup imports + +# %% jupyter={"outputs_hidden": false} +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import tempfile +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from monai import transforms +from monai.apps import MedNISTDataset +from monai.config import print_config +from monai.data import CacheDataset, DataLoader +from monai.utils import first, set_determinism +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from generative.inferers import DiffusionInferer + +# TODO: Add right import reference after deployed +from generative.networks.nets import DiffusionModelUNet +from generative.networks.schedulers import DDPMScheduler + +print_config() + +# %% [markdown] +# ## Setup data directory +# +# You can specify a directory with the MONAI_DATA_DIRECTORY environment variable. +# +# This allows you to save results and reuse downloads. +# +# If not specified a temporary directory will be used. + +# %% jupyter={"outputs_hidden": false} +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory +print(root_dir) + +# %% [markdown] +# ## Set deterministic training for reproducibility + +# %% jupyter={"outputs_hidden": false} +set_determinism(42) + +# %% [markdown] +# ## Setup MedNIST Dataset and training and validation dataloaders +# In this tutorial, we will train our models on the MedNIST dataset available on MONAI +# (https://docs.monai.io/en/stable/apps.html#monai.apps.MedNISTDataset). In order to train faster, we will select just +# one of the available classes ("Hand"), resulting in a training set with 7999 2D images. + +# %% jupyter={"outputs_hidden": false} +train_data = MedNISTDataset(root_dir=root_dir, section="training", download=True, progress=False, seed=0) +train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "Hand"] + +# %% [markdown] +# Here we use transforms to augment the training dataset: +# +# 1. `LoadImaged` loads the hands images from files. +# 1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape. +# 1. `ScaleIntensityRanged` extracts intensity range [0, 255] and scales to [0, 1]. +# 1. `RandAffined` efficiently performs rotate, scale, shear, translate, etc. together based on PyTorch affine transform. + +# %% jupyter={"outputs_hidden": false} +train_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + transforms.RandAffined( + keys=["image"], + rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)], + translate_range=[(-1, 1), (-1, 1)], + scale_range=[(-0.05, 0.05), (-0.05, 0.05)], + spatial_size=[64, 64], + padding_mode="zeros", + prob=0.5, + ), + ] +) +train_ds = CacheDataset(data=train_datalist, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=96, shuffle=True, num_workers=4, persistent_workers=True) + +# %% jupyter={"outputs_hidden": false} +val_data = MedNISTDataset(root_dir=root_dir, section="validation", download=True, progress=False, seed=0) +val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "Hand"] +val_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + ] +) +val_ds = CacheDataset(data=val_datalist, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=96, shuffle=False, num_workers=4, persistent_workers=True) + +# %% [markdown] +# ### Visualisation of the training images + +# %% jupyter={"outputs_hidden": false} +check_data = first(train_loader) +print(f"batch shape: {check_data['image'].shape}") +image_visualisation = torch.cat( + [check_data["image"][0, 0], check_data["image"][1, 0], check_data["image"][2, 0], check_data["image"][3, 0]], dim=1 +) +plt.figure("training images", (12, 6)) +plt.imshow(image_visualisation, vmin=0, vmax=1, cmap="gray") +plt.axis("off") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ### Define network, scheduler, optimizer, and inferer +# At this step, we instantiate the MONAI components to create a DDPM, the UNET, the noise scheduler, and the inferer used for training and sampling. We are using +# the original DDPM scheduler containing 1000 timesteps in its Markov chain, and a 2D UNET with attention mechanisms +# in the 2nd and 3rd levels, each with 1 attention head. + +# %% jupyter={"outputs_hidden": false} +device = torch.device("cuda") + +model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(128, 256, 256), + attention_levels=(False, True, True), + num_res_blocks=1, + num_head_channels=256, +) +model.to(device) + +scheduler = DDPMScheduler( + prediction_type="v_prediction", + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, +) + +optimizer = torch.optim.Adam(params=model.parameters(), lr=1.0e-4) + +inferer = DiffusionInferer(scheduler) +# %% [markdown] +# ### Model training +# Here, we are training our model for 75 epochs (training time: ~50 minutes). + +# %% jupyter={"outputs_hidden": false} +n_epochs = 75 +val_interval = 5 +epoch_loss_list = [] +val_epoch_loss_list = [] + +scaler = GradScaler() +total_start = time.time() +for epoch in range(n_epochs): + model.train() + epoch_loss = 0 + progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in progress_bar: + images = batch["image"].to(device) + optimizer.zero_grad(set_to_none=True) + + with autocast(enabled=True): + # Generate random noise + noise = torch.randn_like(images).to(device) + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + + # Get target for the v-prediction parameterization + target = inferer.scheduler.get_velocity(images, noise, timesteps) + + # Get model prediction + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) + + loss = F.mse_loss(noise_pred.float(), target.float()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + epoch_loss += loss.item() + + progress_bar.set_postfix( + { + "loss": epoch_loss / (step + 1), + } + ) + epoch_loss_list.append(epoch_loss / (step + 1)) + + if (epoch + 1) % val_interval == 0: + model.eval() + val_epoch_loss = 0 + for step, batch in enumerate(val_loader): + images = batch["image"].to(device) + with torch.no_grad(): + with autocast(enabled=True): + noise = torch.randn_like(images).to(device) + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + target = inferer.scheduler.get_velocity(images, noise, timesteps) + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) + val_loss = F.mse_loss(noise_pred.float(), target.float()) + + val_epoch_loss += val_loss.item() + progress_bar.set_postfix( + { + "val_loss": val_epoch_loss / (step + 1), + } + ) + val_epoch_loss_list.append(val_epoch_loss / (step + 1)) + + # Sampling image during training + noise = torch.randn((1, 1, 64, 64)) + noise = noise.to(device) + scheduler.set_timesteps(num_inference_steps=1000) + with autocast(enabled=True): + image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler) + + plt.figure(figsize=(2, 2)) + plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") + plt.tight_layout() + plt.axis("off") + plt.show() + +total_time = time.time() - total_start +print(f"train completed, total time: {total_time}.") +# %% [markdown] +# ### Learning curves + +# %% jupyter={"outputs_hidden": false} +plt.style.use("seaborn-v0_8") +plt.title("Learning Curves", fontsize=20) +plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") +plt.plot( + np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)), + val_epoch_loss_list, + color="C1", + linewidth=2.0, + label="Validation", +) +plt.yticks(fontsize=12) +plt.xticks(fontsize=12) +plt.xlabel("Epochs", fontsize=16) +plt.ylabel("Loss", fontsize=16) +plt.legend(prop={"size": 14}) +plt.show() + +# %% [markdown] +# ### Plotting sampling process along DDPM's Markov chain + +# %% jupyter={"outputs_hidden": false} +model.eval() +noise = torch.randn((1, 1, 64, 64)) +noise = noise.to(device) +scheduler.set_timesteps(num_inference_steps=1000) +with autocast(enabled=True): + image, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100 + ) + +chain = torch.cat(intermediates, dim=-1) + +plt.style.use("default") +plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") +plt.tight_layout() +plt.axis("off") +plt.show() + +# %% [markdown] +# ### Cleanup data directory +# +# Remove directory if a temporary was used. + +# %% +if directory is None: + shutil.rmtree(root_dir) diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb index 411d4c85..5f92921f 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb @@ -1360,7 +1360,8 @@ " z = autoencoderkl.sampling(z_mu, z_sigma)\n", "\n", " noise = torch.randn_like(z).to(device)\n", - " noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise)\n", + " timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device).long()\n", + " noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps)\n", " loss = F.mse_loss(noise_pred.float(), noise.float())\n", "\n", " loss.backward()\n", @@ -1386,7 +1387,10 @@ " z = autoencoderkl.sampling(z_mu, z_sigma)\n", "\n", " noise = torch.randn_like(z).to(device)\n", - " noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise)\n", + " timesteps = torch.randint(\n", + " 0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device\n", + " ).long()\n", + " noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps)\n", "\n", " loss = F.mse_loss(noise_pred.float(), noise.float())\n", "\n", @@ -1575,8 +1579,7 @@ "metadata": { "jupytext": { "cell_metadata_filter": "-all", - "formats": "ipynb,py", - "main_language": "python" + "formats": "ipynb,py" }, "kernelspec": { "display_name": "Python 3", diff --git a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py index 186daff2..defb6db4 100644 --- a/tutorials/generative/2d_ldm/2d_ldm_tutorial.py +++ b/tutorials/generative/2d_ldm/2d_ldm_tutorial.py @@ -320,7 +320,8 @@ z = autoencoderkl.sampling(z_mu, z_sigma) noise = torch.randn_like(z).to(device) - noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise) + timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device).long() + noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps) loss = F.mse_loss(noise_pred.float(), noise.float()) loss.backward() @@ -346,7 +347,10 @@ z = autoencoderkl.sampling(z_mu, z_sigma) noise = torch.randn_like(z).to(device) - noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise) + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device + ).long() + noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps) loss = F.mse_loss(noise_pred.float(), noise.float()) diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py index 6f140606..34b2e9f0 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py @@ -321,8 +321,16 @@ def KL_loss(z_mu, z_sigma): with autocast(enabled=True): # Generate random noise noise = torch.randn_like(z).to(device) + + # Create timesteps + timesteps = torch.randint( + 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + # Get model prediction - noise_pred = inferer(inputs=images, autoencoder_model=autoencoder, diffusion_model=unet, noise=noise) + noise_pred = inferer( + inputs=images, autoencoder_model=autoencoder, diffusion_model=unet, noise=noise, timesteps=timesteps + ) loss = F.mse_loss(noise_pred.float(), noise.float())