Skip to content

Commit

Permalink
DBM implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinAchondo committed Jul 2, 2024
1 parent 237e5d2 commit 9fb6cdc
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 18 deletions.
11 changes: 2 additions & 9 deletions xppbe/Mesh/Mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,19 +308,12 @@ def adapt_meshes_domain(self,data,q_list):
X_plot[f'{type_b}_verts'] = X.numpy()
self.save_data_plot(X_plot)

elif type_b in ('Iu','Id','Ir'):
elif type_b in ('Iu','Id','Ir','IB1','IB2'):
N = self.mol_verts_normal
X = tf.constant(self.mol_verts, dtype=self.DTYPE)
X_I = (X, N)
self.domain_mesh_names.add(type_b)
self.domain_mesh_data['I'] = (X_I,flag)

elif 'IR' in type_b:
N = self.mol_verts_normal
X = tf.constant(self.mol_verts, dtype=self.DTYPE)
X_I = (X, None)
self.domain_mesh_names.add(type_b.replace('I',''))
self.domain_mesh_data[type_b.replace('I','')] = (X_I,flag)
self.domain_mesh_data['I'] = (X_I,flag)

elif type_b in ('G'):
self.domain_mesh_names.add(type_b)
Expand Down
26 changes: 24 additions & 2 deletions xppbe/Model/Equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,36 @@ def __init__(self,*args,**kwargs):
def get_phi(self,X,flag,model,value='phi'):
phi = model(X,flag)
phi_interface = tf.reshape(phi[:,0]+phi[:,1],(-1,1))/2
if flag != 'interface':
vertices = self.grid.vertices
faces_normals = self.grid.normals
elements = self.grid.elements
centroids = np.zeros((3, elements.shape[1]))
for i, element in enumerate(elements.T):
centroids[:, i] = np.mean(vertices[:, element], axis=1)
X = tf.reshape(tf.constant(centroids.transpose(),dtype=self.DTYPE), (-1,3))
Nv = tf.reshape(tf.constant(faces_normals.transpose(),dtype=self.DTYPE), (-1,3))
phi = model(X,'interface')
phi_mean = phi[:,0]
x = self.mesh.get_X(X)
nv = self.mesh.get_X(Nv)
u_interface = phi_mean.numpy().flatten()
du_1 = self.directional_gradient(self.mesh,model,x,nv,'interface',value='phi')
du_1_interface = du_1.numpy().flatten()
phi = self.bempp.GridFunction(self.space, coefficients=u_interface)
dphi = self.bempp.GridFunction(self.space, coefficients=du_1_interface)

if flag=='molecule':
slp = self.bempp.api.operators.potential.laplace.single_layer(self.neumann_space, X.numpy().transpose())
dlp = self.bempp.api.operators.potential.laplace.double_layer(self.dirichl_space, X.numpy().transpose())
phi = slp * dphi - dlp * phi
phi = phi.reshape(-1,1) + self.G(X)

elif flag=='solvent':
slp = self.bempp.api.operators.potential.helmholtz_modified.single_layer(self.neumann_space, X.numpy().transpose(),self.kappa)
dlp = self.bempp.api.operators.potential.helmholtz_modified.double_layer(self.dirichl_space, X.numpy().transpose(),self.kappa)
phi = slp * dphi - dlp * phi
phi = phi.reshape(-1,1)

elif flag=='interface':

Expand Down Expand Up @@ -480,8 +503,7 @@ def get_r(self,mesh,model,X,SU,flag):
integrand = (self.PBE.G_L(R,X_c)*dphi_i - self.PBE.dG_L(R,X_c)*phi_i)*self.areas
integral = tf.reduce_sum(integrand, axis=1, keepdims=True)
phi = self.PBE.get_phi(R,flag,model,value=self.field)
G = self.PBE.G(R) if SU==None else SU
r = 0.5*phi - G - integral
r = 0.5*phi - self.PBE.G(R) - integral
return r


