-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
60 lines (50 loc) · 2.28 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from torch.utils.data import Dataset
import torch
import numpy as np
import json
import pathlib as plb
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class VideoDataset(Dataset):
def __init__(self, captions_file, feat_path, max_len=80, mode='train'):
with open(captions_file, encoding='utf-8') as f:
data = json.load(f)
self.word2ix = data['word2ix']
self.ix2word = data['ix2word']
self.captions = data['captions'] # [name, caption]
self.splits = data['splits']
# filter the train/test/valid split
all_feat_paths = [i for i in plb.Path(feat_path).glob('*.npy')]
self.feat_paths = []
for path in all_feat_paths:
if path.stem in self.splits[mode]:
self.feat_paths.append(path)
self.max_len = max_len
print("prepare {} dataset. vocab_size: {}, dataset_size: {}".format(mode, len(self.word2ix), len(self.feat_paths)))
def __getitem__(self, index):
"""
select a feature and randomly select a corresponding caption,
then pad the caption to max_len when mode is 'train' or 'valid'
:param index: index of data
:return: tuple(tensor(img_feat), tensor(label), str(ID))
"""
ID = self.feat_paths[index].stem
feat = np.load(str(self.feat_paths[index]))
feat = torch.tensor(feat, dtype=torch.float, device=device, requires_grad=True)
labels = self.captions[ID]
label = np.random.choice(labels, 1)[0] # do not use Python random.choice
# label = random.choice(labels)
if len(label) > self.max_len:
label = label[:self.max_len]
pad_label = torch.zeros([self.max_len], dtype=torch.long, device=device)
pad_label[:len(label)] = torch.tensor(label, dtype=torch.long, device=device)
mask = torch.zeros([self.max_len], dtype=torch.float, device=device)
mask[:len(label)] = 1
return feat, pad_label, ID, mask
def __len__(self):
return len(self.feat_paths)
if __name__ == '__main__':
# for debug
trainset = VideoDataset('data/captions.json', 'data/feats/vgg16_bn')
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
a = next(iter(train_loader))