Skip to content

Commit

Permalink
order
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinAchondo committed Jul 5, 2024
1 parent cc9ba93 commit 6fb118c
Showing 1 changed file with 95 additions and 92 deletions.
187 changes: 95 additions & 92 deletions xppbe/NN/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def __init__(self,
self.ub = tf.constant(scale[1])

if not self.weight_factorization:
Dense_Layer = tf.keras.layers.Dense
self.Dense_Layer = tf.keras.layers.Dense
elif self.weight_factorization:
Dense_Layer = CustomDenseLayer
self.Dense_Layer = CustomDenseLayer

# Scale layer
if self.scale_input:
self.scale = tf.keras.layers.Lambda(
lambda x: 2.0 * (x - self.lb) / (self.ub - self.lb) - 1.0,
name=f'scale_layer')
lambda x: 2.0 * (x - self.lb) / (self.ub - self.lb) - 1.0,
name=f'scale_layer')

# Fourier feature layer
if self.use_fourier_features:
Expand All @@ -106,125 +106,122 @@ def __init__(self,
trainable=False,
kernel_initializer=tf.initializers.RandomNormal(stddev=self.fourier_sigma),
name='fourier_features'))
class SinCosLayer(tf.keras.layers.Layer):
def call(self, Z):
return tf.concat([tf.sin(2.0*np.pi*Z), tf.cos(2.0*np.pi*Z)], axis=-1)
self.fourier_features.add(SinCosLayer(name='fourier_sincos_layer'))

# FCNN or ModMLP architectures
if self.architecture_Net in ('FCNN','ModMLP','MLP'):
self.hidden_layers = list()
for i in range(self.num_hidden_layers):
layer = Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_{i}')
self.hidden_layers.append(layer)

# ModMLP architecture
if self.architecture_Net == 'ModMLP':
self.U = Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_u')
self.V = Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_v')
if self.architecture_Net in ('ModMLP','MLP'):
self.create_FCNN()

elif self.architecture_Net == 'ModMLP':
self.create_ModMLP()

# ResNet architecture
elif self.architecture_Net == 'ResNet':
self.first = Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_0')
self.hidden_blocks = list()
self.hidden_blocks_activations = list()
for i in range(self.num_hidden_blocks):
block = tf.keras.Sequential(name=f"block_{i}")
block.add(Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer))
block.add(Dense_Layer(self.num_neurons_per_layer,
activation=None,
kernel_initializer=self.kernel_initializer))
self.hidden_blocks.append(block)
activation_layer = tf.keras.layers.Activation(activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation))
self.hidden_blocks_activations.append(activation_layer)

self.last = Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_1')
self.create_ResNet()

# Output layer
self.out = Dense_Layer(output_dim, name=f'output_layer')
self.out = self.Dense_Layer(output_dim,
activation=None,
use_bias=False,
name=f'output_layer')

# Scale output layer
self.scale_out = tf.keras.layers.Lambda(
lambda x: x*self.scale_NN,
name=f'scale_output_layer')
lambda x: x*self.scale_NN,
name=f'scale_output_layer')


def create_FCNN(self):
self.hidden_layers = list()
for i in range(self.num_hidden_layers):
layer = self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_{i}')
self.hidden_layers.append(layer)
self.call_architecture = self.call_FCNN

def create_ModMLP(self):
self.create_FCNN()
self.U = self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_u')
self.V = self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_v')
self.call_architecture = self.call_ModMLP


def create_ResNet(self):
self.first = self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_0')
self.hidden_blocks = list()
self.hidden_blocks_activations = list()
for i in range(self.num_hidden_blocks):
block = tf.keras.Sequential(name=f"block_{i}")
block.add(self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer))
block.add(self.Dense_Layer(self.num_neurons_per_layer,
activation=None,
kernel_initializer=self.kernel_initializer))
self.hidden_blocks.append(block)
activation_layer = tf.keras.layers.Activation(activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation))
self.hidden_blocks_activations.append(activation_layer)

self.last = self.Dense_Layer(self.num_neurons_per_layer,
activation=CustomActivation(units=self.num_neurons_per_layer,
activation=self.activation,
adaptative_activation=self.adaptative_activation),
kernel_initializer=self.kernel_initializer,
name=f'layer_1')
self.call_architecture = self.call_ResNet


def build_Net(self):
self.build(self.input_shape_N)

def call(self, X):
if self.architecture_Net in ('FCNN','MLP'):
return self.call_FCNN(X)
elif self.architecture_Net == 'ModMLP':
return self.call_ModMLP(X)
elif self.architecture_Net == 'ResNet':
return self.call_ResNet(X)

# Call NeuralNet functions with the desired architecture

def call_FCNN(self, X):
if self.scale_input:
X = self.scale(X)
if self.use_fourier_features:
X = self.fourier_features(X)
for layer in self.hidden_layers:
X = layer(X)
X = self.call_architecture(X)
X = self.out(X)
return self.scale_out(X)

def call_FCNN(self, X):
for layer in self.hidden_layers:
X = layer(X)
return X

def call_ModMLP(self, X):
if self.scale_input:
X = self.scale(X)
if self.use_fourier_features:
X = self.fourier_features(X)
U = self.U(X)
V = self.V(X)
for layer in self.hidden_layers:
X = layer(X)*U + (1-layer(X))*V
X = self.out(X)
return self.scale_out(X)
return X

def call_ResNet(self, X):
if self.scale_input:
X = self.scale(X)
if self.use_fourier_features:
X = self.fourier_features(X)
def call_ResNet(self, X):
X = self.first(X)
for block,activation in zip(self.hidden_blocks,self.hidden_blocks_activations):
X = activation(block(X) + X)
X = self.last(X)
X = self.out(X)
return self.scale_out(X)
return self.last(X)


class CustomActivation(tf.keras.layers.Layer):
Expand All @@ -246,6 +243,12 @@ def call(self, inputs):
activation_func = tf.keras.activations.get(self.activation)
return activation_func(inputs * a_expanded)


class SinCosLayer(tf.keras.layers.Layer):
def call(self, Z):
return tf.concat([tf.sin(2.0*np.pi*Z), tf.cos(2.0*np.pi*Z)], axis=-1)


class CustomDenseLayer(tf.keras.layers.Layer):

def __init__(self, units, activation=None, kernel_initializer='glorot_normal', **kwargs):
Expand Down

0 comments on commit 6fb118c

Please sign in to comment.