-
Notifications
You must be signed in to change notification settings - Fork 241
/
extract_feature.py
128 lines (100 loc) · 4.45 KB
/
extract_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import functools
from collections import OrderedDict
# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
class IntermediateLayerGetter:
def __init__(self, model, return_layers, keep_output=True):
"""Wraps a Pytorch module to get intermediate values
Arguments:
model {nn.module} -- The Pytorch module to call
return_layers {dict} -- Dictionary with the selected submodules
to return the output (format: {[current_module_name]: [desired_output_name]},
current_module_name can be a nested submodule, e.g. submodule1.submodule2.submodule3)
Keyword Arguments:
keep_output {bool} -- If True model_output contains the final model's output
in the other case model_output is None (default: {True})
Returns:
(mid_outputs {OrderedDict}, model_output {any}) -- mid_outputs keys are
your desired_output_name (s) and their values are the returned tensors
of those submodules (OrderedDict([(desired_output_name,tensor(...)), ...).
See keep_output argument for model_output description.
In case a submodule is called more than one time, all it's outputs are
stored in a list.
"""
self._model = model
self.return_layers = return_layers
self.keep_output = keep_output
def __call__(self, *args, **kwargs):
ret = OrderedDict()
handles = []
for name, new_name in self.return_layers.items():
layer = rgetattr(self._model, name)
def hook(module, input, output, new_name=new_name):
if new_name in ret:
if type(ret[new_name]) is list:
ret[new_name].append(output)
else:
ret[new_name] = [ret[new_name], output]
else:
ret[new_name] = output
try:
h = layer.register_forward_hook(hook)
except AttributeError as e:
raise AttributeError(f'Module {name} not found')
handles.append(h)
if self.keep_output:
output = self._model(*args, **kwargs)
else:
self._model(*args, **kwargs)
output = None
for h in handles:
h.remove()
return ret, output
def main(args, config):
from models import build_model
import torchvision.transforms as T
from PIL import Image
model = build_model(config)
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.cuda()
# examples:
# return_layers = {
# 'patch_embed': 'patch_embed',
# 'levels.0.downsample': 'levels.0.downsample',
# 'levels.0.blocks.0.dcn': 'levels.0.blocks.0.dcn',
# }
return_layers = {k: k for k in args.keys}
mid_getter = IntermediateLayerGetter(model, return_layers=return_layers, keep_output=True)
image = Image.open(args.img)
transforms = T.Compose([
T.Resize(config.DATA.IMG_SIZE),
T.ToTensor(),
T.Normalize(config.AUG.MEAN, config.AUG.STD)
])
image = transforms(image)
image = image.unsqueeze(0)
image = image.cuda()
mid_outputs, model_output = mid_getter(image)
for k, v in mid_outputs.items():
print(k, v.shape)
return mid_outputs, model_output
if __name__ == '__main__':
import argparse
import torch
from config import get_config
parser = argparse.ArgumentParser('Get Intermediate Layer Output')
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='Path to config file')
parser.add_argument('--img', type=str, required=True, metavar="FILE", help='Path to img file')
parser.add_argument("--keys", default=None, nargs='+', help="The intermediate layer's keys you want to save.")
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--save', action='store_true', help='Save the results.')
args = parser.parse_args()
config = get_config(args)
mid_outputs, model_output = main(args, config)
if args.save:
torch.save(mid_outputs, args.img[:-3] + '.pth')