forked from wconstab/BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhubconf.py
104 lines (86 loc) · 4.51 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import random
import torch
import numpy as np
from bert_pytorch import parse_args
from bert_pytorch.trainer import BERTTrainer
from bert_pytorch.dataset import BERTDataset, WordVocab
from bert_pytorch.model import BERT
from torch.utils.data import DataLoader
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class Model:
def __init__(self, device=None, jit=False):
self.device = device
self.jit = jit
args = parse_args(args=[
'--train_dataset', 'data/corpus.small',
'--test_dataset', 'data/corpus.small',
'--vocab_path', 'data/vocab.small',
'--output_path', 'bert.model',
]) # Avoid reading sys.argv here
args.with_cuda = self.device == 'cuda'
args.script = self.jit
print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", len(vocab))
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
if args.test_dataset is not None else None
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None
print("Building BERT model")
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
if args.script:
print("Scripting BERT model")
bert = torch.jit.script(bert)
self.trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, debug=args.debug)
example_batch = next(iter(train_data_loader))
self.example_inputs = example_batch['bert_input'].to(self.device), example_batch['segment_label'].to(self.device)
self.is_next = example_batch['is_next'].to(self.device)
self.bert_label = example_batch['bert_label'].to(self.device)
def get_module(self):
return self.trainer.model, self.example_inputs
def eval(self, niter=1):
trainer = self.trainer
for _ in range(niter):
# 1. forward the next_sentence_prediction and masked_lm model
next_sent_output, mask_lm_output = trainer.model.forward(*self.example_inputs)
# 2-1. NLL(negative log likelihood) loss of is_next classification result
# 2-2. NLLLoss of predicting masked token word
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
next_loss = trainer.criterion(next_sent_output, self.is_next)
mask_loss = trainer.criterion(mask_lm_output.transpose(1, 2), self.bert_label)
loss = next_loss + mask_loss
def train(self, niter=1):
trainer = self.trainer
for _ in range(niter):
# 1. forward the next_sentence_prediction and masked_lm model
next_sent_output, mask_lm_output = trainer.model.forward(*self.example_inputs)
# 2-1. NLL(negative log likelihood) loss of is_next classification result
# 2-2. NLLLoss of predicting masked token word
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
next_loss = trainer.criterion(next_sent_output, self.is_next)
mask_loss = trainer.criterion(mask_lm_output.transpose(1, 2), self.bert_label)
loss = next_loss + mask_loss
# 3. backward and optimization only in train
trainer.optim_schedule.zero_grad()
loss.backward()
trainer.optim_schedule.step_and_update_lr()
if __name__ == '__main__':
for device in ['cpu', 'cuda']:
for jit in [True, False]:
print("Testing device {}, JIT {}".format(device, jit))
m = Model(device=device, jit=jit)
bert, example_inputs = m.get_module()
bert(*example_inputs)
m.train()
m.eval()