forked from abhitopia/TorchTalkDLSummerCampLondon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_classification.lua
147 lines (121 loc) · 3.77 KB
/
mnist_classification.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
require 'torch'
require 'nn'
require 'optim'
mnist = require 'mnist'
require 'ReQLu'
require 'ReQLuScaled'
-- parameters
max_iters = 40
--select one of the
--activation_function = nn.ReLU
--activation_function = nn.Tanh
--activation_function = ReQLu
activation_function = ReQLuScaled
-- optimization parameters
sgd_params = {
learningRate = 1e-2,
learningRateDecay = 1e-4,
weightDecay = 1e-3,
momentum = 1e-4
}
-- load dataset
fullset = mnist.traindataset()
testset = mnist.testdataset()
-- trainset
trainset = {
size = 50000,
data = fullset.data[{{1,50000}}]:double(),
label = fullset.label[{{1,50000}}]
}
--validation set
validationset = {
size = 10000,
data = fullset.data[{{50001,60000}}]:double(),
label = fullset.label[{{50001,60000}}]
}
-- model definition
model = nn.Sequential()
model:add(nn.Reshape(28*28))
model:add(nn.Linear(28*28, 225))
model:add(activation_function())
model:add(nn.Linear(225, 144))
model:add(nn.Tanh())
model:add(nn.Linear(144, 10))
model:add(nn.LogSoftMax())
-- loss function
criterion = nn.ClassNLLCriterion()
-- get the flat parameters
x, dl_dx = model:getParameters()
-- initialize the parameters
x:uniform(-0.05, 0.05)
step = function(batch_size)
local current_loss = 0
local count = 0
local shuffle = torch.randperm(trainset.size)
batch_size = batch_size or 200
for t = 1,trainset.size,batch_size do
-- setup inputs and targets for this mini-batch
local size = math.min(t + batch_size - 1, trainset.size) - t
local inputs = torch.Tensor(size, 28, 28)
local targets = torch.Tensor(size)
for i = 1,size do
local input = trainset.data[shuffle[i+t]]
local target = trainset.label[shuffle[i+t]]
-- if target == 0 then target = 10 end
inputs[i] = input
targets[i] = target
end
targets:add(1)
local feval = function(x_new)
-- reset data
if x ~= x_new then x:copy(x_new) end
dl_dx:zero()
-- perform mini-batch gradient descent
local loss = criterion:forward(model:forward(inputs), targets)
model:backward(inputs, criterion:backward(model.output, targets))
return loss, dl_dx
end
_, fs = optim.sgd(feval, x, sgd_params)
-- fs is a table containing value of the loss function
-- (just 1 value for the SGD optimization)
count = count + 1
current_loss = current_loss + fs[1]
end
-- normalize loss
return current_loss / count
end
eval = function(dataset, batch_size)
local count = 0
batch_size = batch_size or 200
for i = 1,dataset.size,batch_size do
local size = math.min(i + batch_size - 1, dataset.size) - i
local inputs = dataset.data[{{i,i+size-1}}]
local targets = dataset.label[{{i,i+size-1}}]:long()
local outputs = model:forward(inputs)
local _, indices = torch.max(outputs, 2)
indices:add(-1)
local guessed_right = indices:eq(targets):sum()
count = count + guessed_right
end
return count / dataset.size
end
do
local last_accuracy = 0
local decreasing = 0
local threshold = 1 -- how many deacreasing epochs we allow
for i = 1,max_iters do
local loss = step()
print(string.format('Epoch: %d Current loss: %4f', i, loss))
local accuracy = eval(validationset)
print(string.format('Accuracy on the validation set: %4f', accuracy))
if accuracy < last_accuracy then
if decreasing > threshold then break end
decreasing = decreasing + 1
else
decreasing = 0
end
last_accuracy = accuracy
end
testset.data = testset.data:double()
eval(testset, 200)
end