-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathann_baseline.py
105 lines (85 loc) · 3.03 KB
/
ann_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch.functional import split
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import utility.utils as utils
import os
# parse the options
options = utils.parse_args()
batchsize = options.batch
te_batchsize = options.batch
CODES_DIR = os.path.dirname(os.getcwd())
# dataset path
DATAROOT = os.path.join(CODES_DIR, 'MNIST_CLS/data/MNIST/processed')
DATAROOT = 'dataset/EMNIST/bymerge/processed'
# oect data path
DEVICE_DIR = os.path.join(os.getcwd(), 'data')
TRAIN_PATH = os.path.join(DATAROOT, 'training_bymerge.pt')
TEST_PATH = os.path.join(DATAROOT, 'test_bymerge.pt')
tr_dataset = utils.SimpleDataset(TRAIN_PATH,
num_pulse=options.num_pulse,
crop=options.crop,
sampling=options.sampling,
ori_img=True)
te_dataset = utils.SimpleDataset(TEST_PATH,
num_pulse=options.num_pulse,
crop=options.crop,
sampling=options.sampling,
ori_img=True)
model = torch.nn.Sequential(
nn.Linear(784, 47)
)
train_loader = DataLoader(tr_dataset,
batch_size=batchsize,
shuffle=True)
test_dataloader = DataLoader(te_dataset, batch_size=batchsize)
num_epoch = 50
learning_rate = 1e-3
num_data = len(tr_dataset)
num_te_data = len(te_dataset)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
# train
acc_list = []
loss_list = []
for epoch in range(num_epoch):
acc = []
loss = 0
for i, (data, _, target) in enumerate(train_loader):
optimizer.zero_grad()
this_batch_size = len(data)
data = data.to(torch.float)
# readout layer
logic = model(data)
logic = torch.squeeze(logic)
batch_loss = criterion(logic, target)
loss += batch_loss
batch_acc = torch.sum(logic.argmax(dim=-1) == target) / batchsize
acc.append(batch_acc)
batch_loss.backward()
optimizer.step()
# if i_batch % 300 == 0:
# print('%d data trained' % i_batch)
scheduler.step()
acc_epoch = (sum(acc) * batchsize / num_data).numpy()
acc_list.append(acc_epoch)
loss_list.append(loss)
print("epoch: %d, loss: %.2f, acc: %.6f, " % (epoch, loss, acc_epoch))
# test
te_accs = []
te_outputs = []
targets = []
with torch.no_grad():
for i, (data, target) in enumerate(test_dataloader):
this_batch_size = len(data)
output = model(data.to(torch.float))
output = torch.squeeze(output)
te_outputs.append(output)
acc = torch.sum(output.argmax(dim=-1) == target) / te_batchsize
te_accs.append(acc)
targets.append(target)
te_acc = (sum(te_accs) * te_batchsize / num_te_data).numpy()
print("test acc: %.6f" % te_acc)