-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
78 lines (66 loc) · 2.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
""" utility functions"""
import re
import os
from os.path import basename
import gensim
import torch
from torch import nn
def count_data(path):
""" count number of data in the given path"""
matcher = re.compile(r'[0-9]+\.json')
match = lambda name: bool(matcher.match(name))
names = os.listdir(path)
n_data = len(list(filter(match, names)))
return n_data
PAD = 0
UNK = 1
START = 2
END = 3
def make_vocab(wc, vocab_size):
word2id, id2word = {}, {}
word2id['<pad>'] = PAD
word2id['<unk>'] = UNK
word2id['<start>'] = START
word2id['<end>'] = END
for i, (w, _) in enumerate(wc.most_common(vocab_size), 4):
word2id[w] = i
return word2id
def make_embedding(id2word, w2v_file, initializer=None):
attrs = basename(w2v_file).split('.') #word2vec.{dim}d.{vsize}k.bin
w2v = gensim.models.Word2Vec.load(w2v_file).wv
vocab_size = len(id2word)
emb_dim = int(attrs[-3][:-1])
embedding = nn.Embedding(vocab_size, emb_dim).weight
if initializer is not None:
initializer(embedding)
oovs = []
with torch.no_grad():
for i in range(len(id2word)):
# NOTE: id2word can be list or dict
if i == START:
embedding[i, :] = torch.Tensor(w2v['<s>'])
elif i == END:
embedding[i, :] = torch.Tensor(w2v[r'<\s>'])
elif id2word[i] in w2v:
embedding[i, :] = torch.Tensor(w2v[id2word[i]])
else:
oovs.append(i)
return embedding, oovs
def reconstruct_topic_dis(data_topic, batch_size=None):
if batch_size == None:
num_sent = len(data_topic)
topic_dis = torch.zeros((num_sent, 512), dtype=torch.float).cuda()
for i in range(num_sent):
index = (torch.tensor(data_topic[i][0]),)
value = torch.FloatTensor(data_topic[i][1]).cuda()
topic_dis[i].index_put_(index, value)
return topic_dis
else:
# Deal with batch data
batch_topic = data_topic
batch_topic_dis = []
for b in range(batch_size):
data_topic = batch_topic[b]
topic_dis = reconstruct_topic_dis(data_topic)
batch_topic_dis.append(topic_dis)
return batch_topic_dis