Expand Down
11 changes: 11 additions & 0 deletions xppbe/Model/PDE_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ def get_loss(self, X_batches, model, validation=False):
loss_r = self.PDE_out.residual_loss(self.mesh,model,self.mesh.get_X(X),SU,flag)
L['R2'] += loss_r

if 'IB1' in X_batches:
((X,N),flag) = X_batches['I']
loss_r = self.PDE_in.residual_loss(self.mesh,model,self.mesh.get_X(X),N,flag)
L['IB1'] += loss_r

if 'IB2' in X_batches:
((X,N),flag) = X_batches['I']
loss_r = self.PDE_out.residual_loss(self.mesh,model,self.mesh.get_X(X),N,flag)
L['IB2'] += loss_r


if 'Q1' in X_batches:
((X,SU),flag) = X_batches['Q1']
loss_q = self.PDE_in.residual_loss(self.mesh,model,self.mesh.get_X(X),SU,flag)
Expand Down
10 changes: 5 additions & 5 deletions xppbe/NN/PINN_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class PINN_utils():
def __init__(self, results_path):

self.results_path = results_path
self.losses_names = ['TL','TL1','TL2','vTL1','vTL2','R1','D1','N1','K1','Q1','R2','D2','N2','K2','G','Iu','Id','Ir','E2','P1','P2']
self.losses_names_1 = ['TL1','R1','D1','N1','K1','Q1','Iu','Id','Ir','G','P1']
self.losses_names_2 = ['TL2','R2','D2','N2','K2','Iu','Id','Ir','E2','G','P2']
self.validation_names = ['TL','TL1','TL2','R1','D1','N1','Q1','R2','D2','N2','Iu','Id','Ir']
self.losses_names = ['TL','TL1','TL2','vTL1','vTL2','R1','D1','N1','K1','Q1','R2','D2','N2','K2','G','Iu','Id','Ir','E2','P1','P2','IB1','IB2']
self.losses_names_1 = ['TL1','R1','D1','N1','K1','Q1','Iu','Id','Ir','G','P1','IB1']
self.losses_names_2 = ['TL2','R2','D2','N2','K2','Iu','Id','Ir','E2','G','P2','IB2']
self.validation_names = ['TL','TL1','TL2','R1','D1','N1','Q1','R2','D2','N2','Iu','Id','Ir','IB1','IB2']
self.w_names = self.losses_names[5:]
self.losses_names_list = [self.losses_names_1,self.losses_names_2]

Expand Down Expand Up @@ -79,7 +79,7 @@ def set_points_methods(self, sample_method='batches', N_batches=1, sample_size=1
def adapt_datasets(self):
self.L_X_domain = dict()
for t in self.mesh.domain_mesh_names:
if t in ('Iu','Id','Ir'):
if t in ('Iu','Id','Ir','IB1','IB2'):
self.L_X_domain['I'] = self.mesh.domain_mesh_data['I']
else:
self.L_X_domain[t] = self.mesh.domain_mesh_data[t]
Expand Down
4 changes: 2 additions & 2 deletions xppbe/Simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def create_simulation(self):
meshes_domain['E2'] = {'domain': 'solvent', 'type': 'E2', 'file': f'experimental_data_{self.domain_properties["molecule"]}.dat', 'method': self.experimental_method}

elif self.pinns_method == 'DBM':
meshes_domain['R1'] = {'domain': 'interface', 'type':'IR1', 'fun':lambda X: self.PBE_model.G(X)}
meshes_domain['R2'] = {'domain': 'interface', 'type':'IR2', 'value':0.0}
meshes_domain['IB1'] = {'domain': 'interface', 'type':'IB1', 'value':0.0}
meshes_domain['IB2'] = {'domain': 'interface', 'type':'IB2', 'value':0.0}

if self.num_networks==2 and self.pinns_method!='DBM':
meshes_domain['Iu'] = {'domain':'interface', 'type':'Iu'}
Expand Down

0 comments on commit 9fb6cdc

Please sign in to comment.