-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathinference.py
134 lines (114 loc) · 3.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#coding=utf-8
import os
import time
import timeit
import argparse
import numpy as np
#import cv2
from PIL import Image
import torch
#import torch.nn.functional as F
from RMI import parser_params, full_model
from RMI.model import psp, deeplab
from RMI.dataloaders import factory
from RMI.utils.metrics import Evaluator
# A map from segmentation name to model object.
seg_model_obj_dict = {
'pspnet': psp.PSPNet,
'deeplabv3': deeplab.DeepLabv3,
'deeplabv3+': deeplab.DeepLabv3Plus,
}
class Trainer(object):
def __init__(self, args):
"""initialize the Trainer"""
# about gpus
self.cuda = args.cuda
self.gpu_ids = args.gpu_ids
self.num_gpus = len(self.gpu_ids)
self.crf_iter_steps = args.crf_iter_steps
self.output_dir = args.output_dir
self.model = 'test'
# define dataloader
self.val_loader = factory.get_dataset(args.data_dir,
batch_size=1,
dataset=args.dataset,
split=args.train_split)
self.nclass = self.val_loader.NUM_CLASSES
# define network
assert args.seg_model in seg_model_obj_dict.keys()
self.seg_model = args.seg_model
self.seg_model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass,
backbone=args.backbone,
output_stride=args.out_stride,
norm_layer=torch.nn.BatchNorm2d,
bn_mom=args.bn_mom,
freeze_bn=True)
# define criterion
#self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean')
self.model = full_model.FullModel(seg_model=self.seg_model,
model=self.model)
# define evaluator
self.evaluator = Evaluator(self.nclass)
# using cuda
if args.cuda:
self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids)
#patch_replication_callback(self.model)
self.model = self.model.cuda()
#self.criterion = self.criterion.cuda()
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
print('Restore parameters from the {}'.format(args.resume))
checkpoint = torch.load(args.resume)
self.global_step = checkpoint['global_step']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
def validation(self):
"""validation procedure
"""
# set validation mode
self.model.eval()
self.evaluator.reset()
start = timeit.default_timer()
for i in range(len(self.val_loader)):
sample = self.val_loader[i]
image = sample['image']
if self.cuda:
image = image.cuda()
image = image.unsqueeze(dim=0)
# forward
with torch.no_grad():
output = self.model(image)
# the output of the pspnet is a tuple
if self.seg_model == 'pspnet':
output = output[0]
output = output.squeeze_()
pred = output.data.cpu().numpy()
# save output
pred = np.argmax(pred, axis=0)
path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png')
result = Image.fromarray(pred.astype(np.uint8))
result.save(path_to_output)
#cv2.imwrite(path_to_output, pred)
# report time of CRF
if not i % 100:
stop = timeit.default_timer()
print("current step = {} ({:.3f} sec)".
format(i, stop - start))
start = timeit.default_timer()
def main():
# get the parameters
parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Testing")
args = parser_params.add_parser_params(parser)
print(args)
torch.manual_seed(args.seed)
trainer = Trainer(args)
start_time = time.time()
trainer.validation()
total_time = time.time() - start_time
print("The validation time is {:.5f} sec".format(total_time))
if __name__ == "__main__":
main()