From 4603d59bc298c2eb3b657cede2c540a844a94263 Mon Sep 17 00:00:00 2001 From: martinachondo Date: Tue, 2 Jul 2024 09:04:46 -0400 Subject: [PATCH] fix --- xppbe/NN/PINN.py | 4 ---- xppbe/Post/Postcode.py | 14 +++++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/xppbe/NN/PINN.py b/xppbe/NN/PINN.py index 2acb325..9599822 100644 --- a/xppbe/NN/PINN.py +++ b/xppbe/NN/PINN.py @@ -114,10 +114,6 @@ def main_loop(self, N=1000, N2=0): X_v = self.get_batches('full_batch', validation=True) X_d = self.get_batches(self.sample_method) - print(X_d) - print(self.w) - print(self.mesh.domain_mesh_names) - self.train_sgd(X_d, X_v) if self.use_optimizer_2: diff --git a/xppbe/Post/Postcode.py b/xppbe/Post/Postcode.py index 1f33a1d..6ec6d13 100644 --- a/xppbe/Post/Postcode.py +++ b/xppbe/Post/Postcode.py @@ -132,21 +132,21 @@ def run_all(self,plot_mesh,known_method): def plot_loss_history(self, domain=1, plot_w=False, loss='all'): fig,ax = plt.subplots() - c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon'} + c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon','IB1':'lime','IB2': 'aqua'} c2 = {'royalblue','springgreen','aqua', 'pink','yellowgreen','teal'} for i in ['1','2']: if int(i)==domain: if not plot_w: - w = {'R'+i: 1.0, 'D'+i: 1.0, 'N'+i: 1.0, 'K'+i: 1.0, 'E'+i: 1.0, 'Q'+i: 1.0, 'G': 1.0, 'Iu': 1.0, 'Id': 1.0, 'Ir': 1.0} + w = {'R'+i: 1.0, 'D'+i: 1.0, 'N'+i: 1.0, 'K'+i: 1.0, 'E'+i: 1.0, 'Q'+i: 1.0, 'G': 1.0, 'Iu': 1.0, 'Id': 1.0, 'Ir': 1.0,'IB1':1.0,'IB2':1.0} elif plot_w: w = self.PINN.w_hist if plot_w==False and (loss=='TL' or loss=='all'): ax.semilogy(range(1,len(self.PINN.losses['TL'+i])+1), self.PINN.losses['TL'+i],'k-',label='TL'+i) for t in self.PINN.losses_names_list[int(i)-1]: - t2 = t if t in ('Iu','Id','Ir','G') else t[0] + t2 = t if t in ('Iu','Id','Ir','G','IB1','IB2') else t[0] if (t2 in loss or loss=='all') and not t in 'TL' and t in self.mesh.domain_mesh_names: - cx = c[t] if t in ('Iu','Id','Ir','G') else c[t[0]] + cx = c[t] if t in ('Iu','Id','Ir','G','IB1','IB2') else c[t[0]] ax.semilogy(range(1,len(self.PINN.losses[t])+1), w[t]*self.PINN.losses[t],cx,label=f'{t}') ax.legend() @@ -173,7 +173,7 @@ def plot_loss_validation_history(self, domain=1, loss='TL'): ax.semilogy(range(1,len(self.PINN.losses['vTL'+i])+1), self.PINN.losses['vTL'+i],'b-',label=f'Training {i}') ax.semilogy(range(1,len(self.PINN.validation_losses['TL'+i])+1), self.PINN.validation_losses['TL'+i],'r-',label=f'Validation {i}') else: - t = loss if loss in ('Iu','Id','Ir','G') else loss+i + t = loss if loss in ('Iu','Id','Ir','G','IB1','IB2') else loss+i if t in self.mesh.domain_mesh_names : ax.semilogy(range(1,len(self.PINN.losses[t])+1), self.PINN.losses[t],'b-',label=f'{loss} training') ax.semilogy(range(1,len(self.PINN.validation_losses[t])+1), self.PINN.validation_losses[t],'r-',label=f'{loss} validation') @@ -197,13 +197,13 @@ def plot_loss_validation_history(self, domain=1, loss='TL'): def plot_weights_history(self, domain=1): fig,ax = plt.subplots() - c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon'} + c = {'TL': 'k','R':'r','D':'b','N':'g', 'K': 'gold','Q': 'c','Iu':'m','Id':'lime', 'Ir': 'aqua', 'E':'darkslategrey','G': 'salmon','IB1':'lime','IB2': 'aqua'} for i in ['1','2']: if int(i)==domain: w = self.PINN.w_hist for t in self.PINN.losses_names_list[int(i)-1]: if t in self.mesh.domain_mesh_names: - cx = c[t] if t in ('Iu','Id','Ir','G') else c[t[0]] + cx = c[t] if t in ('Iu','Id','Ir','G','IB1','IB2') else c[t[0]] ax.semilogy(range(1,len(w[t])+1), w[t], cx,label=f'{t}') ax.legend()