-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathretinaface_detect.py
129 lines (106 loc) · 4.16 KB
/
retinaface_detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import cv2
from retinaface_pytorch.data import cfg_mnet, cfg_re50
from retinaface_pytorch.layers.functions.prior_box import PriorBox
from retinaface_pytorch.utils.nms.py_cpu_nms import py_cpu_nms
from retinaface_pytorch.models.retinaface import RetinaFace
from retinaface_pytorch.utils.box_utils import decode, decode_landm
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
if load_to_cpu:
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(
pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(
pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
def create_net(network='mobile0.25',
weights='./retinaface_pytorch/weights/mobilenet0.25_Final.pth'):
torch.set_grad_enabled(False)
cfg = None
if network == "mobile0.25":
cfg = cfg_mnet
elif network == "resnet50":
cfg = cfg_re50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = RetinaFace(cfg=cfg, phase='test')
net = load_model(net, weights, not torch.cuda.is_available())
net.eval()
cudnn.benchmark = True
net = net.to(device)
return net, cfg
def detect(img,
net,
cfg,
confidence_threshold=0.02,
top_k=5000,
nms_threshold=0.4,
keep_top_k=750):
resize = 1
img = np.float32(img)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
im_height, im_width, _ = img.shape
scale = torch.Tensor(
[img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)
scale = scale.to(device)
loc, conf, landms = net(img) # forward pass
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
inds = np.where(scores > confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
order = scores.argsort()[::-1][:top_k]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(
np.float32, copy=False)
keep = py_cpu_nms(dets, nms_threshold)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
dets = dets[:keep_top_k, :]
landms = landms[:keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
return dets