-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathonnx_inference.py
113 lines (99 loc) · 3.45 KB
/
onnx_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: Eric.Lee
# function: onnx Inference
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
from hand_data_iter.datasets import draw_bd_handpose
class ONNXModel():
def __init__(self, onnx_path,gpu_cfg = False):
"""
:param onnx_path:
"""
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
if gpu_cfg:
self.onnx_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}])
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
print("input_name:{}".format(self.input_name))
print("output_name:{}".format(self.output_name))
def get_output_name(self, onnx_session):
"""
output_name = onnx_session.get_outputs()[0].name
:param onnx_session:
:return:
"""
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
"""
:param onnx_session:
:return:
"""
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_numpy):
"""
:param input_name:
:param image_numpy:
:return:
"""
input_feed = {}
for name in input_name:
input_feed[name] = image_numpy
return input_feed
def forward(self, image_numpy):
'''
# image_numpy = image_numpy[np.newaxis, :]
# onnx_session.run([output_name], {input_name: x})
# :param image_numpy:
# :return:
'''
input_feed = self.get_input_feed(self.input_name, image_numpy)
output = self.onnx_session.run(self.output_name, input_feed=input_feed)
return output
if __name__ == "__main__":
img_size = 256
model = ONNXModel("resnet_50_size-256.onnx")
path_ = "./image/"
for f_ in os.listdir(path_):
img0 = cv2.imread(path_ + f_)
img_width = img0.shape[1]
img_height = img0.shape[0]
img = cv2.resize(img0, (img_size,img_size), interpolation = cv2.INTER_CUBIC)
img_ndarray = img.transpose((2, 0, 1))
img_ndarray = img_ndarray / 255.
img_ndarray = np.expand_dims(img_ndarray, 0)
output = model.forward(img_ndarray.astype('float32'))[0][0]
output = np.array(output)
print(output.shape[0])
pts_hand = {} #构建关键点连线可视化结构
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
pts_hand[str(i)] = {}
pts_hand[str(i)] = {
"x":x,
"y":y,
}
draw_bd_handpose(img0,pts_hand,0,0) # 绘制关键点连线
#------------- 绘制关键点
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
cv2.circle(img0, (int(x),int(y)), 3, (255,50,60),-1)
cv2.circle(img0, (int(x),int(y)), 1, (255,150,180),-1)
cv2.namedWindow('image',0)
cv2.imshow('image',img0)
if cv2.waitKey(600) == 27 :
break
cv2.waitKey(0)