Skip to content

Commit

Permalink
added config file
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jun 24, 2024
1 parent 964f2a2 commit af4deae
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions learn_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def main():
filenamesX = config["filenamesX"].format(res_level=res_level)
restart = config["restart"]
file_path = config["file_path"]
lambda_fac = config["lambda_fac"]

num_varsX = 25
dataX = get_box_format(filenamesX, num_varsX)
Expand Down Expand Up @@ -225,8 +226,9 @@ def forward(self, x):

# Magical loss coming from General Relativity
class Hamiltonian_loss:
def __init__(self, oneoverdx: float):
def __init__(self, oneoverdx: float, lambda_fac: float = 0):
self.oneoverdx = oneoverdx
self.lambda_fac = float(lambda_fac)

def __call__(
self, output: torch.tensor, y_interp: torch.tensor
Expand All @@ -253,7 +255,7 @@ def __call__(
if y_interp is not None:
diff = torch.abs(torch.mean(y_interp - output))
hamloss = torch.mean(out["Ham"] * out["Ham"])
loss = hamloss + diff
loss = hamloss + diff * self.lambda_fac
return loss

if restart and os.path.exists(file_path):
Expand All @@ -262,7 +264,7 @@ def __call__(
# oneoverdx = 64.0 / 16.0
oneoverdx = (64.0 * 2**res_level) / 512.0
print(f"dx {1.0/oneoverdx}")
my_loss = Hamiltonian_loss(oneoverdx)
my_loss = Hamiltonian_loss(oneoverdx, lambda_fac)

net.train()
net.to(device)
Expand Down Expand Up @@ -290,7 +292,7 @@ def closure():
optimizerBFGS.zero_grad()
y_pred, y_interp = net(X_batch)

loss_train = my_loss(y_pred, y_batch)
loss_train = my_loss(y_pred, y_interp)
if loss_train.requires_grad:
loss_train.backward()
return loss_train
Expand All @@ -299,7 +301,7 @@ def closure():
if counter < ADAMsteps:
y_pred, y_interp = net(X_batch)

loss_train = my_loss(y_pred, y_batch)
loss_train = my_loss(y_pred, y_interp)
optimizerADAM.zero_grad()
loss_train.backward()
optimizerADAM.step()
Expand Down

0 comments on commit af4deae

Please sign in to comment.