-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpredict.py
141 lines (130 loc) · 7.43 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -----------------------------------------------------------------------#
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
# -----------------------------------------------------------------------#
import time
import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from achelous import achelous
if __name__ == "__main__":
yolo = achelous()
# ----------------------------------------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
# 'export' 表示遍历文件夹进行检测并保存。遍历图像的文件夹,保存至export_results文件夹,详情查看下方注释。
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
# ----------------------------------------------------------------------------------------------------------#
mode = 'predict'
# -------------------------------------------------------------------------#
# crop 指定了是否在单张图片预测后对目标进行截取
# count 指定了是否进行目标的计数
# crop、count仅在mode='predict'时有效
# -------------------------------------------------------------------------#
crop = False
count = False
# ----------------------------------------------------------------------------------------------------------#
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
# video_fps 用于保存的视频的fps
#
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
# ----------------------------------------------------------------------------------------------------------#
video_path = 'images/video2.mp4'
video_save_path = "images/video_det_out.mp4"
video_fps = 33.0
# ----------------------------------------------------------------------------------------------------------#
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
# fps_image_path 用于指定测试的fps图片
#
# test_interval和fps_image_path仅在mode='fps'有效
# ----------------------------------------------------------------------------------------------------------#
test_interval = 100
fps_image_path = "images/example1.jpg"
# -------------------------------------------------------------------------#
# dir_origin_path 指定了用于检测的图片的文件夹路径
# dir_save_path 指定了检测完图片的保存路径
#
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
# -------------------------------------------------------------------------#
dir_origin_path = "img/"
dir_save_path = "img_out/"
# -------------------------------------------------------------------------#
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
#
# heatmap_save_path仅在mode='heatmap'有效
# -------------------------------------------------------------------------#
heatmap_save_path = "model_data/heatmap_vision.png"
# -------------------------------------------------------------------------#
# simplify 使用Simplify onnx
# onnx_save_path 指定了onnx的保存路径
# -------------------------------------------------------------------------#
simplify = True
onnx_save_path = "model_data/models.onnx"
if mode == "predict":
'''
1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
image_id = img[-20:-4]
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image, image_id, crop=crop, count=count)
r_image.show()
elif mode == "export":
img = input('Input image root:')
try:
image_list = os.listdir(img)
image_ids = [os.path.join(img, path) for path in image_list]
except:
print('Open Error! Try again!')
else:
for i in tqdm(range(len(image_list))):
image_id = image_list[i]
image = Image.open(image_ids[i])
r_image = yolo.detect_image(image, image_id[-20:-4], crop=crop, count=count, export_all=True)
elif mode == "dir_predict":
import os
from tqdm import tqdm
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
elif mode == "heatmap":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
image_id = img[-20:-4]
except:
print('Open Error! Try again!')
continue
else:
yolo.detect_heatmap(image, image_id, heatmap_save_path)
elif mode == "export_onnx":
yolo.convert_to_onnx(simplify, onnx_save_path)
else:
raise AssertionError(
"Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")