-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
73 lines (66 loc) · 2.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from PIL import Image
import os
import numpy as np
save_dir = "./output"
@torch.no_grad
def save_images(generated_images, epoch, step, edges=False):
generated_images = (generated_images + 1) / 2
for i in range(generated_images.size(0)):
img = generated_images[i].cpu().numpy().transpose(1, 2, 0)
img = (img * 255).astype(np.uint8)
img = Image.fromarray(img)
postfix = "sketch" if edges else "generated"
img.save(
os.path.join(save_dir, f"epoch_{epoch}_step_{step}_img_{i+1}_{postfix}.png")
)
def load_checkpoint(models, optimizers, schedulers, epoch, checkpoints_dir):
generator, discriminator = models
g_optimizer, d_optimizer = optimizers
g_scheduler, d_scheduler = schedulers
generator.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"generator_state_dict_{epoch}"))
)
discriminator.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"discriminator_state_dict_{epoch}"))
)
g_optimizer.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"g_optimizer_state_dict_{epoch}"))
)
d_optimizer.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"d_optimizer_state_dict_{epoch}"))
)
g_scheduler.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"g_scheduler_state_dict_{epoch}"))
)
d_scheduler.load_state_dict(
torch.load(os.path.join(checkpoints_dir, f"d_scheduler_state_dict_{epoch}"))
)
def save_checkpoint(models, optimizers, schedulers, epoch, output_dir):
generator, discriminator = models
g_optimizer, d_optimizer = optimizers
g_scheduler, d_scheduler = schedulers
torch.save(
generator.state_dict(),
os.path.join(output_dir, f"generator_state_dict_{epoch}"),
)
torch.save(
discriminator.state_dict(),
os.path.join(output_dir, f"discriminator_state_dict_{epoch}"),
)
torch.save(
g_optimizer.state_dict(),
os.path.join(output_dir, f"g_optimizer_state_dict_{epoch}"),
)
torch.save(
d_optimizer.state_dict(),
os.path.join(output_dir, f"d_optimizer_state_dict_{epoch}"),
)
torch.save(
g_scheduler.state_dict(),
os.path.join(output_dir, f"g_scheduler_state_dict_{epoch}"),
)
torch.save(
d_scheduler.state_dict(),
os.path.join(output_dir, f"d_scheduler_state_dict_{epoch}"),
)