-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_lstm.py
100 lines (77 loc) · 3.37 KB
/
train_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import print_function
import pandas as pd
from sklearn.model_selection import train_test_split
from lstm_seq2seq.library.utility.plot_utils import plot_and_save_history
from lstm_seq2seq.library.seq2seq import Seq2SeqSummarizer
import numpy as np
from collections import Counter
MAX_INPUT_SEQ_LENGTH = 500
MAX_TARGET_SEQ_LENGTH = 50
MAX_INPUT_VOCAB_SIZE = 5000
MAX_TARGET_VOCAB_SIZE = 2000
def fit_text(X, Y, input_seq_max_length=None, target_seq_max_length=None):
if input_seq_max_length is None:
input_seq_max_length = MAX_INPUT_SEQ_LENGTH
if target_seq_max_length is None:
target_seq_max_length = MAX_TARGET_SEQ_LENGTH
input_counter = Counter()
target_counter = Counter()
max_input_seq_length = 0
max_target_seq_length = 0
for line in X:
text = [word.lower() for word in line.split(' ')]
seq_length = len(text)
if seq_length > input_seq_max_length:
text = text[0:input_seq_max_length]
seq_length = len(text)
for word in text:
input_counter[word] += 1
max_input_seq_length = max(max_input_seq_length, seq_length)
for line in Y:
line2 = 'START ' + line.lower() + ' END'
text = [word for word in line2.split(' ')]
seq_length = len(text)
if seq_length > target_seq_max_length:
text = text[0:target_seq_max_length]
seq_length = len(text)
for word in text:
target_counter[word] += 1
max_target_seq_length = max(max_target_seq_length, seq_length)
input_word2idx = dict()
for idx, word in enumerate(input_counter.most_common(MAX_INPUT_VOCAB_SIZE)):
input_word2idx[word[0]] = idx + 2
input_word2idx['PAD'] = 0
input_word2idx['UNK'] = 1
input_idx2word = dict([(idx, word) for word, idx in input_word2idx.items()])
target_word2idx = dict()
for idx, word in enumerate(target_counter.most_common(MAX_TARGET_VOCAB_SIZE)):
target_word2idx[word[0]] = idx + 1
target_word2idx['UNK'] = 0
target_idx2word = dict([(idx, word) for word, idx in target_word2idx.items()])
num_input_tokens = len(input_word2idx)
num_target_tokens = len(target_word2idx)
config = dict()
config['input_word2idx'] = input_word2idx
config['input_idx2word'] = input_idx2word
config['target_word2idx'] = target_word2idx
config['target_idx2word'] = target_idx2word
config['num_input_tokens'] = num_input_tokens
config['num_target_tokens'] = num_target_tokens
config['max_input_seq_length'] = max_input_seq_length
config['max_target_seq_length'] = max_target_seq_length
return config
LOAD_EXISTING_WEIGHTS = True
np.random.seed(170110)
report_dir_path = './reports'
model_dir_path = './models'
df = pd.read_csv('./data/df_to_model.csv')
Y = df['target_text']
X = df['input_text']
config = fit_text(X, Y)
summarizer = Seq2SeqSummarizer(config)
if LOAD_EXISTING_WEIGHTS:
summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path))
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42)
history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=100)
history_plot_file_path = report_dir_path + '/' + Seq2SeqSummarizer.model_name + '-history.png'
plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'accuracy'})