-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_gen_test.py
89 lines (65 loc) · 2.6 KB
/
simple_gen_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.
# Authors: Kilian Fatras
# Alexander Tong
import os
import copy
import torch
from absl import app, flags
from utils_celeba import ema, generate_samples
from torchcfm.models.unet.unet import UNetModelWrapper
FLAGS = flags.FLAGS
flags.DEFINE_string("model", "IDFF", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
flags.DEFINE_float("sigma", 0.2, help="sigma")
flags.DEFINE_float("flow_w", 2, help="flow weight")
# Training
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
def sample_gen(argv=1):
# MODELS
net_model = UNetModelWrapper(
dim=(6, 64, 64),
num_res_blocks=2,
num_channels=FLAGS.num_channel,
channel_mult=[1, 2,2, 4, 4],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
device
) # new dropout + bs of 128
sigma = FLAGS.sigma
# show model size
model_size = 0
for param in net_model.parameters():
model_size += param.data.nelement()
print("Model params: %.2f M" % (model_size / 1024 / 1024))
#################################
# OT-CFM
#################################
savedir = FLAGS.output_dir + FLAGS.model +'-'+str(FLAGS.flow_w)+'-'+ str(sigma)+ "/"
# Load the model
PATH = f"{FLAGS.output_dir}/{FLAGS.model+'-'+str(FLAGS.flow_w)+'-'+ str(sigma)}/{FLAGS.model}_celeba_weights_step_final.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
state_dict = checkpoint["ema_model"]
try:
net_model.load_state_dict(state_dict)
except RuntimeError:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
net_model.load_state_dict(new_state_dict)
net_model.eval()
ema_model = copy.deepcopy(net_model)
ema(net_model, ema_model, FLAGS.ema_decay) # new
# generate_samples(net_model, FLAGS.parallel, savedir, FLAGS.step, net_="normal",sde_enable=True,sigma=1*sigma,model_name=FLAGS.model)
generate_samples(ema_model, FLAGS.parallel, savedir, 'final', net_="ema",sde_enable=True,sigma=1*sigma,model_name=FLAGS.model)
if __name__ == "__main__":
app.run(sample_gen)