-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathff_centroids.py
135 lines (112 loc) · 4.45 KB
/
ff_centroids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Flatten
from torch.optim import Adam
from torch.utils.data import DataLoader
import mnist
from ff_utils import LayerOutputs, UnitLength
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
def distance_to_centroids(h, y_true, epsilon=1e-12):
"""
Calculates the mean squared distance to the centroid of each class.
Returns a tensor of shape [n_examples, 10].
"""
safe_mean = lambda x, dim: x.sum(dim) / (x.shape[dim] + epsilon)
# TODO: what if class is missing?
# * determine centroids only for classes that are present, and return torch.unique(y_true)
# * or treat centroids as trainable parameters, so they slowly update
class_centroids = torch.stack(
[safe_mean(h[y_true == i], 0) for i in range(10)], dim=1
) # [n_in, 10]
x_to_centroids = h.unsqueeze(2) - class_centroids # [n_examples, n_in, 10]
return x_to_centroids.pow(2).mean(1) # [n_examples, 10]
@torch.no_grad()
def predict(model: nn.Sequential, x, y_true, skip_layers=1):
"""Predict by finding the class with closest centroid to each example."""
d = sum(
[distance_to_centroids(h, y_true) for h in LayerOutputs(model, x)][skip_layers:]
)
return d.argmin(1) # type: ignore
def centroid_loss(h, y_true, alpha=4.0, epsilon=1e-12, temperature=1.0):
"""
Loss function based on (squared) distance to the true centroid vs a nearby centroid.
Achieves an error rate of ~1.7%.
"""
# Distance from h to centroids of each class
d = distance_to_centroids(h, y_true)
# Choose a nearby class, at random, using the inverse distance as a
# probability distribution. To stop torch.multinomial getting out-of-range
# values, we first normalised by the minimum distance.
min_d = torch.min(d, 1, keepdim=True)[0]
norm_d = (d + epsilon) / (min_d + epsilon)
y_near = torch.multinomial(norm_d.pow(-temperature), 1).squeeze(1)
# Smoothed version of triplet loss: max(0, d2_same - d2_near + margin)
d_true = d[range(d.shape[0]), y_true] # ||anchor - positive||^2
d_near = d[range(d.shape[0]), y_near] # ||anchor - negative||^2
return F.silu(alpha * (d_true - d_near)).mean()
# %%
# Define the model
#
# Must be an iterable of layers. I find it works best if each layer starts with
# a UnitLength() sub-layer.
n_units = 500 # 2000 improves error rate
model = nn.Sequential(
nn.Sequential(Flatten(), UnitLength(), nn.Linear(784, n_units), nn.ReLU()),
nn.Sequential(UnitLength(), nn.Linear(n_units, n_units), nn.ReLU()),
).to(device)
def error_rate(model: nn.Sequential, data_loader: DataLoader) -> float:
model.eval() # Set the model to evaluation mode
correct = 0
total = 0
for x, y in data_loader:
x, y = x.to(device), y.to(device)
predicted = predict(model, x, y)
correct += (predicted == y).sum().item()
total += y.size(0)
return 1 - correct / total
# %%
# Train the model
torch.manual_seed(42)
learning_rate = 0.05
optimiser = Adam(model.parameters(), lr=learning_rate)
num_epochs = 120
batch_size = 4096
train_loader = DataLoader(
list(zip(mnist.train_x, mnist.train_y)), batch_size=batch_size, shuffle=False
)
test_loader = DataLoader(
list(zip(mnist.test_x, mnist.test_y)), batch_size=batch_size, shuffle=False
)
print(
"[init] Training: {:.2%}, Test: {:.2%}".format(
error_rate(model, train_loader),
error_rate(model, test_loader),
)
)
for epoch in range(num_epochs):
# Mini-batch training
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# Train layers in turn on same mini-batch, using backpropagation locally only
model.train()
for layer in model:
h = layer(x)
temperature = 4
loss = centroid_loss(h, y, temperature=temperature)
optimiser.zero_grad()
loss.backward()
optimiser.step()
with torch.no_grad():
x = layer(x)
# Evaluate the model on the training and test set
if epoch % 5 == 0:
print(
"[{:>4d}] Training: {:.2%}, Test: {:.2%}".format(
epoch,
error_rate(model, train_loader),
error_rate(model, test_loader),
)
)