-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
508 lines (430 loc) · 22.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# from transformers.file_utils import get_file_from_repo
from importlib.machinery import SourceFileLoader
import os
from transformers import EncoderDecoderModel, AutoConfig, AutoModel, AutoTokenizer, EncoderDecoderConfig, RobertaForCausalLM
from transformers.modeling_utils import PreTrainedModel, logging
import torch
from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithCrossAttentions
from typing import Dict, Any, Optional, Tuple
from torch.nn import CrossEntropyLoss
from dataclasses import dataclass
from model_config import InvertTextNormalizationConfig, PretrainedConfig, DecoderInvertTextNormalizationConfig
cache_dir = './cache'
encoder_model_name = 'vinai/phobert-base'
decoder_model_name = 'vinai/phobert-base'
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
logger = logging.get_logger(__name__)
@dataclass
class InvertTextNormalizationOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
logits_spoken_tagging: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
def invert_text_features(encoder_hidden_states, word_src_lengths, spoken_label):
list_features = []
list_features_mask = []
max_length = word_src_lengths.max()
feature_pad = torch.zeros_like(encoder_hidden_states[0, :1, :])
for hidden_state, word_length, list_idx in zip(encoder_hidden_states, word_src_lengths, spoken_label):
for idx in list_idx:
if idx > 0:
start = sum(word_length[:idx])
end = start + word_length[idx]
remain_length = max_length - word_length[idx]
list_features_mask.append(torch.cat([torch.ones_like(spoken_label[0, 0]).expand(word_length[idx]),
torch.zeros_like(
spoken_label[0, 0].expand(remain_length))]).unsqueeze(0))
spoken_phrases_feature = hidden_state[start: end]
list_features.append(torch.cat([spoken_phrases_feature,
feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0))
return torch.cat(list_features), torch.cat(list_features_mask)
def invert_text_labels(decoder_input_ids, labels, word_tgt_lengths, spoken_idx):
list_decoder_input_ids = []
list_labels = []
max_length = word_tgt_lengths.max()
init_decoder_ids = torch.tensor([0], device=labels.device, dtype=labels.dtype)
pad_decoder_ids = torch.tensor([1], device=labels.device, dtype=labels.dtype)
eos_decoder_ids = torch.tensor([2], device=labels.device, dtype=labels.dtype)
ignore_labels = torch.tensor([-100], device=labels.device, dtype=labels.dtype)
for decoder_inputs, decoder_label, word_length, list_idx in zip(decoder_input_ids,
labels, word_tgt_lengths, spoken_idx):
for idx in list_idx:
if idx > 0:
start = sum(word_length[:idx - 1])
end = start + word_length[idx - 1]
remain_length = max_length - word_length[idx - 1]
remain_decoder_input_ids = max_length - len(decoder_inputs[start + 1:end + 1])
list_decoder_input_ids.append(torch.cat([init_decoder_ids,
decoder_inputs[start + 1:end + 1],
pad_decoder_ids.expand(remain_decoder_input_ids)]).unsqueeze(0))
list_labels.append(torch.cat([decoder_label[start:end],
eos_decoder_ids,
ignore_labels.expand(remain_length)]).unsqueeze(0))
decoder_input_ids = torch.cat(list_decoder_input_ids)
labels = torch.cat(list_labels)
return decoder_input_ids, labels
class InvertTextNormalization(EncoderDecoderModel):
config_class = InvertTextNormalizationConfig
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
super().__init__(config)
if encoder is None:
from transformers.models.auto.modeling_auto import AutoModel
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
decoder = DecoderInvertTextNormalization._from_config(config.decoder)
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = torch.nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
if self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
# spoken tagging
self.dropout = torch.nn.Dropout(0.3)
# 0: "O", 1: "B", 2: "I"
self.spoken_tagging_classifier = torch.nn.Linear(config.encoder.hidden_size, 3)
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
# FFN
# self.FFN = torch.nn.Linear(config.decoder.hidden_size, config.decoder.vocab_size)
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs
) -> PreTrainedModel:
kwargs_encoder = {
argument[len("encoder_"):]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args,
**kwargs_encoder)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config = DecoderInvertTextNormalizationConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = DecoderInvertTextNormalization.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = InvertTextNormalizationConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
"""
return_dict (bool, optional): True - return a ModelOutput.
False - return a plain tuple.
"""
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
# print("decoder_input_ids: ", decoder_inputs["input_ids"])
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def forward(
self,
input_ids=None,##
attention_mask=None,# ##
decoder_input_ids=None,#
decoder_attention_mask=None,#
encoder_outputs=None,#
past_key_values=None,#
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,##
use_cache=None,#
spoken_label=None,##
word_src_lengths=None,##
word_tgt_lengths=None,##
spoken_idx=None,##
output_attentions=None,
output_hidden_states=None,
return_dict=None,
inputs_length=None,
outputs=None,
outputs_length=None,
src=None,
tgt=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
spoken_tagging_output = None
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
)
spoken_tagging_output = self.spoken_tagging_classifier(self.dropout(encoder_outputs[0]))
encoder_hidden_states = encoder_outputs[0]
# print("spoken_tagging_output: ", torch.argmax(encoder_outputs.spoken_tagging_output, dim=-1))
if spoken_idx is not None:
encoder_hidden_states, attention_mask = invert_text_features(encoder_hidden_states,
word_src_lengths,
spoken_idx)
decoder_input_ids, labels = invert_text_labels(decoder_input_ids, labels,
word_tgt_lengths,
spoken_idx)
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
loss = None
# if labels is not None:
# logits = self.FFN(self.dropout(decoder_outputs.last_hidden_state))
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if spoken_label is not None:
loss_fct = CrossEntropyLoss()
spoken_tagging_loss = loss_fct(spoken_tagging_output.reshape(-1, 3), spoken_label.view(-1))
loss = loss + spoken_tagging_loss
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return InvertTextNormalizationOutput(
loss=loss,
logits=decoder_outputs.logits,
logits_spoken_tagging=spoken_tagging_output,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
# encoder_last_hidden_state=encoder_outputs.last_hidden_state,
# encoder_hidden_states=encoder_outputs.hidden_states,
# encoder_attentions=encoder_outputs.attentions,
)
class DecoderInvertTextNormalization(RobertaForCausalLM):
config_class = DecoderInvertTextNormalizationConfig
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config):
super().__init__(config)
self.dense_query_copy = torch.nn.Linear(config.hidden_size, config.hidden_size)
"""
torch.bmm(input, mat2, *, out=None) -> Tensor:
Performs a batch matrix-matrix product of matrices stored in input and mat2. input and mat2 must be 3-D tensors
each containing the same number of matrices. If input is a (b x n x m) tensor, mat2 is a (b x m x p) tensor, out will be
a (b x n xs p)tensor.
"""
def forward_copy_attention(self, query, values, values_mask):
"""
:param query: batch * output_steps * hidden_state
:param values: batch * max_encoder_steps * hidden_state
:param values_mask: batch * output_steps * max_encoder_steps
:return: batch * output_steps * hidden_state
"""
dot_attn_score = torch.bmm(query, values.transpose(2, 1))
attn_mask = (1 - values_mask.clone().unsqueeze(1)).bool()
dot_attn_score.masked_fill_(attn_mask, -float('inf'))
dot_attn_score = torch.softmax(dot_attn_score, dim=-1)
result_attention = torch.bmm(dot_attn_score, values)
return result_attention
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
# attention with input encoded
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# output copy attention
query_copy = torch.relu(self.dense_query_copy(sequence_output))
sequence_atten_copy_output = self.forward_copy_attention(query_copy,
encoder_hidden_states,
encoder_attention_mask)
prediction_scores = self.lm_head(sequence_output + sequence_atten_copy_output)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return output
result = CausalLMOutputWithCrossAttentions(
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
return result
def init_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(encoder_model_name, use_fast=False)
tokenizer.model_input_names = ["input_ids",
"attention_mask",
"labels"]
return tokenizer
def init_model():
tokenizer = init_tokenizer()
# set encoder decoder tying to True
roberta = InvertTextNormalization.from_encoder_decoder_pretrained(encoder_model_name,
decoder_model_name,
tie_encoder_decoder=False)
# set special tokens
roberta.config.decoder_start_token_id = tokenizer.bos_token_id
roberta.config.eos_token_id = tokenizer.eos_token_id
roberta.config.pad_token_id = tokenizer.pad_token_id
# sensible parameters for beam search
# set decoding params
roberta.config.max_length = 200
roberta.config.early_stopping = True
roberta.config.no_repeat_ngram_size = 3
roberta.config.length_penalty = 2.0
roberta.config.num_beams = 1
roberta.config.vocab_size = roberta.config.encoder.vocab_size
return roberta, tokenizer