-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ddpm.py
69 lines (53 loc) · 2.38 KB
/
train_ddpm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import yaml
import tqdm
import os
import argparse
import math
from easydict import EasyDict
from dataset.dataset import CustomDataset
from models.unet import Unet
from utils.noise_scheduler import NoiseScheduler
from utils.utils import get_optimizer
import torch
from torch.utils.data import DataLoader
def train(args):
with open(args.config_path, 'r') as file:
try:
config = EasyDict(yaml.safe_load(file))
except yaml.YAMLError as exc:
print(exc)
print(config)
print(config.dataset.mnist.train)
mnist = CustomDataset(config.dataset.mnist.train, os.path.join(config.dataset.mnist.train, 'labels.csv'))
mnist_loader = DataLoader(mnist, batch_size=config.train.batch_size, shuffle=True, num_workers=0)
noise_scheduler = NoiseScheduler(num_timesteps=config.diffusion.num_timesteps)
noise_scheduler.linear_noise_scheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end)
model = Unet(config.model)
model.train()
optimizer = get_optimizer(config.train.optimizer, model)
best_loss = math.inf
for epoch_idx in range(config.train.epoch):
losses = []
batch = 1
for image, _ in mnist_loader:
optimizer.zero_grad()
t = torch.randint(0, config.diffusion.num_timesteps, (image.shape[0], ))
noise = torch.randn_like(image)
noisy_image = noise_scheduler.apply_noise(image, noise, t)
noise_pred = model(noisy_image, t)
loss = torch.nn.SmoothL1Loss()(noise, noise_pred)
print("Epoch : {} | Batch : {} | Loss : {}".format(epoch_idx, batch, loss.item()))
batch += 1
losses.append(loss.item())
loss.backward()
optimizer.step()
epoch_loss = torch.mean(torch.tensor(losses))
print("Epoch : {} | Loss : {}".format(epoch_idx, epoch_loss.item()))
if best_loss > epoch_loss:
torch.save(model.state_dict(), os.path.join(config.train.saved_ddpm_model_dir, 'saved_ddpm_min_loss.pth'))
torch.save(model.state_dict(), os.path.join(config.train.saved_ddpm_model_dir, 'saved_ddpm_epoch_{}.pth'.format(epoch_idx)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', dest='config_path', default='configs/ddpm.yml', type=str)
args = parser.parse_args()
train(args)