-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtrain_lstm.py
116 lines (100 loc) · 4.39 KB
/
train_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Runs a simple Neural Machine Translation model
Type `python run.py -h` for help with arguments.
"""
import os
import argparse
import numpy as np
import os
import keras
from keras import backend as K
from keras.layers import Lambda
#from keras import backend as k
from keras.models import Model
from keras.layers import Dense, Embedding, Activation, Permute
from keras import regularizers, constraints, initializers, activations
from keras.layers import Input, Flatten, Dropout
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed, Bidirectional
from keras.callbacks import ModelCheckpoint
from reader_lstm import Data,Vocabulary
from model.dualencoder import dualenc
# create a directory if it doesn't already exist
if not os.path.exists('./weights'):
os.makedirs('./weights/')
def main(args):
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# Dataset functions
vocab = Vocabulary('./fkdata/vocabl_fk.json', padding=args.padding)
vocab = Vocabulary('./fkdata/vocabl_fk.json',
padding=args.padding)
kb_vocab=Vocabulary('./fkdata/vocabl_fk.json',
padding=4)
print('Loading datasets.')
training = Data(args.training_data, vocab,kb_vocab)
validation = Data(args.validation_data, vocab, kb_vocab)
training.load()
validation.load()
training.transform()
training.kb_out()
validation.transform()
validation.kb_out()
print('Datasets Loaded.')
print('Compiling Model.')
model = dualenc(pad_length=args.padding,
embedding_size=args.embedding,
vocab_size=vocab.size(),
batch_size=args.batch_size,
n_chars=vocab.size(),
n_labels=vocab.size(),
embedding_learnable=True,
encoder_units=256,
decoder_units=256,trainable=True)
#cell_states = K.variable(value=numpy.random.normal(size=(32, 10)))
model.summary()
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', ])
print('Model Compiled.')
print('Training. Ctrl+C to end early.')
try:
model.fit_generator(generator=training.generator(args.batch_size),
steps_per_epoch=300,
validation_data=validation.generator(args.batch_size),
validation_steps=10,
workers=1,
verbose=1,
epochs=args.epochs)
except KeyboardInterrupt as e:
print('Model training stopped early.')
model.save_weights("cl_modellstm_weightsfk_kb.hdf5")
print('Model training complete.')
#run_examples(model, input_vocab, output_vocab)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
named_args = parser.add_argument_group('named arguments')
named_args.add_argument('-e', '--epochs', metavar='|',
help="""Number of Epochs to Run""",
required=False, default=30, type=int)
named_args.add_argument('-es', '--embedding', metavar='|',
help="""Size of the embedding""",
required=False, default=100, type=int)
named_args.add_argument('-g', '--gpu', metavar='|',
help="""GPU to use""",
required=False, default='1', type=str)
named_args.add_argument('-p', '--padding', metavar='|',
help="""Amount of padding to use""",
required=False, default=20, type=int)
named_args.add_argument('-t', '--training-data', metavar='|',
help="""Location of training data""",
required=False, default='./data/cl_train_lstmfk_kb.csv')
named_args.add_argument('-v', '--validation-data', metavar='|',
help="""Location of validation data""",
required=False, default='./data/cl_val_lstmfk_kb.csv')
named_args.add_argument('-b', '--batch-size', metavar='|',
help="""Location of validation data""",
required=False, default=50, type=int)
args = parser.parse_args()
print(args)
main(args)