Skip to content

Commit

Permalink
changes postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinAchondo committed Jun 20, 2024
1 parent 5d5c878 commit 87552e6
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 138 deletions.
14 changes: 7 additions & 7 deletions tests/test_xppbe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_xppbe_solver(molecule):
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=False, pbj=False)
sim.postprocessing(run_all=True)
run_checkers(sim,sim_name,temp_dir)


Expand All @@ -100,7 +100,7 @@ def test_additional_losses(loss):
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=False, pbj=False)
sim.postprocessing(run_all=True)
run_checkers(sim,sim_name,temp_dir)


Expand All @@ -123,7 +123,7 @@ def test_other_architectures(arch):
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=False, pbj=False)
sim.postprocessing(run_all=True)
run_checkers(sim,sim_name,temp_dir)


Expand All @@ -148,7 +148,7 @@ def test_non_linear_and_schemes(model,scheme):
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=False, pbj=False)
sim.postprocessing(run_all=True)
run_checkers(sim,sim_name,temp_dir)


Expand All @@ -163,7 +163,7 @@ def test_mesh_post():
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=True, pbj=False)
sim.postprocessing(run_all=True, plot_mesh=True)
run_checkers(sim,sim_name,temp_dir)


Expand All @@ -178,7 +178,7 @@ def test_iteration_continuation():
sim.create_simulation()
sim.adapt_model()
sim.solve_model()
sim.postprocessing(run_all=True, mesh=False, pbj=True)
sim.postprocessing(run_all=True)

assert os.path.isdir(results_path)
assert len(os.listdir(results_path)) > 0
Expand All @@ -191,7 +191,7 @@ def test_iteration_continuation():
sim_2.create_simulation()
sim_2.adapt_model()
sim_2.solve_model()
sim_2.postprocessing(run_all=True, mesh=False, pbj=False)
sim_2.postprocessing(run_all=True)

last_iteration = os.path.join(iterations_path,f'iter_{sim_2.N_iters}')

Expand Down
28 changes: 1 addition & 27 deletions xppbe/Post/Post_Template.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"metadata": {},
"outputs": [],
"source": [
"Post.plot_G_solv_history(known=True, method=method);"
"Post.plot_G_solv_history(method);"
]
},
{
Expand Down Expand Up @@ -155,32 +155,6 @@
"source": [
"Post.plot_phi_line_aprox_known(method, value='react', theta=np.pi/2, phi=np.pi, N=300);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vertices = Post.mesh.mol_verts.astype(np.float32)\n",
"elements = Post.mesh.mol_faces.astype(np.float32)\n",
"phi_known = Post.phi_known(method,'react', vertices,'solvent')\n",
"Post.plot_interface_3D_known(phi_known, vertices, elements, jupyter=True);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"phi_known = Post.phi_known(method,'react', vertices, flag='solvent')\n",
"phi_xpinn = Post.get_phi(vertices,flag='molecule',model=Post.model,value='react')\n",
"\n",
"error = np.abs((phi_xpinn.numpy() - phi_known.numpy().reshape(-1,1))/phi_known.numpy().reshape(-1,1))\n",
"\n",
"Post.plot_interface_error(error, vertices, elements, scale='log', jupyter=True);"
]
}
],
"metadata": {
Expand Down
151 changes: 106 additions & 45 deletions xppbe/Post/Postcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ def plot_weights_history(self, domain=1):
return fig,ax


def plot_G_solv_history(self, known=False, method=None):
def plot_G_solv_history(self, known_method=None):
fig,ax = plt.subplots()
ax.plot(np.array(list(self.PINN.G_solv_hist.keys()), dtype=self.DTYPE), self.PINN.G_solv_hist.values(),'k-',label='PINN')
if known:
G_known = self.PDE.solvation_energy_phi_qs(self.to_V**-1*self.phi_known(method,'react',tf.constant(self.PDE.x_qs, dtype=self.DTYPE),'molecule'))
if not known_method is None:
G_known = self.PDE.solvation_energy_phi_qs(self.to_V**-1*self.phi_known(known_method,'react',tf.constant(self.PDE.x_qs, dtype=self.DTYPE),'molecule'))
G_known = np.ones(len(self.PINN.G_solv_hist))*G_known
label = method.replace('_',' ') if 'Born' not in method else 'Analytic'
label = known_method.replace('_',' ') if 'Born' not in known_method else 'Analytic'
ax.plot(np.array(list(self.PINN.G_solv_hist.keys()), dtype=self.DTYPE), G_known,'r--',label=f'{label}')
ax.legend()
n_label = r'$n$'
Expand All @@ -177,7 +177,7 @@ def plot_G_solv_history(self, known=False, method=None):
ax.grid()

if self.save:
path = 'Gsolv_history.png' if not known else f'Gsolv_history_{method}.png'
path = 'Gsolv_history.png' if known_method is None else f'Gsolv_history_{known_method}.png'
path_save = os.path.join(self.directory,self.path_plots_solution,path)
fig.savefig(path_save, bbox_inches='tight')
return fig,ax
Expand Down Expand Up @@ -320,10 +320,10 @@ def plot_interface_3D(self,variable='phi', value='phi', domain='interface', jupy

if variable == 'phi':
values,values_1,values_2 = self.get_phi_interface(self.PINN.model,value=value)
text_l = r'phi' if value == 'phi' else r'phi_react'
text_l = r'phi' if value == 'phi' else 'ϕ react'
elif variable == 'dphi':
values,values_1,values_2 = self.get_dphi_interface(self.PINN.model)
text_l = r'dphi' if value == 'phi' else r'dphi_react'
text_l = r'dphi' if value == 'phi' else '∂ϕ react'

if domain =='interface':
values = values.numpy().flatten()
Expand All @@ -338,34 +338,71 @@ def plot_interface_3D(self,variable='phi', value='phi', domain='interface', jupy
intensity=values, colorscale='RdBu_r',
colorbar=dict(title=f'{text_l} [V]')))

fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'),margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'),margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")

