Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinAchondo committed Jul 2, 2024
1 parent 7909037 commit 4603d59
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
4 changes: 0 additions & 4 deletions xppbe/NN/PINN.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,6 @@ def main_loop(self, N=1000, N2=0):
X_v = self.get_batches('full_batch', validation=True)
X_d = self.get_batches(self.sample_method)

print(X_d)
print(self.w)
print(self.mesh.domain_mesh_names)

self.train_sgd(X_d, X_v)

if self.use_optimizer_2:
Expand Down
14 changes: 7 additions & 7 deletions xppbe/Post/Postcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,21 @@ def run_all(self,plot_mesh,known_method):

def plot_loss_history(self, domain=1, plot_w=False, loss='all'):
fig,ax = plt.subplots()
c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon'}
c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon','IB1':'lime','IB2': 'aqua'}
c2 = {'royalblue','springgreen','aqua', 'pink','yellowgreen','teal'}
for i in ['1','2']:
if int(i)==domain:
if not plot_w:
w = {'R'+i: 1.0, 'D'+i: 1.0, 'N'+i: 1.0, 'K'+i: 1.0, 'E'+i: 1.0, 'Q'+i: 1.0, 'G': 1.0, 'Iu': 1.0, 'Id': 1.0, 'Ir': 1.0}
w = {'R'+i: 1.0, 'D'+i: 1.0, 'N'+i: 1.0, 'K'+i: 1.0, 'E'+i: 1.0, 'Q'+i: 1.0, 'G': 1.0, 'Iu': 1.0, 'Id': 1.0, 'Ir': 1.0,'IB1':1.0,'IB2':1.0}
elif plot_w:
w = self.PINN.w_hist

if plot_w==False and (loss=='TL' or loss=='all'):
ax.semilogy(range(1,len(self.PINN.losses['TL'+i])+1), self.PINN.losses['TL'+i],'k-',label='TL'+i)
for t in self.PINN.losses_names_list[int(i)-1]:
t2 = t if t in ('Iu','Id','Ir','G') else t[0]
t2 = t if t in ('Iu','Id','Ir','G','IB1','IB2') else t[0]
if (t2 in loss or loss=='all') and not t in 'TL' and t in self.mesh.domain_mesh_names:
cx = c[t] if t in ('Iu','Id','Ir','G') else c[t[0]]
cx = c[t] if t in ('Iu','Id','Ir','G','IB1','IB2') else c[t[0]]
ax.semilogy(range(1,len(self.PINN.losses[t])+1), w[t]*self.PINN.losses[t],cx,label=f'{t}')

ax.legend()
Expand All @@ -173,7 +173,7 @@ def plot_loss_validation_history(self, domain=1, loss='TL'):
ax.semilogy(range(1,len(self.PINN.losses['vTL'+i])+1), self.PINN.losses['vTL'+i],'b-',label=f'Training {i}')
ax.semilogy(range(1,len(self.PINN.validation_losses['TL'+i])+1), self.PINN.validation_losses['TL'+i],'r-',label=f'Validation {i}')
else:
t = loss if loss in ('Iu','Id','Ir','G') else loss+i
t = loss if loss in ('Iu','Id','Ir','G','IB1','IB2') else loss+i
if t in self.mesh.domain_mesh_names :
ax.semilogy(range(1,len(self.PINN.losses[t])+1), self.PINN.losses[t],'b-',label=f'{loss} training')
ax.semilogy(range(1,len(self.PINN.validation_losses[t])+1), self.PINN.validation_losses[t],'r-',label=f'{loss} validation')
Expand All @@ -197,13 +197,13 @@ def plot_loss_validation_history(self, domain=1, loss='TL'):

def plot_weights_history(self, domain=1):
fig,ax = plt.subplots()
c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon'}
c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon','IB1':'lime','IB2': 'aqua'}
for i in ['1','2']:
if int(i)==domain:
w = self.PINN.w_hist
for t in self.PINN.losses_names_list[int(i)-1]:
if t in self.mesh.domain_mesh_names:
cx = c[t] if t in ('Iu','Id','Ir','G') else c[t[0]]
cx = c[t] if t in ('Iu','Id','Ir','G','IB1','IB2') else c[t[0]]
ax.semilogy(range(1,len(w[t])+1), w[t], cx,label=f'{t}')

ax.legend()
Expand Down

0 comments on commit 4603d59

Please sign in to comment.