-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_engine.py
185 lines (162 loc) · 6.38 KB
/
data_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from torch.utils.data import Dataset
from nltk.stem import WordNetLemmatizer
import numpy as np
import spacy
import os
import pickle
from utils import print_time_info
from tokenizer import Tokenizer
from text_token import _UNK, _PAD, _BOS, _EOS
from data.E2ENLG import E2ENLG
class DataEngine(Dataset):
def __init__(
self,
data_dir,
dataset,
save_path='data.pkl',
vocab_path='vocab.pkl',
is_spacy=True,
is_lemma=True,
fold_attr=True,
use_punct=False,
vocab_size=20000,
n_layers=4,
min_length=5,
en_max_length=None,
de_max_length=None,
regen=False,
train=True
):
if is_spacy:
self.spacy_parser = spacy.load('en')
print_time_info("Use Spacy as the parser")
else:
self.nltk_lemmatizer = WordNetLemmatizer()
print_time_info("Use NLTK as the parser")
self.is_spacy = is_spacy
self.is_lemma = is_lemma
self.fold_attr = fold_attr
self.use_punct = use_punct
self.data_dir = data_dir
self.save_path = save_path
self.vocab_path = vocab_path
self.vocab_size = vocab_size
self.n_layers = n_layers
self.dataset = dataset
self.min_length = min_length
self.en_max_length = en_max_length if en_max_length else -1
self.de_max_length = de_max_length if de_max_length else -1
self.regen = regen
self.split_vocab = True
self.tokenizer = Tokenizer(vocab_path, self.split_vocab, regen, train)
self.counter = 0
self.train = train
self.prepare_data()
def prepare_data(self):
if not os.path.exists(self.save_path) or self.regen:
if self.regen:
print_time_info("Regenerate the data...")
else:
print_time_info("There isn't any usable save...")
if not os.path.isdir(self.data_dir):
print_time_info("Error: The dataset doesn't exist")
exit()
print_time_info("Start reading dataset {} from {}".format(
self.dataset, self.data_dir))
if self.dataset == "E2ENLG":
self.input_data, self.input_attr_seqs, self.output_labels, \
self.refs, self.sf_data = E2ENLG(
self.data_dir, self.is_spacy, self.is_lemma,
self.fold_attr, self.use_punct,
self.min_length, self.train)
else:
self.input_data, self.input_attr_seqs, self.output_labels, self.refs, self.sf_data = \
pickle.load(open(self.save_path, 'rb'))
print_time_info("Load the data from {}".format(self.save_path))
if not os.path.exists(self.vocab_path) or (self.regen and self.train):
self.build_vocab()
if not os.path.exists(self.save_path) or self.regen:
self.tokenize_sents()
self.crop()
pickle.dump(
[self.input_data, self.input_attr_seqs, self.output_labels, self.refs, self.sf_data],
open(self.save_path, 'wb'))
print_time_info(
"Create the save file {}".format(self.save_path))
self.tokenizer.shrink_vocab(self.vocab_size)
self.add_unk()
self.training_set_label_samples = self.input_data
def build_vocab(self):
if not self.split_vocab:
corpus = []
for sent in self.input_data:
corpus.extend(sent)
for sent in self.output_labels:
corpus.extend(sent)
self.tokenizer.build_vocab(corpus)
else:
corpus = []
for sent in self.output_labels:
corpus.extend(sent)
tokens = []
for attrs in self.input_data:
tokens.extend(attrs)
self.tokenizer.build_vocab(corpus, tokens)
def tokenize_sents(self):
for idx, sent in enumerate(self.input_data):
self.input_data[idx] = self.tokenizer.tokenize(sent, True)
for idx, sent in enumerate(self.output_labels):
self.output_labels[idx] = self.tokenizer.tokenize(sent, False)
self.output_labels[idx].append(_EOS)
for idx, refs in enumerate(self.refs):
self.refs[idx] = [self.tokenizer.tokenize(ref, False) for ref in refs]
def add_unk(self):
if not self.split_vocab:
for idx, sent in enumerate(self.input_data):
for w_idx, word in enumerate(sent):
if word >= self.vocab_size + 4:
self.input_data[idx][w_idx] = _UNK
for idx, sent in enumerate(self.output_labels):
for w_idx, word in enumerate(sent):
if word >= self.vocab_size + 4:
self.output_labels[idx][w_idx] = _UNK
def crop(self):
if self.en_max_length != -1:
for idx, sent in enumerate(self.input_data):
self.input_data[idx] = sent[:self.en_max_length]
if self.de_max_length != -1:
for idx, sent in enumerate(self.output_labels):
# for sidx, sent in enumerate(labels):
self.output_labels[idx] = sent[:self.de_max_length]
def tokenize(self, sent):
return self.tokenizer.tokenize(sent, True)
def untokenize(self, sent, is_token=False):
return self.tokenizer.untokenize(sent, is_token)
def __len__(self):
return len(self.input_data)
def __getitem__(self, idx):
return (
self.input_data[idx],
self.output_labels[idx],
self.refs[idx],
self.sf_data[idx],
self.input_attr_seqs[idx]
)
class DataEngineSplit(Dataset):
def __init__(self, input_data, output_labels, refs, sf_data, input_attr_seqs):
super(DataEngineSplit, self).__init__()
self.input_data = input_data
self.output_labels = output_labels
self.refs = refs
self.sf_data = sf_data
self.input_attr_seqs = input_attr_seqs
def __len__(self):
return len(self.input_data)
def __getitem__(self, idx):
return (
self.input_data[idx],
self.output_labels[idx],
self.refs[idx],
self.sf_data[idx],
self.input_attr_seqs[idx]
)