-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathextract_features.py
68 lines (56 loc) · 2.04 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import os
import imageio
import utils
import pickle as pkl
from extractor import Extractor
import time
NUMBER_OF_FRAMES = 40
path_to_read_data = 'data/UCF-101'
path_to_store_data = 'data/UCF-101-Inception-Features'
categories = os.listdir(path_to_read_data)
model = Extractor()
def get_activations(path_to_video):
try:
video = imageio.get_reader(path_to_video)
frames = [frame for frame in video]
frame_length = len(frames)
frames = utils.rescale_list(frames, NUMBER_OF_FRAMES)
sequence = list()
try:
for frame in frames:
features = model.extract(frame.astype(np.float64))
sequence.append(features)
except TypeError:
print(path_to_video, frame_length)
return np.array(sequence)
except:
print(path_to_video)
return []
def store_activations(category):
video_files = os.listdir(os.path.join(path_to_read_data, category))
path_to_dir = os.path.join(path_to_store_data, category)
if not os.path.isdir(path_to_dir):
os.mkdir(path_to_dir)
for video in video_files:
path = os.path.join(path_to_read_data, category, video)
store_name = video.replace('.avi', '.pkl')
path_to_store_file = os.path.join(path_to_store_data, category, store_name)
if os.path.isfile(path_to_store_file):
print(store_name + ' stored.')
continue
sequence = get_activations(path)
if len(sequence) != 0:
with open(path_to_store_file, 'wb') as file:
pkl.dump(sequence, file, protocol=pkl.HIGHEST_PROTOCOL)
print(store_name + ' stored.')
else:
print(store_name + ' not stored.')
start = time.time()
for category in categories:
store_activations(category)
checkpoint = time.time()
print('---------------------*****---------------------')
print(category + ' done!')
print('Time elapsed: {}s'.format(checkpoint - start))
print('---------------------*****---------------------')