-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCapsuleNet.py
118 lines (91 loc) · 4.89 KB
/
CapsuleNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers, models
from keras.utils import to_categorical
class Length(layers.Layer):
def call(self, inputs, **kwargs):
return K.sqrt(K.sum(K.square(inputs), -1))
def compute_output_shape(self, input_shape):
return input_shape[:-1]
class Mask(layers.Layer):
def call(self, inputs, **kwargs):
if type(inputs) is list:
assert len(inputs) == 2
inputs, mask = inputs
else:
x = inputs
x = (x - K.max(x, 1, True)) / K.epsilon() + 1
mask = K.clip(x, 0, 1)
inputs_masked = K.batch_dot(inputs, mask, [1, 1])
return inputs_masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple:
return tuple([None, input_shape[0][-1]])
else:
return tuple([None, input_shape[-1]])
def squash(vectors, axis=-1):
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm)
return scale * vectors
class CapsuleLayer(layers.Layer):
def __init__(self, num_capsule, dim_vector, num_routing=3,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
**kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_vector = dim_vector
self.num_routing = num_routing
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
def build(self, input_shape):
#assert len(input_shape) >= 3,
self.input_num_capsule = input_shape[1]
self.input_dim_vector = input_shape[2]
self.W = self.add_weight(shape=[self.input_num_capsule, self.num_capsule, self.input_dim_vector, self.dim_vector],
initializer=self.kernel_initializer,
name='W')
self.bias = self.add_weight(shape=[1, self.input_num_capsule, self.num_capsule, 1, 1],
initializer=self.bias_initializer,
name='bias',
trainable=False)
self.built = True
def call(self, inputs, training=None):
inputs_expand = K.expand_dims(K.expand_dims(inputs, 2), 2)
inputs_tiled = K.tile(inputs_expand, [1, 1, self.num_capsule, 1, 1])
inputs_hat = tf.scan(lambda ac, x: K.batch_dot(x, self.W, [3, 2]),
elems=inputs_tiled,
initializer=K.zeros([self.input_num_capsule, self.num_capsule, 1, self.dim_vector]))
assert self.num_routing > 0, 'The num_routing should be > 0.'
for i in range(self.num_routing):
c = tf.nn.softmax(self.bias, dim=2)
outputs = squash(K.sum(c * inputs_hat, 1, keepdims=True))
if i != self.num_routing - 1:
self.bias += K.sum(inputs_hat * outputs, -1, keepdims=True)
return K.reshape(outputs, [-1, self.num_capsule, self.dim_vector])
def compute_output_shape(self, input_shape):
return tuple([None, self.num_capsule, self.dim_vector])
def PrimaryCap(inputs, dim_vector, n_channels, kernel_size, strides, padding):
output = layers.Conv2D(filters=dim_vector*n_channels, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
outputs = layers.Reshape(target_shape=[-1, dim_vector])(output)
return layers.Lambda(squash)(outputs)
def CapsNet(input_shape, n_class, num_routing):
x = layers.Input(shape=input_shape)
conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
primarycaps = PrimaryCap(conv1, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
digitcaps = CapsuleLayer(num_capsule=n_class, dim_vector=16, num_routing=num_routing, name='digitcaps')(primarycaps)
out_caps = Length(name='out_caps')(digitcaps)
y = layers.Input(shape=(n_class,))
masked = Mask()([digitcaps, y])
x_recon = layers.Dense(512, activation='relu')(masked)
x_recon = layers.Dense(1024, activation='relu')(x_recon)
x_recon = layers.Dense(width*breadth*3, activation='sigmoid')(x_recon)
x_recon = layers.Reshape(target_shape=[width, breadth, 3], name='out_recon')(x_recon)
return models.Model([x, y], [out_caps, x_recon])
def margin_loss(y_true, y_pred):
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
return K.mean(K.sum(L, 1))