-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
121 lines (94 loc) · 4.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from data import prepare_minibatch, get_minibatch
from evaluation import evaluate_minibatch, test_model_snli
import time
from torch import optim
import torch
from tqdm import tqdm
from models import AWESentenceEncoder, LSTMEncoder, BiLSTMEncoder
from datetime import datetime as d
#TODO: Add tensorboard support
def train_model(model, dataset, optimizer, criterion ,scheduler, num_epochs,
checkpoint_path = "models/",
batch_fn=get_minibatch,
prep_fn=prepare_minibatch,
eval_fn=evaluate_minibatch,
batch_size=64, eval_batch_size=None,
device = "cpu",
writer = None,
lr_factor = 5):
"""Train a model."""
train_data, dev_data, test_data = dataset.get_data()
# store train loss and validation accuracy during training
# so we can plot them afterwards
train_losses = []
val_losses = []
val_accuracies = []
test_acc = 0
print_every = 1000
best_eval = 0
best_iter = 0
best_model_path = None
if eval_batch_size is None:
eval_batch_size = batch_size
for epoch in tqdm(range(num_epochs)):
model.train()
current_loss = 0.
i = 0
if writer is not None:
writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], epoch)
for batch in batch_fn(train_data, batch_size=batch_size):
# forward pass
premise_tup, hypothesis_tup, targets = prep_fn(batch, model.vocab, device)
logits = model(premise_tup, hypothesis_tup)
B = targets.size(0) # later we will use B examples per update
loss = criterion(logits.view([B, -1]), targets.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
current_loss = current_loss + loss.item()
i = i + 1
if i%print_every == 0:
print(f"Batch number {i}")
train_losses.append(current_loss)
print("Training Loss: " + str(current_loss))
if writer is not None:
writer.add_scalar("Training Loss", current_loss, epoch)
_, _, dev_acc, dev_loss = eval_fn(
model, criterion, dev_data, batch_size=eval_batch_size,
batch_fn=batch_fn, prep_fn=prep_fn, device=device)
val_losses.append(dev_loss.item())
val_accuracies.append(dev_acc)
print("Validation Loss: " + str(dev_loss.item()))
if writer is not None:
writer.add_scalar("Validation Loss", dev_loss, epoch)
print("Validation Accuracy: " + str(dev_acc))
if writer is not None:
writer.add_scalar("Validation Accuracy", dev_acc, epoch)
if dev_acc > best_eval:
print("new highscore")
best_eval = dev_acc
best_iter = epoch
best_model_path = createCheckpointPathName(checkpoint_path, model.encoder, dev_acc, model.complex)
torch.save(model, best_model_path)
optimizer.param_groups[0]['lr'] /= lr_factor
if optimizer.param_groups[0]['lr'] < 10**(-5):
print("Training stopped due to LR limit.")
break
scheduler.step()
print("Loading best model to test...")
test_acc = test_model_snli(best_model_path, test_data, criterion, batch_fn, prep_fn, eval_fn, batch_size, device, writer)
print("Test Accuracy: " + str(test_acc))
return train_losses, val_losses, val_accuracies, test_acc
def createCheckpointPathName(path, model, acc, complex):
if path[-1] != "/" and path[-1] != "\\":
path = path + "/"
cls = model.__class__
name = cls.__name__
if cls == BiLSTMEncoder:
if model.pool_type is not None:
name = name + "_pooling-" + model.pool_type
date = d.now().strftime("%Y-%m-%d-%H-%M-%S")
if not complex:
return path + name +"_" + str(model.out_dim) +"_"+f"{acc:.2f}" +"_" + date + ".pt"
else:
return path + name +"_complex_" + str(model.out_dim) +"_"+f"{acc:.2f}" +"_" + date + ".pt"