-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
141 lines (117 loc) · 5.95 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import argparse
import pickle
from utils import decode_from_tokens
from vocabulary import Vocabulary
from configuration_file import ConfigurationFile
from model.encoder import Encoder
from model.decoder import AVSSNDecoder
import h5py
import torch
import numpy as np
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate captions por test samples')
parser.add_argument('-chckpt', '--checkpoint_path', type=str, default='pretrain/chckpt.pt',
help='Set the path to pre-trained model (default is pretrain/chckpt.pt).')
parser.add_argument('-data', '--dataset_folder', type=str, default='data/MSVD',
help='Set the path to dataset folder (default is data/MSVD).')
parser.add_argument('-out', '--output_folder', type=str, default='results/MSVD',
help='Set the path to output folder (default is results/MSVD).')
args = parser.parse_args()
# load vocabulary
with open(os.path.join(args.dataset_folder, 'corpus.pkl'), "rb") as f:
corpus = pickle.load(f)
idx2word_dict = corpus[4]
vocab = Vocabulary.from_idx2word_dict(idx2word_dict, False)
print('Size of vocabulary: {}'.format(len(vocab)))
# Pretrained Embedding
pretrained_embedding = torch.Tensor(corpus[5])
#max_frames = 20 #30
cnn_feature_size = 2048
c3d_feature_size = 4096
i3d_feature_size = 400
res_eco_features_size = 3584
projected_size = 512
hidden_size = 1024 # Number of hidden layer units of the cyclic network
mid_size = 128 # The middle of the boundary detection layer represents the dimension
n_tags = 300
global_tagger_hidden_size = 1024
specific_tagger_hidden_size = 128
hidden_size = 1024
embedding_size = 300 #1024
rnn_in_size = 300 #1024
rnn_hidden_size = 1024
config = ConfigurationFile(os.path.join(args.dataset_folder, 'config.ini'), 'attn-vscn-max')
# Models
encoder = Encoder(cnn_feature_size=cnn_feature_size,
c3d_feature_size=c3d_feature_size,
i3d_feature_size=i3d_feature_size,
n_tags=n_tags,
hidden_size=hidden_size,
global_tagger_hidden_size=global_tagger_hidden_size,
specific_tagger_hidden_size=specific_tagger_hidden_size,
n_layers=config.encoder_num_layers,
input_dropout_p=config.encoder_dropout_p,
rnn_dropout_p=config.encoder_dropout_p,
bidirectional=config.encoder_bidirectional,
rnn_cell=config.encoder_rnn_cell,
device='cpu')
decoder = AVSSNDecoder(in_seq_length=config.max_frames,
out_seq_length=config.max_words,
n_feats=res_eco_features_size + 512,
n_tags=n_tags,
embedding_size=embedding_size,
pretrained_embedding=pretrained_embedding,
hidden_size=hidden_size,
rnn_in_size=rnn_in_size,
rnn_hidden_size=rnn_hidden_size,
vocab=vocab,
device='cpu',
rnn_cell=config.decoder_rnn_cell,
encoder_num_layers=config.encoder_num_layers,
encoder_bidirectional=config.encoder_bidirectional,
num_layers=config.decoder_num_layers,
dropout_p=config.decoder_dropout_p,
beam_size=config.decoder_beam_size,
temperature=config.decoder_temperature,
train_sample_max=config.decoder_train_sample_max,
test_sample_max=config.decoder_test_sample_max,
beam_search_logic = config.decoder_beam_search_logic)
# Checkpoint
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
# 1. filter out unnecessary keys for encoder
chckpt_dict = {k: v for k, v in checkpoint['encoder'].items() if k not in ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']}
encoder_dict = encoder.state_dict()
encoder_dict.update(chckpt_dict)
encoder.load_state_dict(encoder_dict)
decoder.load_state_dict(checkpoint['decoder'])
#load test set features
test_vidxs = sorted(list(set(corpus[2][1])))
with h5py.File(os.path.join(args.dataset_folder, config.features_path), 'r') as feats_file:
print('loading visual feats...')
dataset = feats_file[config.dataset_name]
cnn_feats = torch.from_numpy(dataset['cnn_features'][test_vidxs]).float()
c3d_feats = torch.from_numpy(dataset['c3d_features'][test_vidxs]).float()
cnn_globals = torch.zeros(cnn_feats.size(0), 512) # torch.from_numpy(dataset['cnn_globals'][test_vidxs]).float()
cnn_sem_globals = torch.from_numpy(dataset['cnn_sem_globals'][test_vidxs]).float()
f_counts = dataset['count_features'][test_vidxs]
print('visual feats loaded')
res_eco_globals = torch.from_numpy(np.load(os.path.join(args.dataset_folder, 'resnext_eco.npy'))[test_vidxs])
tags_globals = torch.from_numpy(np.load(os.path.join(args.dataset_folder, 'tag_feats.npy'))[test_vidxs])
encoder.eval()
decoder.eval()
with torch.no_grad():
video_encoded = encoder(cnn_feats, c3d_feats, cnn_globals, tags_globals, res_eco_globals)
logits, tokens = decoder(video_encoded, None, teacher_forcing_ratio=0)
scores = logits.max(dim=2)[0].mean(dim=1)
confidences, sentences = [], []
for score, seq in zip(scores, tokens):
s = decode_from_tokens(seq, vocab)
print(score, s)
sentences.append(s)
confidences.append(score)
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
with open(os.path.join(args.output_folder, 'predictions.txt'), 'w') as fo:
for vidx, sentence in zip(test_vidxs, sentences):
fo.write(f'{vidx}\t{sentence}\n')