-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtrain.py
67 lines (55 loc) · 1.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import random
import numpy as np
import torch
# original lib
import common as com
from networks.models import Models
########################################################################
# load parameter.yaml
########################################################################
param = com.yaml_load()
########################################################################
def main():
parser = com.get_argparse()
# read parameters from yaml
flat_param = com.param_to_args_list(params=param)
args = parser.parse_args(args=flat_param)
# read parameters from command line
args = parser.parse_args(namespace=args)
print(args)
if args.train_only and args.test_only:
raise ValueError("--train_only and --test_only cannot be used together.")
elif args.train_only:
train = True
test = False
elif args.test_only:
train = False
test = True
else:
train = True
test = True
args.cuda = args.use_cuda and torch.cuda.is_available()
# Python random
random.seed(args.seed)
# Numpy
np.random.seed(args.seed)
# Pytorch
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True
net = Models(args.model).net(
args=args,
train=train,
test=test
)
print(args.model)
print("============== BEGIN TRAIN ==============")
if train:
for epoch in range(1, args.epochs + 2):
net.train(epoch)
print("============ END OF TRAIN ============")
if test:
net.test()
if __name__ == "__main__":
main()