Welcome to this project exploring Diffusion Models on the MNIST dataset! π
This repository focuses on generating and reconstructing handwritten digits by integrating:
- Autoencoders with Convolutional Attention Blocks (CABs)
- Denoising Diffusion Probabilistic Models (DDPM) using U-Net
This project aims to reconstruct MNIST digits by encoding them into a latent space and progressively denoising them through a Diffusion Model.
- π§ Latent Space Representations - Using attention mechanisms for better feature extraction.
- π Diffusion Process - Forward and reverse diffusion to model the data distribution.
- π Visualization - Monitoring performance through SSIM/PSNR scores and latent space scatter plots.
- Encoder compresses MNIST digits into latent representations using convolutional layers and attention.
- Decoder reconstructs the digits from latent space.
CABs refine the feature maps by:
- π Global Average Pooling to extract spatial information.
- π Two Conv2D Layers to scale the feature channels.
- β¨ Sigmoid Activation to apply attention.
class CALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
- πͺοΈ Forward Diffusion gradually adds noise to the latent representation.
- π Reverse Diffusion predicts and removes noise step by step to reconstruct the original image.
The core of the diffusion model is a U-Net, enhanced with:
- Residual Blocks
- Attention Mechanisms
- Time Embeddings for each diffusion step
The project features multiple visual outputs that highlight the training and performance of the model.
A side-by-side comparison of original vs reconstructed images. High SSIM and PSNR scores indicate effective reconstructions.
Projection of latent space using t-SNE for a batch and the full test dataset.
Latent Space with Labels (One Batch) |
Latent Space (Full Test Dataset) |
Tracking the loss of the diffusion model over epochs.
Loss Plot:
Images progress from noisy states (left) to denoised outputs (right), demonstrating the stepwise denoising process.
The project applies unconditional latent diffusion inspired by classic DDPMs but focuses on the latent space. Below is a simplified breakdown of the key concepts:
Where:
- (x_t) is the latent at timestep (t)
- (Ξ±_t) represents noise schedule
- (Ξ΅) is the random noise
Reverse Process (Denoising):
This iterative denoising helps reconstruct the original data.
Here are a few ideas to extend this project:
- 𧱠Larger U-Net Models for higher quality image synthesis
- π Dynamic Diffusion Schedules to speed up convergence
- π Experiment with Other Datasets like Fashion MNIST or CIFAR-10