-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathrun.py
107 lines (87 loc) · 3.59 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6
import os
import torch
import torch.nn as nn
import torch.optim as optim
from config import Config
from utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader
from model import Att_BLSTM
from evaluate import Eval
def print_result(predict_label, id2rel, start_idx=8001):
with open('predicted_result.txt', 'w', encoding='utf-8') as fw:
for i in range(0, predict_label.shape[0]):
fw.write('{}\t{}\n'.format(
start_idx+i, id2rel[int(predict_label[i])]))
def train(model, criterion, loader, config):
train_loader, dev_loader, _ = loader
optimizer = optim.Adadelta(
model.parameters(), lr=config.lr, weight_decay=config.L2_decay)
print(model)
print('traning model parameters:')
for name, param in model.named_parameters():
if param.requires_grad:
print('%s : %s' % (name, str(param.data.shape)))
print('--------------------------------------')
print('start to train the model ...')
eval_tool = Eval(config)
max_f1 = -float('inf')
for epoch in range(1, config.epoch+1):
for step, (data, label) in enumerate(train_loader):
model.train()
data = data.to(config.device)
label = label.to(config.device)
optimizer.zero_grad()
logits = model(data)
loss = criterion(logits, label)
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), clip_value=5)
optimizer.step()
_, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)
f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)
print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'
% (epoch, train_loss, dev_loss, f1), end=' ')
if f1 > max_f1:
max_f1 = f1
torch.save(model.state_dict(), os.path.join(
config.model_dir, 'model.pkl'))
print('>>> save models!')
else:
print()
def test(model, criterion, loader, config):
print('--------------------------------------')
print('start test ...')
_, _, test_loader = loader
model.load_state_dict(torch.load(
os.path.join(config.model_dir, 'model.pkl')))
eval_tool = Eval(config)
f1, test_loss, predict_label = eval_tool.evaluate(
model, criterion, test_loader)
print('test_loss: %.3f | micro f1 on test: %.4f' % (test_loss, f1))
return predict_label
if __name__ == '__main__':
config = Config()
print('--------------------------------------')
print('some config:')
config.print_config()
print('--------------------------------------')
print('start to load data ...')
word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
rel2id, id2rel, class_num = RelationLoader(config).get_relation()
loader = SemEvalDataLoader(rel2id, word2id, config)
train_loader, dev_loader = None, None
if config.mode == 1: # train mode
train_loader = loader.get_train()
dev_loader = loader.get_dev()
test_loader = loader.get_test()
loader = [train_loader, dev_loader, test_loader]
print('finish!')
print('--------------------------------------')
model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config)
model = model.to(config.device)
criterion = nn.CrossEntropyLoss()
if config.mode == 1: # train mode
train(model, criterion, loader, config)
predict_label = test(model, criterion, loader, config)
print_result(predict_label, id2rel)