-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathddrnet_23_slim.py
258 lines (217 loc) · 9.77 KB
/
ddrnet_23_slim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# ------------------------------------------------------------------------------
# Written by Hamid Ali (hamidriasat@gmail.com)
# ------------------------------------------------------------------------------
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models
from resnet import basic_block, bottleneck_block, make_layer
from resnet import basicblock_expansion, bottleneck_expansion
def DAPPPM(x_in, branch_planes, outplanes):
"""
Deep Aggregation Pyramid Pooling Module
:param x_in:
:param branch_planes:
:param outplanes:
:return:
"""
input_shape = tf.keras.backend.int_shape(x_in)
height = input_shape[1]
width = input_shape[2]
# Average pooling kernel size
kernal_sizes_height = [5, 9, 17, height]
kernal_sizes_width = [5, 9, 17, width]
# Average pooling strides size
stride_sizes_height = [2, 4, 8, height]
stride_sizes_width = [2, 4, 8, width]
x_list = []
# y1
scale0 = layers.BatchNormalization()(x_in)
scale0 = layers.Activation("relu")(scale0)
scale0 = layers.Conv2D(branch_planes, kernel_size=(1, 1), use_bias=False, )(scale0)
x_list.append(scale0)
for i in range(len(kernal_sizes_height)):
# first apply average pooling
temp = layers.AveragePooling2D(pool_size=(kernal_sizes_height[i], kernal_sizes_width[i]),
strides=(stride_sizes_height[i], stride_sizes_width[i]),
padding="same")(x_in)
temp = layers.BatchNormalization()(temp)
temp = layers.Activation("relu")(temp)
# then apply 1*1 conv
temp = layers.Conv2D(branch_planes, kernel_size=(1, 1), use_bias=False, )(temp)
# then resize using bilinear
temp = tf.image.resize(temp, size=(height, width), )
# add current and previous layer output
temp = layers.Add()([temp, x_list[i]])
temp = layers.BatchNormalization()(temp)
temp = layers.Activation("relu")(temp)
# at the end apply 3*3 conv
temp = layers.Conv2D(branch_planes, kernel_size=(3, 3), use_bias=False, padding="same")(temp)
# y[i+1]
x_list.append(temp)
# concatenate all
combined = layers.concatenate(x_list, axis=-1)
combined = layers.BatchNormalization()(combined)
combined = layers.Activation("relu")(combined)
combined = layers.Conv2D(outplanes, kernel_size=(1, 1), use_bias=False, )(combined)
shortcut = layers.BatchNormalization()(x_in)
shortcut = layers.Activation("relu")(shortcut)
shortcut = layers.Conv2D(outplanes, kernel_size=(1, 1), use_bias=False, )(shortcut)
# final = combined + shortcut
final = layers.Add()([combined, shortcut])
return final
def segmentation_head(x_in, interplanes, outplanes, scale_factor=None):
"""
Segmentation head
3*3 -> 1*1 -> rescale
:param x_in:
:param interplanes:
:param outplanes:
:param scale_factor:
:return:
"""
x = layers.BatchNormalization()(x_in)
x = layers.Activation("relu")(x)
x = layers.Conv2D(interplanes, kernel_size=(3, 3), use_bias=False, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(outplanes, kernel_size=(1, 1), use_bias=range, padding="valid")(x)
if scale_factor is not None:
input_shape = tf.keras.backend.int_shape(x)
height2 = input_shape[1] * scale_factor
width2 = input_shape[2] * scale_factor
x = tf.image.resize(x, size=(height2, width2))
return x
def ddrnet_23_slim(input_shape=[1024, 2048, 3], layers_arg=[2, 2, 2, 2], num_classes=19, planes=32, spp_planes=128,
head_planes=64, scale_factor=8, augment=False):
"""
ddrnet 23 slim
:param input_shape: shape of input data
:param layers_arg: how many times each Rb block is repeated
:param num_classes: output classes
:param planes: filter size kept throughout model
:param spp_planes: DAPPM block output dimensions
:param head_planes: segmentation head dimensions
:param scale_factor: scale output factor
:param augment: whether auxiliary loss is added or not
:return:
"""
x_in = layers.Input(input_shape)
highres_planes = planes * 2
input_shape = tf.keras.backend.int_shape(x_in)
height_output = input_shape[1] // 8
width_output = input_shape[2] // 8
layers_inside = []
# 1 -> 1/2 first conv layer
x = layers.Conv2D(planes, kernel_size=(3, 3), strides=2, padding='same')(x_in)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
# 1/2 -> 1/4 second conv layer
x = layers.Conv2D(planes, kernel_size=(3, 3), strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
# layer 1
# 1/4 -> 1/4 first basic residual block not mentioned in the image
x = make_layer(x, basic_block, planes, planes, layers_arg[0], expansion=basicblock_expansion)
layers_inside.append(x)
# layer 2
# 2 High :: 1/4 -> 1/8 storing results at index:1
x = layers.Activation("relu")(x)
x = make_layer(x, basic_block, planes, planes * 2, layers_arg[1], stride=2, expansion=basicblock_expansion)
layers_inside.append(x)
"""
For next layers
x: low branch
x_: high branch
"""
# layer 3
# 3 Low :: 1/8 -> 1/16 storing results at index:2
x = layers.Activation("relu")(x)
x = make_layer(x, basic_block, planes * 2, planes * 4, layers_arg[2], stride=2, expansion=basicblock_expansion)
layers_inside.append(x)
# 3 High :: 1/8 -> 1/8 retrieving from index:1
x_ = layers.Activation("relu")(layers_inside[1])
x_ = make_layer(x_, basic_block, planes * 2, highres_planes, 2, expansion=basicblock_expansion)
# Fusion 1
# x -> 1/16 to 1/8, x_ -> 1/8 to 1/16
# High to Low
x_temp = layers.Activation("relu")(x_)
x_temp = layers.Conv2D(planes * 4, kernel_size=(3, 3), strides=2, padding='same', use_bias=False)(x_temp)
x_temp = layers.BatchNormalization()(x_temp)
x = layers.Add()([x, x_temp])
# Low to High
x_temp = layers.Activation("relu")(layers_inside[2])
x_temp = layers.Conv2D(highres_planes, kernel_size=(1, 1), use_bias=False)(x_temp)
x_temp = layers.BatchNormalization()(x_temp)
x_temp = tf.image.resize(x_temp, (height_output, width_output)) # 1/16 -> 1/8
x_ = layers.Add()([x_, x_temp]) # next high branch input, 1/8
if augment:
temp_output = x_ # Auxiliary loss from high branch
# layer 4
# 4 Low :: 1/16 -> 1/32 storing results at index:3
x = layers.Activation("relu")(x)
x = make_layer(x, basic_block, planes * 4, planes * 8, layers_arg[3], stride=2, expansion=basicblock_expansion)
layers_inside.append(x)
# 4 High :: 1/8 -> 1/8
x_ = layers.Activation("relu")(x_)
x_ = make_layer(x_, basic_block, highres_planes, highres_planes, 2, expansion=basicblock_expansion)
# Fusion 2 :: x_ -> 1/32 to 1/8, x -> 1/8 to 1/32 using two conv's
# High to low
x_temp = layers.Activation("relu")(x_)
x_temp = layers.Conv2D(planes * 4, kernel_size=(3, 3), strides=2, padding='same', use_bias=False)(x_temp)
x_temp = layers.BatchNormalization()(x_temp)
x_temp = layers.Activation("relu")(x_temp)
x_temp = layers.Conv2D(planes * 8, kernel_size=(3, 3), strides=2, padding='same', use_bias=False)(x_temp)
x_temp = layers.BatchNormalization()(x_temp)
x = layers.Add()([x, x_temp])
# Low to High
x_temp = layers.Activation("relu")(layers_inside[3])
x_temp = layers.Conv2D(highres_planes, kernel_size=(1, 1), use_bias=False)(x_temp)
x_temp = layers.BatchNormalization()(x_temp)
x_temp = tf.image.resize(x_temp, (height_output, width_output))
x_ = layers.Add()([x_, x_temp])
# layer 5
# 5 High :: 1/8 -> 1/8
x_ = layers.Activation("relu")(x_)
x_ = make_layer(x_, bottleneck_block, highres_planes, highres_planes, 1, expansion=bottleneck_expansion)
x = layers.Activation("relu")(x)
# 5 Low :: 1/32 -> 1/64
x = make_layer(x, bottleneck_block, planes * 8, planes * 8, 1, stride=2, expansion=bottleneck_expansion)
# Deep Aggregation Pyramid Pooling Module
x = DAPPPM(x, spp_planes, planes * 4)
# resize from 1/64 to 1/8
x = tf.image.resize(x, (height_output, width_output))
x_ = layers.Add()([x, x_])
x_ = segmentation_head((x_), head_planes, num_classes, scale_factor)
# apply softmax at the output layer
x_ = tf.nn.softmax(x_)
if augment:
x_extra = segmentation_head(temp_output, head_planes, num_classes, scale_factor) # without scaling
x_extra = tf.nn.softmax(x_extra)
model_output = [x_, x_extra]
else:
model_output = x_
model = models.Model(inputs=[x_in], outputs=[model_output])
# set weight initializers
for layer in model.layers:
if hasattr(layer, 'kernel_initializer'):
layer.kernel_initializer = tf.keras.initializers.he_normal()
if hasattr(layer, 'depthwise_initializer'):
layer.depthwise_initializer = tf.keras.initializers.he_normal()
return model
if __name__ == "__main__":
"""## Model Compilation"""
INPUT_SHAPE = [1024, 2048, 3]
OUTPUT_CHANNELS = 19
with tf.device("cpu:0"):
# create model
ddrnet_model = ddrnet_23_slim(num_classes=OUTPUT_CHANNELS, input_shape=INPUT_SHAPE, )
optimizer = tf.keras.optimizers.SGD(momentum=0.9, lr=0.045)
# compile model
ddrnet_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False), optimizer=optimizer,
metrics=['accuracy'])
# show model summary in output
ddrnet_model.summary()
# save model architecture as png
# tf.keras.utils.plot_model(ddrnet_model, show_layer_names=True, show_shapes=True)
# save model
# ddrnet_model.save("temp.hdf5")