forked from shinianzhihou/ChangeDetection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_net.py
83 lines (67 loc) · 1.87 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import argparse
from configs import cfg
from build import (
build_dataloader,
build_model,
build_loss,
build_optimizer,
build_scheduler,
build_tensorboad,
build_checkpoint
)
from engine.trainer import train_epoch
from utils.states import States
def run_train(cfg):
states = States(cfg)
train_loader = build_dataloader(cfg)
test_loader = build_dataloader(cfg, test=True) if cfg.BUILD.TEST_WHEN_TRAIN else None
model = build_model(cfg).to(cfg.MODEL.DEVICE)
optimizer = build_optimizer(cfg, model)
model,optimizer,states = build_checkpoint(cfg, model, optimizer,states)
criterion = build_loss(cfg)
scheduler = build_scheduler(cfg, optimizer, max_iters=train_loader.__len__())
writer = build_tensorboad(cfg)
for epoch in range(cfg.SOLVER.NUM_EPOCH):
states.update("current_epoch",epoch)
train_epoch(
cfg,
states,
train_loader,
model,
optimizer,
criterion,
scheduler,
writer,
test_loader,
)
def main():
parser = argparse.ArgumentParser(
description="easy2train for Change Detection")
parser.add_argument(
"-cfg",
"--config_file",
default="configs/homo/default.yaml",
metavar="FILE",
help="Path to config file",
type=str,
)
parser.add_argument(
"-se",
"--skip_eval",
help="Do not eval the models(checkpoints)",
action="store_true",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
run_train(cfg)
if __name__ == "__main__":
main()