-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsuperpoint_extraction.py
95 lines (78 loc) · 2.77 KB
/
superpoint_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import argparse
import yaml
from tqdm import tqdm
import torch
import cv2
from model.build_model import build_superpoint_model
from model.inference import superpoint_inference
from cocoapi.PythonAPI.pycocotools.ytvos import YTVOS
def find(lst, key, value):
ind = []
id = []
for i, dic in enumerate(lst):
if value in dic[key][0]:
ind.append(i)
id.append(lst[i]['id'])
return ind, id
def inference(configs):
## data cofig
data_config = configs['data']
## superpoint model config
superpoint_model_config = configs['model']['superpoint']
detection_threshold = superpoint_model_config['detection_threshold']
## others
configs['num_gpu'] = [0]
configs['public_model'] = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
base_dir = configs['img_data_path']
base_points_dir = configs['points_dir']
# model
superpoint_model = build_superpoint_model(configs)
# YouTube-VIS
base_dir = configs['img_data_path']
ovis = YTVOS(configs['annotations_path'])
with open(configs['video_list_path'], "r") as f:
video_list = f.read().split("\n")
for video_id in tqdm(video_list):
points_dir = os.path.join(base_points_dir, video_id)
# Identfy video id
vid_ind, vid_id = find(ovis.dataset['videos'], 'file_names', video_id)
image_paths = ovis.dataset['videos'][vid_ind[0]]['file_names']
for i in range(len(image_paths)):
data_path = os.path.join(base_dir, image_paths[i])
file_name = os.path.splitext(os.path.basename(image_paths[i]))[0]
data = {}
src = cv2.imread(data_path)
image = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
image = cv2.merge([image, image, image])
image = torch.from_numpy(image).type(torch.float32).to(device)
image = image.permute(2,0,1)
image /= 255
data['image'] = [image]
data['image_name'] = [str(file_name)]
with torch.no_grad():
result = superpoint_inference(superpoint_model, data, data_config, detection_threshold, points_dir)
def main():
parser = argparse.ArgumentParser(description="SuperPoint Feature Extraction")
parser.add_argument(
"-c", "--config_file",
dest = "config_file",
type = str,
default = ""
)
parser.add_argument(
"-g", "--gpu",
dest = "gpu",
type = int,
default = 1
)
args = parser.parse_args()
config_file = args.config_file
f = open(config_file, 'r', encoding='utf-8')
configs = f.read()
configs = yaml.safe_load(configs)
configs['use_gpu'] = args.gpu
inference(configs)
if __name__ == "__main__":
main()