-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
242 lines (204 loc) · 10.2 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#!/usr/bin/env python
# coding: utf-8
from __future__ import division, print_function, unicode_literals
import argparse
import json
import os
import shutil
import time
import numpy as np
import torch
from utils import util, multiwoz_dataloader
from models.evaluator import *
from models.model import Model
from utils.util import detected_device, pp_mkdir
from multiwoz.Evaluators import *
# pp added: print out env
util.get_env_info()
parser = argparse.ArgumentParser(description='multiwoz1-bsl-te')
# 1. Data & Dir
data_arg = parser.add_argument_group('Data')
data_arg.add_argument('--data_dir', type=str, default='data/multi-woz', help='the root directory of data')
data_arg.add_argument('--result_dir', type=str, default='results/bsl/')
data_arg.add_argument('--model_name', type=str, default='translate.ckpt')
# 2. MISC
misc_arg = parser.add_argument_group('Misc')
misc_arg.add_argument('--dropout', type=float, default=0.0)
misc_arg.add_argument('--use_emb', type=str, default='False')
misc_arg.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
misc_arg.add_argument('--no_models', type=int, default=20, help='how many models to evaluate')
misc_arg.add_argument('--beam_width', type=int, default=10, help='Beam width used in beamsearch')
misc_arg.add_argument('--write_n_best', type=util.str2bool, nargs='?', const=True, default=False, help='Write n-best list (n=beam_width)')
# 3. Here add new args
new_arg = parser.add_argument_group('New')
new_arg.add_argument('--intent_type', type=str, default=None, help='separate experts by intents: None, domain, sysact or domain_act') # pp added
new_arg.add_argument('--lambda_expert', type=float, default=0.5) # use xx percent of training data
new_arg.add_argument('--mu_expert', type=float, default=0.5) # use xx percent of training data
new_arg.add_argument('--gamma_expert', type=float, default=0.5) # use xx percent of training data
new_arg.add_argument('--debug', type=util.str2bool, nargs='?', const=True, default=False, help='if True use small data for debugging')
args = parser.parse_args()
args.device = "cuda" if torch.cuda.is_available() else "cpu"
print('args.device={}'.format(args.device))
# construct dirs
args.model_dir = '%s/model/' % args.result_dir
args.train_output = '%s/data/train_dials/' % args.result_dir
args.valid_output = '%s/data/valid_dials/' % args.result_dir
args.decode_output = '%s/data/test_dials/' % args.result_dir
print(args)
# pp added: init seed
util.init_seed(args.seed)
def load_config(args):
config = util.unicode_to_utf8(
# json.load(open('%s.json' % args.model_path, 'rb')))
json.load(open('{}{}.json'.format(args.model_dir, args.model_name), 'rb')))
for key, value in args.__args.items():
try:
config[key] = value.value
except:
config[key] = value
return config
def loadModelAndData(num):
# Load dictionaries
input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index = util.loadDictionaries(mdir=args.data_dir)
# pp added: load intents
intent2index, index2intent = util.loadIntentDictionaries(intent_type=args.intent_type, intent_file='{}/intents.json'.format(args.data_dir)) if args.intent_type else (None, None)
# Reload existing checkpoint
model = Model(args, input_lang_index2word, output_lang_index2word, input_lang_word2index, output_lang_word2index, intent2index)
model = model.to(detected_device)
if args.load_param:
model.loadModel(iter=num)
# # Load validation file list:
with open('{}/val_dials.json'.format(args.data_dir)) as outfile:
val_dials = json.load(outfile)
#
# # Load test file list:
with open('{}/test_dials.json'.format(args.data_dir)) as outfile:
test_dials = json.load(outfile)
return model, val_dials, test_dials, input_lang_word2index, output_lang_word2index, intent2index, index2intent
def decode(num=1, beam_search=False):
model, val_dials, test_dials, input_lang_word2index, output_lang_word2index, intent2index, index2intent = loadModelAndData(num)
delex_path = '%s/delex.json' % args.data_dir
start_time = time.time()
model.beam_search = beam_search
step = 0 if not args.debug else 2 # small sample for debug
# VALIDATION
val_dials_gen = {}
valid_loss = 0
for name, val_file in list(val_dials.items())[-step:]:
loader = multiwoz_dataloader.get_loader_by_dialogue(val_file, name,
input_lang_word2index, output_lang_word2index,
args.intent_type, intent2index)
data = iter(loader).next()
# Transfer to GPU
if torch.cuda.is_available():
data = [data[i].cuda() if isinstance(data[i], torch.Tensor) else data[i] for i in range(len(data))]
input_tensor, input_lengths, target_tensor, target_lengths, bs_tensor, db_tensor, mask_tensor = data
output_words, loss_sentence = model.predict(input_tensor, input_lengths, target_tensor, target_lengths,
db_tensor, bs_tensor, mask_tensor)
valid_loss += loss_sentence
val_dials_gen[name] = output_words
print('Current VALID LOSS:', valid_loss)
# Valid_Score = evaluateModel(val_dials_gen, val_dials, delex_path, mode='Valid')
Valid_Score = evaluator.summarize_report(val_dials_gen, mode='Valid')
# evaluteNLG(val_dials_gen, val_dials)
# TESTING
test_dials_gen = {}
test_loss = 0
for name, test_file in list(test_dials.items())[-step:]:
loader = multiwoz_dataloader.get_loader_by_dialogue(test_file, name,
input_lang_word2index, output_lang_word2index,
args.intent_type, intent2index)
data = iter(loader).next()
# Transfer to GPU
if torch.cuda.is_available():
data = [data[i].cuda() if isinstance(data[i], torch.Tensor) else data[i] for i in range(len(data))]
input_tensor, input_lengths, target_tensor, target_lengths, bs_tensor, db_tensor, mask_tensor = data
output_words, loss_sentence = model.predict(input_tensor, input_lengths, target_tensor, target_lengths,
db_tensor, bs_tensor, mask_tensor)
test_loss += loss_sentence
test_dials_gen[name] = output_words
test_loss /= len(test_dials)
print('Current TEST LOSS:', test_loss)
# Test_Score = evaluateModel(test_dials_gen, test_dials, delex_path, mode='Test')
Test_Score = evaluator.summarize_report(test_dials_gen, mode='Test')
# evaluteNLG(test_dials_gen, test_dials)
print('TIME:', time.time() - start_time)
return Valid_Score, val_dials_gen, np.exp(valid_loss), Test_Score, test_dials_gen, np.exp(test_loss)
def decodeWrapper(beam_search=False):
# Load config file
# with open(args.model_path + '.config') as f:
with open('{}{}.config'.format(args.model_dir, args.model_name)) as f:
add_args = json.load(f)
for k, v in add_args.items():
if k=='data_dir': # ignore this arg
continue
setattr(args, k, v)
args.mode = 'test'
args.load_param = True
args.dropout = 0.0
assert args.dropout == 0.0
# Start going through models
# args.original = args.model_path
Best_Valid_Score = None
Best_Test_Score = None
Best_PPL = None
Best_model_id = 0
Best_val_dials_gen = {}
Best_test_dials_gen = {}
for ii in range(1, args.no_models + 1):
print(30 * '-' + 'EVALUATING EPOCH %s' % ii)
# args.model_path = args.model_path + '-' + str(ii)
with torch.no_grad():
Valid_Score, val_dials_gen, val_ppl, Test_Score, test_dials_gen, test_ppl = decode(ii, beam_search)
if Best_Valid_Score is None or Best_Valid_Score[-2] < Valid_Score[-2]:
Best_Valid_Score = Valid_Score
Best_Test_Score = Test_Score
Best_PPL = test_ppl
Best_val_dials_gen = val_dials_gen
Best_test_dials_gen = test_dials_gen
Best_model_id = ii
# try:
# decode(ii, intent2index)
# except:
# print('cannot decode')
# save best generated output to json
print('Summary'+'~'*50)
print('Best model: %s'%(Best_model_id))
BLEU, MATCHES, SUCCESS, SCORE, P, R, F1 = Best_Test_Score
mode = 'Test'
print('%s PPL: %.2f' % (mode, Best_PPL))
print('%s BLEU: %.4f' % (mode, BLEU))
print('%s Matches: %2.2f%%' % (mode, MATCHES))
print('%s Success: %2.2f%%' % (mode, SUCCESS))
print('%s Score: %.4f' % (mode, SCORE))
print('%s Precision: %.2f%%' % (mode, P))
print('%s Recall: %.2f%%' % (mode, R))
print('%s F1: %.2f%%' % (mode, F1))
suffix = 'bm' if beam_search else 'gd'
try:
with open(args.valid_output + 'val_dials_gen_%s.json' % suffix, 'w') as outfile:
json.dump(Best_val_dials_gen, outfile, indent=4)
except:
print('json.dump.err.valid')
try:
with open(args.decode_output + 'test_dials_gen_%s.json' % suffix, 'w') as outfile:
json.dump(Best_test_dials_gen, outfile, indent=4)
except:
print('json.dump.err.test')
if __name__ == '__main__':
# create dir for generated outputs of valid and test set
pp_mkdir(args.valid_output)
pp_mkdir(args.decode_output)
evaluator = MultiWozEvaluator('MultiWozEvaluator')
print('\n\nGreedy Search'+'='*50)
decodeWrapper(beam_search=False)
print('\n\nBeam Search' + '=' * 50)
decodeWrapper(beam_search=True)
# evaluteNLGFile(gen_dials_fpath='results/bsl_20190510161309/data/test_dials/test_dials_gen.json',
# ref_dialogues_fpath='data/test_dials.json')
# evaluteNLGFiles(gen_dials_fpaths=['results/bsl_20190510161309/data/test_dials/test_dials_gen.json',
# 'results/moe1_20190510165545/data/test_dials/test_dials_gen.json'],
# ref_dialogues_fpath='data/test_dials.json')
# from nlgeval import compute_metrics
# metrics_dict = compute_metrics(hypothesis='/Users/pp/Code/nlg-eval/examples/hyp.txt',
# references=['/Users/pp/Code/nlg-eval/examples/ref1.txt'])