if not jupyter and self.save:
path_save = os.path.join(self.directory, self.path_plots_solution, f'Interface_{variable}_{value}_{domain}')
if self.save:
if ext=='html':
fig.write_html(os.path.join(self.directory, self.path_plots_solution, f'Interface_{variable}_{value}_{domain}.html'))
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(os.path.join(self.directory, self.path_plots_solution, f'Interface_{variable}_{value}_{domain}.png'), scale=3)
elif jupyter:
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig


def plot_interface_3D_known(self, method, cmin=None,cmax=None, jupyter=False, ext='html'):

vertices = self.mesh.mol_verts.astype(np.float32)
elements = self.mesh.mol_faces.astype(np.float32)
phi_known = self.phi_known(method,'react', vertices,'solvent')

fig = self.plot_interface_3D_known_by(phi_known, vertices, elements, jupyter=False)
path_save = os.path.join(self.directory, self.path_plots_solution, f'Interface_{method}')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig

@staticmethod
def plot_interface_3D_known(phi_known, vertices, elements, cmin=None,cmax=None, jupyter=True):
def plot_interface_3D_known_by(phi_known, vertices, elements, cmin=None,cmax=None, jupyter=True):
cmin = np.min(phi_known) if cmin is None else cmin
cmax = np.max(phi_known) if cmax is None else cmax
fig = go.Figure()
fig.add_trace(go.Mesh3d(x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
i=elements[:, 0], j=elements[:, 1], k=elements[:, 2],
intensity=phi_known, colorscale='RdBu_r',cmin=cmin,cmax=cmax,
colorbar=dict(title='Phi react [V]')))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
colorbar=dict(title='ϕ react [V]')))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")
if jupyter:
fig.show()
return fig

@staticmethod
def plot_interface_error(error,vertices,elements,scale='log',jupyter=True):
def plot_interface_error(self,method, type_e='relative',scale='log',jupyter=False, ext='html'):
vertices = self.mesh.mol_verts.astype(np.float32)
elements = self.mesh.mol_faces.astype(np.float32)

phi_known = self.phi_known(method,'react', vertices, flag='solvent')
phi_pinn = self.get_phi(vertices,flag='molecule',model=self.model,value='react')

error = np.abs(phi_pinn.numpy() - phi_known.numpy().reshape(-1,1))
if type_e == 'relative':
error /= phi_known.numpy().reshape(-1,1)
fig = self.plot_interface_error_by(error,vertices,elements,scale,jupyter=False)

path_save = os.path.join(self.directory, self.path_plots_solution, f'Interface_error_{method}_{type_e}_{scale}')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)

@staticmethod
def plot_interface_error_by(error,vertices,elements,scale='log',jupyter=True):
fig = go.Figure()
if scale=='log':
epsilon = 1e-10
Expand All @@ -384,13 +421,14 @@ def plot_interface_error(error,vertices,elements,scale='log',jupyter=True):
i=elements[:, 0], j=elements[:, 1], k=elements[:, 2],
intensity=np.abs(error), colorscale='Plasma',
colorbar=dict(title='Error')))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")

if jupyter:
fig.show()
return fig


