-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
191 lines (145 loc) · 5.56 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
'''
The main Dataset class for the entire model is written here.
It keeps track of train_data and the corpus and everything to do
with the book dataset. Also supports multithreading.
'''
import io
import re
import os
import time
import utils
import string
import asyncio
import warnings
import numpy as np
from torch.utils.data import Dataset
from typing import List, Optional, Tuple, Union
from multiprocessing import Pool
warnings.filterwarnings('ignore')
import utils
import logger
class BookCorpusDataset(Dataset):
'''
Class:
- The main dataset which contains all the required data for training
the model.
- Supports multiprocessing.
Args:
chunk_size:
- The amount of words in a batch.
- This is set to None, when just_corpus=True.
just_corpus:
- Whether the dataset should only prepare the corpus (when not training).
- You can't run generate_batches() if this is set to True.
save_corpus:
- Whether to save the corpus in a file or not.
cache_train_data:
- Whether or not to save the training data instead of processing it every time
at runtime.
train_data_file:
- The filename to load the training data from.
corpus_from_file:
- The filename to load the corpus from.
'''
def __init__(self,
chunk_size=3,
just_corpus=False,
save_corpus=True,
cache_train_data=False,
train_data_file: Optional[str]=None,
corpus_from_file: Optional[str]=None):
try:
assert bool(train_data_file) == bool(corpus_from_file)
except AssertionError:
raise ValueError('''If train_data_file is None, then so should the corpus_from_file.
corpus_from_file is dependant on train_data_file.''')
start = time.time()
self.n_batches = 500
self._just_corpus = just_corpus
self.loop = asyncio.get_event_loop()
self.chunk_size = chunk_size if not self._just_corpus else None
if just_corpus:
return
file_contents: List[str]
if corpus_from_file:
# Remove newlines
self.corpus = [word.strip('\n') for word in io.open(corpus_from_file, encoding='utf-8').readlines()]
file_contents = self._run_load_corpus(just_contents=True)
else:
self.corpus, file_contents = self._run_load_corpus()
if train_data_file:
self.train_data = np.loadtxt(train_data_file)
else:
logger.INFO('Preprocessing...')
self.train_data: np.ndarray = utils.text2idx(
file_contents,
self.corpus,
pbar=True
)
logger.INFO('Finished preprocesing')
print(f'Process took: {time.time() - start}')
if save_corpus:
with io.open('corpus.txt', 'w', encoding='utf-8') as f:
for word in self.corpus:
f.write(word + '\n')
if cache_train_data:
np.savetxt('train_data.csv.gz', self.train_data)
self.prep_data = []
def generate_batches(self):
try:
assert not self._just_corpus
except AssertionError:
raise AssertionError('If you want to run: generate_batches(), you must set (just_corpus = False) in the constructor.')
beginning = 0
last_idx = self.chunk_size
for i in range(self.n_batches):
sample = self._get_batch(beginning, last_idx)
self.prep_data.append(sample)
beginning = last_idx
last_idx += self.chunk_size
def _run_load_corpus(self, just_contents=False):
return self.loop.run_until_complete(load_corpus('data', just_contents=just_contents))
def _get_batch(self, beginning, last_idx):
starting_phrase = self.train_data[beginning:last_idx]
target_word = self.train_data[last_idx:last_idx + self.chunk_size]
return (starting_phrase, target_word)
def __getitem__(self, index):
return self.prep_data[index]
def __len__(self):
return len(self.prep_data)
async def load_corpus(text_file_dir, just_contents=False) -> Union[list, str]:
corpus = ''
files_str = os.listdir(text_file_dir)
files = [open('data/' + f, 'r', encoding='utf-8') for f in files_str]
logger.INFO('Collecting tokens from:\n')
files_str.sort(key=len)
for c in files_str:
logger.info(c)
print()
for f in files:
corpus += f.read()
total_text = corpus
if not just_contents:
edited_corpus = re.sub(f"\n", ' ', corpus)
words = edited_corpus.split(' ')
# Filter punctuation and stop-words
r = []
for word in words:
r.append(await proc(word))
return (list(set(r)), total_text)
return total_text
async def proc(word):
new_word = ''
allowed_punc = '?.!-'
for char in word:
if char in allowed_punc or char in string.ascii_letters:
new_word += char
if word != '':
new_word = new_word.strip()
return new_word
if __name__ == '__main__':
loop = asyncio.get_event_loop()
start = time.time()
corpus, text = loop.run_until_complete(load_corpus('data'))
end = time.time()
print(end - start)