-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathnet.py
116 lines (107 loc) · 5.1 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch.nn as nn
import config
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# TODO: modify padding
self.conv1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
self.relu1_1 = nn.ReLU()
self.conv1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.relu1_2 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.relu2_1 = nn.ReLU()
self.conv2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.relu2_2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.conv3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
self.relu3_1 = nn.ReLU()
self.conv3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.relu3_2 = nn.ReLU()
self.conv3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.relu3_3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(2)
self.conv4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
self.relu4_1 = nn.ReLU()
self.conv4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.relu4_2 = nn.ReLU()
self.conv4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.relu4_3 = nn.ReLU()
self.pool4 = nn.MaxPool2d(2)
self.conv5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.relu5_1 = nn.ReLU()
self.conv5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.relu5_2 = nn.ReLU()
self.conv5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.relu5_3 = nn.ReLU()
self.pool5 = nn.MaxPool2d(kernel_size=[3, 3], stride=1, padding=1)
if config.dilation:
self.conv6 = nn.Conv2d(512, 1024, 3, stride=1, padding=6, dilation=6)
else:
self.conv6 = nn.Conv2d(512, 1024, 3, stride=1, padding=1)
self.relu6 = nn.ReLU()
self.conv7 = nn.Conv2d(1024, 1024, 1, stride=1, padding=0)
self.relu7 = nn.ReLU()
self.out1_1 = nn.Conv2d(128, 2, 1, stride=1, padding=0)
self.out1_2 = nn.Conv2d(128, 16, 1, stride=1, padding=0)
self.out2_1 = nn.Conv2d(256, 2, 1, stride=1, padding=0)
self.out2_2 = nn.Conv2d(256, 16, 1, stride=1, padding=0)
self.out3_1 = nn.Conv2d(512, 2, 1, stride=1, padding=0)
self.out3_2 = nn.Conv2d(512, 16, 1, stride=1, padding=0)
self.out4_1 = nn.Conv2d(512, 2, 1, stride=1, padding=0)
self.out4_2 = nn.Conv2d(512, 16, 1, stride=1, padding=0)
self.out5_1 = nn.Conv2d(1024, 2, 1, stride=1, padding=0)
self.out5_2 = nn.Conv2d(1024, 16, 1, stride=1, padding=0)
self.final_1 = nn.Conv2d(2, 2, 1, stride=1, padding=0)
self.final_2 = nn.Conv2d(16, 16, 1, stride=1, padding=0)
def forward(self, x):
# print("forward1")
x = self.pool1(self.relu1_2(self.conv1_2(self.relu1_1(self.conv1_1(x)))))
# print("forward11")
x = self.relu2_2(self.conv2_2(self.relu2_1(self.conv2_1(x))))
# print("forward12")
l1_1x = self.out1_1(x)
# print("forward13")
l1_2x = self.out1_2(x)
# print("forward14")
x = self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(self.relu3_1(self.conv3_1(self.pool2(x)))))))
# print("forward15")
l2_1x = self.out2_1(x)
# print("forward16")
l2_2x = self.out2_2(x)
# print("forward17")
x = self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(self.relu4_1(self.conv4_1(self.pool3(x)))))))
l3_1x = self.out3_1(x)
l3_2x = self.out3_2(x)
x = self.relu5_3(self.conv5_3(self.relu5_2(self.conv5_2(self.relu5_1(self.conv5_1(self.pool4(x)))))))
l4_1x = self.out4_1(x)
l4_2x = self.out4_2(x)
x = self.relu7(self.conv7(self.relu6(self.conv6(self.pool5(x)))))
l5_1x = self.out5_1(x)
l5_2x = self.out5_2(x)
# print("forward3")
upsample1_1 = nn.functional.upsample(l5_1x + l4_1x, scale_factor=2, mode="bilinear", align_corners=True)
upsample2_1 = nn.functional.upsample(upsample1_1 + l3_1x, scale_factor=2, mode="bilinear", align_corners=True)
if config.version == "2s":
upsample3_1 = nn.functional.upsample(upsample2_1 + l2_1x, scale_factor=2, mode="bilinear", align_corners=True)
out_1 = upsample3_1 + l1_1x
else:
out_1 = upsample2_1 + l2_1x
# out_1 = self.final_1(out_1)
# print("forward4")
upsample1_2 = nn.functional.upsample(l5_2x + l4_2x, scale_factor=2, mode="bilinear", align_corners=True)
upsample2_2 = nn.functional.upsample(upsample1_2 + l3_2x, scale_factor=2, mode="bilinear", align_corners=True)
if config.version == "2s":
upsample3_2 = nn.functional.upsample(upsample2_2 + l2_2x, scale_factor=2, mode="bilinear", align_corners=True)
out_2 = upsample3_2 + l1_2x
else:
out_2 = upsample2_2 + l2_2x
# out_2 = self.final_2(out_2)
# print("forward5")
return [out_1, out_2]
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features