def plot_collocation_points_3D(self, jupyter=False):
def plot_collocation_points_3D(self, jupyter=False, ext='html'):

color_dict = {
'Q1_verts': 'red',
Expand Down Expand Up @@ -425,15 +463,19 @@ def plot_collocation_points_3D(self, jupyter=False):
)
fig.add_trace(trace)

fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")

if not jupyter and self.save:
fig.write_html(os.path.join(self.directory,self.path_plots_meshes, 'collocation_points_plot_3d.html'))
elif jupyter:
path_save = os.path.join(self.directory,self.path_plots_meshes, 'collocation_points_plot_3d')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig

def plot_surface_mesh_3D(self, jupyter=False):
def plot_surface_mesh_3D(self, jupyter=False, ext='html'):

vertices = self.mesh.mol_verts
elements = self.mesh.mol_faces
Expand Down Expand Up @@ -467,15 +509,19 @@ def plot_surface_mesh_3D(self, jupyter=False):
)

fig = go.Figure(data=[element_trace,edge_trace])
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode='data', xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")

if not jupyter and self.save:
fig.write_html(os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_surf_3D.html'))
elif jupyter:
path_save = os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_surf_3D')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig

def plot_vol_mesh_3D(self, jupyter=False):
def plot_vol_mesh_3D(self, jupyter=False, ext='html'):
toRemove = []
ext_tetmesh = self.mesh.ext_tetmesh
for vertexID in ext_tetmesh.vertexIDs:
Expand Down Expand Up @@ -555,17 +601,20 @@ def plot_vol_mesh_3D(self, jupyter=False):
)

fig = go.Figure(data=[element_trace_ex,edge_trace_ex, element_trace_in, edge_trace_in])
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")


if not jupyter and self.save:
fig.write_html(os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_vol_3D.html'))
elif jupyter:
path_save = os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_vol_3D')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig


def plot_mesh_3D(self,domain,element_indices=None, jupyter=False):
def plot_mesh_3D(self,domain,element_indices=None, jupyter=False, ext='html'):

element_indices_input = element_indices

Expand Down Expand Up @@ -653,15 +702,19 @@ def plot_mesh_3D(self,domain,element_indices=None, jupyter=False):

fig = go.Figure(data=[trace_vertices, trace_edges, trace_elements])

fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")

if not jupyter and self.save:
fig.write_html(os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_3D_{domain}.html'))
elif jupyter:
path_save = os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_3D_{domain}')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig

def plot_surface_mesh_normals(self,plot='vertices',jupyter=False):
def plot_surface_mesh_normals(self,plot='vertices',jupyter=False, ext='html'):

mesh_obj = self.mesh

Expand Down Expand Up @@ -711,11 +764,16 @@ def plot_surface_mesh_normals(self,plot='vertices',jupyter=False):
)

fig = go.Figure(data=[mesh_trace, vertex_normals_trace, edge_trace])
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [A]', yaxis_title='Y [A]', zaxis_title='Z [A]'), margin=dict(l=30, r=40, t=20, b=20))
fig.update_layout(scene=dict(aspectmode="data", xaxis_title='X [Å]', yaxis_title='Y [Å]', zaxis_title='Z [Å]'), margin=dict(l=30, r=40, t=20, b=20),font_family="Times New Roman")


if not jupyter and self.save:
fig.write_html(os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_surface_normals_{plot}.html'))
elif jupyter:
path_save = os.path.join(self.directory,self.path_plots_meshes,f'mesh_plot_surface_normals_{plot}')
if self.save:
if ext=='html':
fig.write_html(path_save+'.html')
elif ext=='png':
fig.write_image(path_save+'.png', scale=3)
if jupyter:
fig.show()
return fig

Expand Down Expand Up @@ -792,7 +850,7 @@ def L2_interface_known(self,known_method):
error = np.sqrt(np.sum(phi_dif**2)/np.sum(phi_known.numpy()**2))
return error

def save_values_file(self,save=True):
def save_values_file(self, save=True, L2_err_method=None):

max_iter = max(map(int,list(self.PINN.G_solv_hist.keys())))
Gsolv_value = self.PINN.G_solv_hist[str(max_iter)]
Expand All @@ -812,6 +870,9 @@ def save_values_file(self,save=True):
'Loss_Data_K2': self.PINN.losses['K2'][-1]
}

if not L2_err_method is None:
dict_pre['L2_error'] = self.L2_interface_known(L2_err_method)

df_dict = {}
for key, value in dict_pre.items():
if key=='Gsolv_value':
Expand Down
Loading

0 comments on commit 87552e6

Please sign in to comment.