-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
158 lines (145 loc) · 4.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import argparse
from torch.backends import cudnn
from config import config, dataset_config, merge_cfg_arg
from dataloder import get_loader
from solver_makeup import Solver_makeupGAN
def parse_args():
parser = argparse.ArgumentParser(description="Train GAN")
# general
parser.add_argument(
"--data_path",
default="./mtdataset",
type=str,
help="training and test data path",
)
parser.add_argument("--dataset", default="MAKEUP", type=str)
parser.add_argument(
"--gpus", default="0", type=str, help="GPU device to train with"
)
parser.add_argument("--batch_size", default="1", type=int, help="batch_size")
parser.add_argument(
"--vis_step", default="1260", type=int, help="steps between visualization"
)
parser.add_argument(
"--task_name", default="makeup_branch_1", type=str, help="task name"
)
parser.add_argument("--checkpoint", default="", type=str, help="checkpoint to load")
parser.add_argument(
"--ndis", default="1", type=int, help="train discriminator steps"
)
parser.add_argument("--LR", default="2e-4", type=float, help="Learning rate")
parser.add_argument(
"--decay", default="0", type=int, help="epochs number for training"
)
parser.add_argument(
"--model",
default="makeupGAN",
type=str,
help="which model to use: cycleGAN/ makeupGAN",
)
parser.add_argument("--epochs", default="300", type=int, help="nums of epochs")
parser.add_argument(
"--whichG",
default="branch",
type=str,
help="which Generator to choose, normal/branch, branch means two input branches",
)
parser.add_argument(
"--norm",
default="SN",
type=str,
help="normalization of discriminator, SN means spectrum normalization, none means no normalization",
)
parser.add_argument(
"--d_repeat",
default="3",
type=int,
help="the repeat Res-block in discriminator",
)
parser.add_argument(
"--g_repeat", default="6", type=int, help="the repeat Res-block in Generator"
)
parser.add_argument(
"--lambda_cls", default="1", type=float, help="the lambda_cls weight"
)
parser.add_argument(
"--lambda_rec", default="10", type=int, help="lambda_A and lambda_B"
)
parser.add_argument(
"--lambda_his", default="1", type=float, help="histogram loss on lips"
)
parser.add_argument(
"--lambda_skin_1",
default="0.1",
type=float,
help="histogram loss on skin equals to lambda_his* lambda_skin",
)
parser.add_argument(
"--lambda_skin_2",
default="0.1",
type=float,
help="histogram loss on skin equals to lambda_his* lambda_skin",
)
parser.add_argument(
"--lambda_eye",
default="1",
type=float,
help="histogram loss on eyes equals to lambda_his*lambda_eye",
)
parser.add_argument(
"--content_layer",
default="r41",
type=str,
help="vgg layer we use to output features",
)
parser.add_argument(
"--lambda_vgg", default="5e-3", type=float, help="the param of vgg loss"
)
parser.add_argument(
"--cls_list", default="N,M", type=str, help="the classes of makeup to train"
)
parser.add_argument(
"--direct",
type=bool,
default=True,
help="direct means to add local cosmetic loss at the first, unified training",
)
parser.add_argument(
"--lips", type=bool, default=True, help="whether to finetune lips color"
)
parser.add_argument(
"--skin", type=bool, default=True, help="whether to finetune foundation color"
)
parser.add_argument(
"--eye", type=bool, default=True, help="whether to finetune eye shadow color"
)
args = parser.parse_args()
return args
def train_net():
# enable cudnn
cudnn.benchmark = True
data_loaders = get_loader(config, mode="train") # return train&test
# get the solver
# if args.model == "cycleGAN":
# solver = Solver_makeupGAN(data_loaders, config, dataset_config)
# solver = Solver_cycleGAN(data_loaders, config, dataset_config)
if args.model == "makeupGAN":
solver = Solver_makeupGAN(data_loaders, config, dataset_config)
else:
print("model that not support")
exit()
solver.train()
if __name__ == "__main__":
args = parse_args()
config = merge_cfg_arg(config, args)
dataset_config.name = args.dataset
print("show config")
for arg in vars(config):
print(f"{str(arg):>15}: {str(getattr(config, arg)):>30}")
print()
# Create the directories if not exist
if not os.path.exists(config.data_path):
print(f"No data path!! ({config.data_path})")
exit()
train_net()