-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopm-rnn-retrain.py
82 lines (66 loc) · 2.44 KB
/
opm-rnn-retrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
import argparse
import os
import utils
def encode(y):
return np_utils.to_categorical(y)
def create_model(X, y, params):
input = Input(shape=(X.shape[1], X.shape[2]))
lstm1 = LSTM(params["LSTM-1"], return_sequences=True)(input)
drop1 = Dropout(params["Dropout-1"])(lstm1)
lstm2 = LSTM(params["LSTM-2"])(drop1)
drop2 = Dropout(params["Dropout-2"])(lstm2)
output = Dense(y.shape[1], activation=params["activation"])(drop2)
return Model(inputs=input, outputs=output)
def train(dataset, output):
raw_text = open(dataset).read()
pre_text = utils.pre_process(raw_text)
char_map = utils.map_chars_to_int(pre_text)
params = {
"seq_length":80,
"n_chars" : len(pre_text),
"n_vocab" : len(char_map)
}
X = []
y = []
X, y = utils.prepare_dset(pre_text, char_map, params)
params["n_patterns"] = len(X)
X = np.reshape(X, (params["n_patterns"], params["seq_length"], 1))
y = encode(y)
model_params = {
"LSTM-1" : 512,
"LSTM-2" : 256,
"Dropout-1" : 0.3,
"Dropout-2" : 0.2,
"activation": "softmax",
"loss" : "categorical_crossentropy",
"optimizer" : "adam",
"epochs" : 100,
"batch_size" : 32
}
model = create_model(X, y, model_params)
filepath = os.path.join(output, "weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5")
checkpoint = ModelCheckpoint(filepath, monitor='loss',
verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
model.compile(loss=model_params["loss"],
optimizer=model_params["optimizer"])
model.fit(X, y, epochs=model_params["epochs"],
batch_size=model_params["batch_size"],
callbacks=callbacks_list)
with open("model-opm.json" , "w") as json_file:
json_file.write(model.to_json())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset",
help="Path of the training dataset", default="opm-lyrics.txt")
parser.add_argument("--output",
help="Path where to save the weights", default="weights/")
args = parser.parse_args()
dset = args.dataset
out = args.output
train(dset, out)