-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcapsule_lib.py
135 lines (103 loc) · 4.52 KB
/
capsule_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
'''
Implementation of fundamental parts of CapsNET https://arxiv.org/abs/1710.09829
'''
import tensorflow as tf
from tensorflow import keras
class RoutingByAgreement():
'''
Implementation of routing by agreement algorithm using tf.while_loop function
'''
def __init__(self,iterations):
self.iterations=iterations
def stop_condition(self,input,counter):
return tf.less(counter,self.iterations)
def loop_body(self,input,counter):
return input # TODO
def __call__(self,input):
with tf.name_scope("routing_by_agreement_operation"):
counter = 0
result = tf.while_loop(self.stop_condition,self.loop_body,[input,counter])
return result
# class CapsuleNetwork(keras.Model):
# def __init__(
# self,
# input_size,
# no_of_conv_kernels,
# conv_strides,
# )
# self.epsilon=1e-7
# self.input_size=input_size
# self.number_of_conv_kernels
@tf.function
def squash(s, axis=-1,epsilon=1e-7):
'''
Activation function for capsule.
For more information, see for example:
- https://arxiv.org/abs/1710.09829
- https://pechyonkin.me/capsules-2/
- https://www.kaggle.com/code/giovanimachado/capsnet-tensorflow-implementation
'''
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keepdims=True)
safe_norm = tf.sqrt(squared_norm + epsilon)
squash_factor = squared_norm / (1. + squared_norm)
unit_vector = s / safe_norm
return squash_factor * unit_vector
class SecondaryCapsule(keras.layers.Layer):
# 'capsules','affine_transform_matrix','routing_weights','capsules_input','features', 'input_spec'
def __init__(
self,
capsules,
transformation_initializer="glorot_uniform",
routing_initializer="zeros",
transformation_regularizer=None,
routing_regularizer=None,
activity_regularizer=None,
transformation_constraint=None,
routing_constraint=None,
**kwargs
):
super().__init__(activity_regularizer=activity_regularizer,**kwargs)
self.capsules = capsules
self.transformation_initializer = transformation_initializer
self.routing_initializer = routing_initializer
self.transformation_regularizer = transformation_regularizer
self.routing_regularizer = routing_regularizer
self.activity_regularizer = activity_regularizer
self.transformation_constraint = transformation_constraint
self.routing_constraint = routing_constraint
def build(self,input_shape):
if(len(input_shape)!=4):
raise Exception("Invalid dimension of input shape")
self.capsules_input = input_shape[2]
self.features = input_shape[3]
print("Running with batch_size= ",input_shape[0])
self.input_spec = keras.layers.InputSpec(shape=(None,1,self.capsules_input,self.features))
self.affine_transform_matrix = self.add_weight(
"affine_transform_matrix",
shape = [], # TODO
initializer = self.transformation_initializer,
regularizer = self.transformation_regularizer,
constraint = self.transformation_constraint,
trainable = True
)
self.routing_weights = self.add_weight(
"routing_weights",
shape = [self.capsules,self.capsules_input],
initializer = self.transformation_initializer,
regularizer = self.transformation_regularizer,
constraint = self.transformation_constraint,
trainable = True
)
def call(self,inputs):
pass # TODO
def get_config(self):
config = super().get_config()
config.update(
{
"capsules" : self.capsules,
"initializer" : tf.initializers.serialize (self.transformation_initializer),
"regularizer" : tf.regularizers.serialize (self.transformation_regularizer),
"constraint" : tf.constraints.serialize (self.transformation_constraint)
}
)