-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathkeras_model.py
58 lines (49 loc) · 2.16 KB
/
keras_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#
# The SELDnet architecture
#
from keras.layers import Bidirectional, Conv2D, MaxPooling2D, Input, MaxPooling3D, Conv3D, merge
from keras.layers.core import Dense, Activation, Dropout, Reshape, Permute
from keras.layers.recurrent import GRU
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.layers.wrappers import TimeDistributed
from keras.optimizers import Adam
import keras
keras.backend.set_image_data_format('channels_first')
from IPython import embed
def get_model(data_in, data_out, dropout_rate, nb_cnn2d_filt, pool_size,
rnn_size, fnn_size, classification_mode, weights):
# model definition
spec_start = Input(shape=(data_in[-3], data_in[-2], data_in[-1]))
spec_cnn = spec_start
for i, convCnt in enumerate(pool_size):
spec_cnn = Conv2D(filters=nb_cnn2d_filt, kernel_size=(3, 3), padding='same')(spec_cnn)
spec_cnn = BatchNormalization()(spec_cnn)
spec_cnn = Activation('relu')(spec_cnn)
spec_cnn = MaxPooling2D(pool_size=(1, pool_size[i]))(spec_cnn)
spec_cnn = Dropout(dropout_rate)(spec_cnn)
spec_cnn = Permute((2, 1, 3))(spec_cnn)
spec_rnn = Reshape((data_in[-2], -1))(spec_cnn)
for nb_rnn_filt in rnn_size:
spec_rnn = Bidirectional(
GRU(nb_rnn_filt, activation='tanh', dropout=dropout_rate, recurrent_dropout=dropout_rate,
return_sequences=True),
merge_mode='mul'
)(spec_rnn)
doa = spec_rnn
for nb_fnn_filt in fnn_size:
doa = TimeDistributed(Dense(nb_fnn_filt))(doa)
doa = Dropout(dropout_rate)(doa)
doa = TimeDistributed(Dense(data_out[1][-1]))(doa)
doa = Activation('tanh', name='doa_out')(doa)
# SED
sed = spec_rnn
for nb_fnn_filt in fnn_size:
sed = TimeDistributed(Dense(nb_fnn_filt))(sed)
sed = Dropout(dropout_rate)(sed)
sed = TimeDistributed(Dense(data_out[0][-1]))(sed)
sed = Activation('sigmoid', name='sed_out')(sed)
model = Model(inputs=spec_start, outputs=[sed, doa])
model.compile(optimizer=Adam(), loss=['binary_crossentropy', 'mse'], loss_weights=weights)
model.summary()
return model