-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathlenet.py
104 lines (77 loc) · 2.28 KB
/
lenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch.nn as nn
from collections import OrderedDict
class C1(nn.Module):
def __init__(self):
super(C1, self).__init__()
self.c1 = nn.Sequential(OrderedDict([
('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
('relu1', nn.ReLU()),
('s1', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
]))
def forward(self, img):
output = self.c1(img)
return output
class C2(nn.Module):
def __init__(self):
super(C2, self).__init__()
self.c2 = nn.Sequential(OrderedDict([
('c2', nn.Conv2d(6, 16, kernel_size=(5, 5))),
('relu2', nn.ReLU()),
('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
]))
def forward(self, img):
output = self.c2(img)
return output
class C3(nn.Module):
def __init__(self):
super(C3, self).__init__()
self.c3 = nn.Sequential(OrderedDict([
('c3', nn.Conv2d(16, 120, kernel_size=(5, 5))),
('relu3', nn.ReLU())
]))
def forward(self, img):
output = self.c3(img)
return output
class F4(nn.Module):
def __init__(self):
super(F4, self).__init__()
self.f4 = nn.Sequential(OrderedDict([
('f4', nn.Linear(120, 84)),
('relu4', nn.ReLU())
]))
def forward(self, img):
output = self.f4(img)
return output
class F5(nn.Module):
def __init__(self):
super(F5, self).__init__()
self.f5 = nn.Sequential(OrderedDict([
('f5', nn.Linear(84, 10)),
('sig5', nn.LogSoftmax(dim=-1))
]))
def forward(self, img):
output = self.f5(img)
return output
class LeNet5(nn.Module):
"""
Input - 1x32x32
Output - 10
"""
def __init__(self):
super(LeNet5, self).__init__()
self.c1 = C1()
self.c2_1 = C2()
self.c2_2 = C2()
self.c3 = C3()
self.f4 = F4()
self.f5 = F5()
def forward(self, img):
output = self.c1(img)
x = self.c2_1(output)
output = self.c2_2(output)
output += x
output = self.c3(output)
output = output.view(img.size(0), -1)
output = self.f4(output)
output = self.f5(output)
return output