-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·62 lines (50 loc) · 2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
from criterion.criterion import LogSoftMax,SelectCrossEntropyLoss,WeightL1Loss
from network.backbone import buildlModel
from utils.utils import build_data_loader,build_optimizer,process_loss,adjust_optimizer
def train(args):
total_epoch=args.epoch
batch=args.batch
num_workers=args.num_workers
lr_init = args.lr
logSoftMax=LogSoftMax()
selCroEntLoss=SelectCrossEntropyLoss()
weigthL1Loss=WeightL1Loss()
model=buildlModel("train")
trainLoader=build_data_loader(batch,num_workers)
data_length=len(trainLoader)
optimizer=build_optimizer(model,lr_init)
for epoch in range(total_epoch):
try:
for step,data in enumerate(trainLoader):
label_cls = data['label_cls']
label_loc = data['label_loc']
label_loc_weight = data['label_loc_weight']
cls,loc=model(data)
cls=logSoftMax(cls)
cls_loss=selCroEntLoss(cls,label_cls)
loc_loss=weigthL1Loss(loc,label_loc,label_loc_weight)
loss=process_loss(loc_loss,cls_loss)["total_loss"]
print("epoch:{epoch} step:{step} loss:{loss}".format(epoch=epoch,step=step,loss=loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer=adjust_optimizer(model,total_epoch,epoch,lr_init)
except:
pass
def demo(args):
pass
def main(args):
if args.mode=="train":
train(args)
if args.mode=="demo":
demo(args)
if __name__=="__main__":
parser=argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epoch", type=int, default=50)
parser.add_argument("--batch",type=int,default=8)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--mode", type=str, default="train")
args=parser.parse_args()
main(args)