-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
176 lines (153 loc) · 8.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
'''
@author: niceliu
@contact: nicehuster@gmail.com
@file: main.py
@time: 1/1/19 7:04 PM
@desc:
'''
import torch,sys,PIL,time
from utils import prepare_seed,time_for_file,\
obtain_args,Logger,load_configure,AverageMeter,convert_secs2time,time_string
from data import transforms,GeneralDataset
from models import obtain_model,save_checkpoint
from optimizer import obtain_optimizer
from exps import basic_eval_all as eval_all
from exps import basic_train as train
from copy import deepcopy
def main(args):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
prepare_seed(args.rand_seed)
logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
logger = Logger(args.save_path, logstr)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python version : {}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow version : {}".format(PIL.__version__))
logger.log("PyTorch version : {}".format(torch.__version__))
logger.log("cuDNN version : {}".format(torch.backends.cudnn.version()))
# General Data Argumentation
mean_fill = tuple([int(x * 255) for x in [0.485, 0.456, 0.406]])
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
train_transform = [transforms.PreCrop(args.pre_crop_expand)]
train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
# if args.arg_flip:
# train_transform += [transforms.AugHorizontalFlip()]
if args.rotate_max:
train_transform += [transforms.AugRotate(args.rotate_max)]
train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
train_transform += [transforms.ToTensor(), normalize]
train_transform = transforms.Compose(train_transform)
eval_transform = transforms.Compose(
[transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),
transforms.ToTensor(), normalize])
assert (args.scale_min + args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(
args.scale_min, args.scale_max, args.scale_eval)
# Model Configure Load
model_config = load_configure(args.model_config, logger)
args.sigma = args.sigma * args.scale_eval
logger.log('Real Sigma : {:}'.format(args.sigma))
# Training Dataset
train_data = GeneralDataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
train_data.load_list(args.train_lists, args.num_pts, True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
shuffle=True,num_workers=args.workers,
pin_memory=True)
# Evaluation Dataloader
eval_loaders = []
if args.eval_ilists is not None:
for eval_ilist in args.eval_ilists:
eval_idata = GeneralDataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type,
args.data_indicator)
eval_idata.load_list(eval_ilist, args.num_pts, True)
eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
eval_loaders.append((eval_iloader, False))
# Define network
logger.log('configure : {:}'.format(model_config))
net = obtain_model(model_config, args.num_pts + 1)
assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(
model_config.downsample, net.downsample)
logger.log("=> network :\n {}".format(net))
logger.log('Training-data : {:}'.format(train_data))
for i, eval_loader in enumerate(eval_loaders):
eval_loader, is_video = eval_loader
logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders),
'video' if is_video else 'image',
eval_loader.dataset))
logger.log('arguments : {:}'.format(args))
opt_config = load_configure(args.opt_config, logger)
if hasattr(net, 'specify_parameter'):
net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
else:
net_param_dict = net.parameters()
optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
logger.log('criterion : {:}'.format(criterion))
net, criterion = net.cuda(), criterion.cuda()
net = torch.nn.DataParallel(net)
last_info = logger.last_info()
if last_info.exists():
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(str(last_info))
start_epoch = last_info['epoch'] + 1
checkpoint = torch.load(last_info['last_checkpoint'])
assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info,
checkpoint[
'epoch'])
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done".format(logger.last_info(), checkpoint['epoch']))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch = 0
if args.eval_once:
logger.log("=> only evaluate the model once")
eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
logger.close()
return
# Main Training and Evaluation Loop
start_time = time.time()
epoch_time = AverageMeter()
for epoch in range(start_epoch, opt_config.epochs):
scheduler.step()
need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs - epoch), True)
epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
LRs = scheduler.get_lr()
logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str,
need_time, min(LRs),
max(LRs), opt_config))
# train for one epoch
train_loss, train_nme = train(args, train_loader, net, criterion,
optimizer, epoch_str, logger, opt_config)
# log the results
logger.log(
'==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss,
train_nme * 100))
# remember best prec@1 and save checkpoint
save_path = save_checkpoint({
'epoch': epoch,
'args': deepcopy(args),
'arch': model_config.arch,
'state_dict': net.state_dict(),
'scheduler': scheduler.state_dict(),
'optimizer': optimizer.state_dict(),
}, str(logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str)), logger)
last_info = save_checkpoint({
'epoch': epoch,
'last_checkpoint': save_path,
}, str(logger.last_info()), logger)
eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.close()
if __name__=="__main__":
args=obtain_args()
main(args)