-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_unet.py
58 lines (48 loc) · 4.18 KB
/
model_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Tuple, List, Text, Dict, Any, Iterator
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Activation, Dropout, Flatten, Dense, Reshape
from keras.layers.merge import Concatenate
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.utils import plot_model
def create_unet(in_shape: Tuple[int,int,int], output_ch: int, filters: int, ker_init: str="glorot_uniform") -> Model:
'''
reference models
* https://github.com/phillipi/pix2pix/blob/master/models.lua#L47
* https://github.com/tdeboissiere/DeepLearningImplementations/blob/master/pix2pix/src/model/models.py#L317
'''
input_tensor = Input(shape=in_shape) # type: Input
# enc
x = Conv2D( filters*1, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( input_tensor ) ; e1 = x
x = BatchNormalization()( Conv2D( filters*2, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e2 = x
x = BatchNormalization()( Conv2D( filters*4, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e3 = x
x = BatchNormalization()( Conv2D( filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e4 = x
x = BatchNormalization()( Conv2D( filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e5 = x
x = BatchNormalization()( Conv2D( filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e6 = x
x = BatchNormalization()( Conv2D( filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ); e7 = x
x = Conv2D( filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( LeakyReLU(0.2)(x) ) ; e8 = x
# dec
x = BatchNormalization()( Conv2DTranspose(filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([Dropout(0.5)(x), e7])
x = BatchNormalization()( Conv2DTranspose(filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([Dropout(0.5)(x), e6])
x = BatchNormalization()( Conv2DTranspose(filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([Dropout(0.5)(x), e5])
x = BatchNormalization()( Conv2DTranspose(filters*8, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([x, e4])
x = BatchNormalization()( Conv2DTranspose(filters*4, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([x, e3])
x = BatchNormalization()( Conv2DTranspose(filters*2, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([x, e2])
x = BatchNormalization()( Conv2DTranspose(filters*1, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) ) ); x = Concatenate()([x, e1])
x = Conv2DTranspose(output_ch, kernel_size=(4, 4), strides=(2, 2), padding="same", kernel_initializer=ker_init)( Activation("relu")(x) )
if output_ch == 1:
# for dice_coef
x = Reshape((in_shape[0], in_shape[1]))(x)
#x = Activation("tanh")(x)
x = Activation('sigmoid')(x)
else:
x = Activation("softmax")(x)
unet = Model(inputs=[input_tensor], outputs=[x])
return unet
if __name__ == '__main__':
unet = create_unet((512, 512, 3), 1, 64, "he_normal")
unet.summary()
plot_model(unet, to_file='unet.png', show_shapes=True, show_layer_names=True)
exit()