-
Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathrnn_crf.py
35 lines (24 loc) · 829 Bytes
/
rnn_crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from utils import *
from rnn_encoder import *
from crf import *
class rnn_crf(nn.Module):
def __init__(self, cti, wti, num_tags):
super().__init__()
self.rnn = rnn_encoder(cti, wti, num_tags)
self.crf = crf(num_tags)
if CUDA: self = self.cuda()
def forward(self, xc, xw, y0): # for training
self.zero_grad()
mask = y0[1:].gt(PAD_IDX).float()
h = self.rnn(xc, xw, mask)
loss = self.crf(h, y0, mask)
return loss
def decode(self, xc, xw, lens): # for inference
if HRE:
mask = [[i > j for j in range(lens[0])] for i in lens]
mask = Tensor(mask).transpose(0, 1)
else:
mask = xw.gt(PAD_IDX).float()
h = self.rnn(xc, xw, mask)
y = self.crf.decode(h, mask)
return y