-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
75 lines (52 loc) · 2.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import subprocess
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset_utils import PromptTrainDataset
from net.DyNet_large import DyNet_large
from utils.schedulers import LinearWarmupCosineAnnealingLR
from options import options as opt
import lightning as pl
from lightning.pytorch.loggers import WandbLogger,TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torchmetrics.image import PeakSignalNoiseRatio
class DyNetModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = DyNet_large(decoder=True)
self.loss_fn = nn.L1Loss()
self.psnr = PeakSignalNoiseRatio()
def forward(self,x):
return self.net(x)
def training_step(self, batch, batch_idx):
degrad_patch, clean_patch, clean_name = batch
restored = self.net(degrad_patch)
restored = restored.clamp(0.0, 1.0)
loss = self.loss_fn(restored,clean_patch)
self.log('train_loss', loss, on_step=True, prog_bar=True, sync_dist=True)
return loss
def lr_scheduler_step(self,scheduler,metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=2e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15,max_epochs=150)
return [optimizer],[scheduler]
def main():
print("Options")
print(opt)
if opt.wblogger is not None:
logger = WandbLogger(project=opt.wblogger,name="DyNet-Training")
else:
logger = TensorBoardLogger(save_dir = "logs/")
trainset = PromptTrainDataset(opt)
checkpoint_callback = ModelCheckpoint(dirpath = opt.ckpt_dir,every_n_epochs = 1,save_top_k=-1)
trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True,
drop_last=True, num_workers=opt.num_workers)
model = DyNetModel().cuda()
trainer = pl.Trainer(max_epochs=opt.epochs,accelerator="gpu",devices=opt.num_gpus,strategy="ddp_find_unused_parameters_true",logger=logger,callbacks=[checkpoint_callback])
trainer.fit(model, trainloader)
if __name__ == '__main__':
main()