diff --git a/keras_unet_collection/layer_utils.py b/keras_unet_collection/layer_utils.py index c9197d8..3f4e4f6 100644 --- a/keras_unet_collection/layer_utils.py +++ b/keras_unet_collection/layer_utils.py @@ -8,6 +8,7 @@ from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Lambda from tensorflow.keras.layers import BatchNormalization, Activation, concatenate, multiply, add from tensorflow.keras.layers import ReLU, LeakyReLU, PReLU, ELU, Softmax +from tensorflow.keras.layers import Dropout def decode_layer(X, channel, pool_size, unpool, kernel_size=3, activation='ReLU', batch_norm=False, name='decode'): @@ -196,7 +197,7 @@ def attention_gate(X, g, channel, def CONV_stack(X, channel, kernel_size=3, stack_num=2, dilation_rate=1, activation='ReLU', - batch_norm=False, name='conv_stack'): + batch_norm=False, dropout=False, dropout_rate=0.5, name='conv_stack'): ''' Stacked convolutional layers: (Convolutional layer --> batch normalization --> Activation)*stack_num @@ -241,6 +242,11 @@ def CONV_stack(X, channel, kernel_size=3, stack_num=2, activation_func = eval(activation) X = activation_func(name='{}_{}_activation'.format(name, i))(X) + #dropout + if dropout: + X = Dropout(rate=dropout_rate)(X) + + return X def Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv'):