-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddpm_train.py
86 lines (69 loc) · 2.92 KB
/
ddpm_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, Lambda, Resize, ToTensor
from tqdm import tqdm
from diffusion.autoencoders import VQModel
from diffusion.loss import DiffusionHybridLoss
from diffusion.models import CFGuidance, DDPModule, UNet
from diffusion.predictor import NoisePredictor
from diffusion.schedule import DiscreteGaussianSchedule, linear_beta_schedule
from diffusion.transformers import RandomDiffusionSteps
WITH_VAE = True
def main():
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))
predictor = NoisePredictor(schedule, lambda x: torch.clamp(x, -1, 1))
# prepare model
# unet = UNet(input_channel=1, time_dim=32, digit_dim=32)
unet = UNet(input_channel=64, time_dim=32, digit_dim=32) # the latent dim is 64
unet = CFGuidance(unet, 32, guidance=2.0)
model = DDPModule(unet, schedule, predictor)
encoder = nn.Embedding(10, 32)
epochs = 25
device = "cuda:0"
encoder.to(device)
model.to(device)
if WITH_VAE:
vqvae = VQModel().to(device)
vqvae.load_state_dict(torch.load("vae.pt", map_location=device))
# prepare data
diffusion_transform = RandomDiffusionSteps(schedule, batched=True)
transforms = Compose([Resize((32, 32)), ToTensor()])
train_dataset = FashionMNIST("fashion_mnist", train=True, download=True, transform=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=192, shuffle=True, num_workers=2, pin_memory=True)
# Apply optimizer to diffusion model and encoder for joint training
optimizer = torch.optim.AdamW([{"params": encoder.parameters()}, {"params": model.parameters()}], lr=0.0001)
# Define loss
h_loss = DiffusionHybridLoss(schedule)
encoder.train()
model.train()
for e in range(epochs):
for sample in (pbar := tqdm(train_dataloader)):
x, c = sample
x = x.to(device)
if WITH_VAE:
latents = vqvae.encode(x)
else:
latents = x # the image itself
noisy_latents = diffusion_transform({"x": latents}) # add noise
x0, xt, noise, t, c = (
noisy_latents["x"].to(device),
noisy_latents["xt"].to(device),
noisy_latents["noise"].to(device),
noisy_latents["t"].to(device),
c.to(device),
)
optimizer.zero_grad()
# Compute loss
embedding = encoder(c)
out = model(xt, t, embedding)
loss = h_loss(out.prediction, noise, out.mean, out.log_variance, x0, xt, t)
loss.backward()
optimizer.step()
pbar.set_description(f"{e+1}| Loss: {loss.item()}")
# save model
torch.save(model.state_dict(), "model.pt")
torch.save(encoder.state_dict(), "encoder.pt")
if __name__ == "__main__":
main()