forked from RasaHQ/rasa_lookup_demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_lookup.py
235 lines (173 loc) · 6.68 KB
/
run_lookup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from rasa_nlu.training_data import load_data
from rasa_nlu.model import Trainer
from rasa_nlu import config
from rasa_nlu import evaluate
from rasa_nlu import utils
import os
import logging
import re
import matplotlib.pylab as plt
# to get the logging stream into a string
try:
from cStringIO import StringIO
except ImportError:
from io import StringIO
"""
This script demonstrates the improvment of entity extraction recall by
use of lookup tables. A new feature in rasa_nlu.
The demo can by run by
python run_lookup.py <demo>
where <demo> is one of {food, company}.
If <demo> is omitted, then it will both demos back to back.
See the README.md for more information.
"""
DEMO_KEYS = ['food', 'company']
def train_model(td_file, config_file, model_dir):
# trains a model using the training data and config
td = load_data(td_file)
trainer = Trainer(config.load(config_file))
trainer.train(td)
# creates model and returns the path to this model for evaluation
model_loc = trainer.persist(model_dir)
return model_loc
def train_test(td_file, config_file, model_dir):
# helper function to split into test and train and evaluate on results.
td = load_data(td_file)
trainer = Trainer(config.load(config_file))
train, test = td.train_test_split(train_frac=0.6)
trainer.train(train)
model_loc = trainer.persist(model_dir)
with open('data/tmp/temp_test.json', 'w', encoding="utf8") as f:
f.write(test.as_json())
with open('data/temp_train.json', 'w', encoding="utf8") as f:
f.write(train.as_json())
evaluate_model('data/tmp/temp_test.json', model_loc)
def CV_eval(td_file, config_file, Nfolds=10):
# trains a model with crossvalidation using the training data and config
td = load_data(td_file)
configuration = config.load(config_file)
evaluate.run_cv_evaluation(td, Nfolds, configuration)
def evaluate_model(td_file, model_loc):
# evaluates the model on the training data
# wrapper for rasa_nlu.evaluate.run_evaluation
evaluate.run_evaluation(td_file, model_loc)
def get_path_dicts(key):
# gets the right training data and model directory given the demo
training_data_dict = {
'food': 'data/food/food_train.md',
'company': 'data/company/company_train.json'
}
training_data_lookup_dict = {
'food': 'data/food/food_train_lookup.md',
'company': 'data/company/company_train_lookup.json'
}
test_data_dict = {
'food': 'data/food/food_test.md',
'company': 'data/company/company_test.json'
}
model_dir_dict = {
'food': 'models/food',
'company': 'models/company'
}
training_data = training_data_dict[key]
training_data_lookup = training_data_lookup_dict[key]
test_data = test_data_dict[key]
model_dir = model_dir_dict[key]
return training_data, training_data_lookup, test_data, model_dir
def run_demo(key, disp_bar=True):
# runs the demo specified by key
# get the data for this key and the configs
training_data, training_data_lookup, test_data, model_dir = get_path_dicts(key)
config_file = 'configs/config.yaml'
config_baseline = 'configs/config_no_features.yaml'
# run a baseline
model_loc = train_model(training_data, config_baseline, model_dir)
evaluate_model(test_data, model_loc)
# run with more features in CRF
model_loc = train_model(training_data, config_file, model_dir)
evaluate_model(test_data, model_loc)
# run with the lookup table
model_loc = train_model(training_data_lookup, config_file, model_dir)
evaluate_model(test_data, model_loc)
# get the metrics
metric_list = strip_metrics(key)
# either print or plot them
if disp_bar:
plot_metrics(metric_list)
else:
print_metrics(metric_list)
def parse_metrics(match, key):
# from the regex match, parse out the precision, recall, f1 scores
elements = match.split(' ')[1:]
elements = filter(lambda x: len(x) > 2, elements)
elements = [float(e) for e in elements]
metrics = dict(zip(['key', 'precision', 'recall', 'f1'], [key] + elements))
return metrics
def strip_metrics(key):
# steals the logger stream and returns the metrics associated with key
stream_string = log_stream.getvalue()
stream_literal = repr(stream_string)
p_re = re.compile(key + '[ ]+\d.\d\d[ ]+\d.\d\d[ ]+\d.\d\d')
matches = p_re.findall(stream_literal)
metric_list = [parse_metrics(m, key) for m in matches]
return metric_list
def print_metrics(metric_list):
# prints out the demo preformance
key = metric_list[0]['key']
print("baseline, demo '{}' had:".format(key))
display_metrics(metric_list[0])
print("before adding lookup table(s), demo '{}' had:".format(key))
display_metrics(metric_list[1])
print("after adding lookup table(s), demo '{}' had:".format(key))
display_metrics(metric_list[2])
def display_metrics(metrics):
# helper function for print_metrics
for key, val in metrics.items():
print("\t{}:\t{}".format(key, val))
def plot_metrics(metric_list, save_path=None):
# runs through each test case and adds a set of bars to a plot. Saves
f, (ax1) = plt.subplots(1, 1)
plt.grid(True)
print_metrics(metric_list)
bar_metrics(metric_list[0], ax1, index=0)
bar_metrics(metric_list[1], ax1, index=1)
bar_metrics(metric_list[2], ax1, index=2)
if save_path is None:
save_path = 'img/bar_' + key + '.png'
plt.savefig(save_path, dpi=400)
def bar_metrics(metrics, ax, index=0):
# adds a set of metrics bars to the axis 'ax' of the plot
precision = metrics['precision']
recall = metrics['recall']
f1 = metrics['f1']
title = metrics['key']
width = 0.2
shift = index * width
indeces = [r + shift for r in range(3)]
metric_list = [precision, recall, f1]
ax.set_title(title)
ax.set_axisbelow(True)
ax.bar(indeces, metric_list, width)
ax.set_xticks([r + width / 2 for r in range(3)])
ax.set_xticklabels(('precision', 'recall', 'f1'))
ax.legend(['baseline', 'no lookup', 'with lookup'])
ax.set_ylabel('score')
if __name__ == "__main__":
# capture logging to string
log_stream = StringIO()
logging.basicConfig(stream=log_stream, level=logging.INFO)
# whether to create and save a bar plot
disp_bar = True
import sys
argv = sys.argv
if len(argv) < 2:
# run all of the demos
for key in DEMO_KEYS:
run_demo(key, disp_bar=disp_bar)
else:
key = argv[1]
if key not in DEMO_KEYS:
raise ValueError(
"first argument to run_demo.py must be one of {'food','company'}")
else:
run_demo(key, disp_bar=disp_bar)