Skip to content

Commit

Permalink
improved train-test split
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Jun 17, 2024
1 parent e44dc4e commit 38fd25a
Showing 1 changed file with 49 additions and 22 deletions.
71 changes: 49 additions & 22 deletions learn_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm, trange
import torch.nn.init as init
from torch.utils.data import DataLoader, TensorDataset, random_split


from GeneralRelativity.Utils import (
Expand Down Expand Up @@ -63,12 +64,22 @@ def main():
writer = SummaryWriter(f"{folder_name}")

# Loading small testdata
filenamesX = "/home/thelfer1/scr4_tedwar42/thelfer1/data_gen_binary/outputXdata_level1_step0050.dat"
num_varsX = 100
# filenamesX = "/home/thelfer1/scr4_tedwar42/thelfer1/data_gen_binary/outputXdata_level1_step0050.dat"

filenamesX = "/home/thelfer1/scr4_tedwar42/thelfer1/high_end_data/outputXdata_level9_step000[02].dat"

num_varsX = 25
dataX = get_box_format(filenamesX, num_varsX)
# Cutting out extra values added for validation
dataX = dataX[:, :, :, :, :25]

plt.imshow(
dataX[0, 8, :, :, 0], cmap="viridis"
) # 'viridis' is a colormap, you can choose others like 'plasma', 'inferno', etc.
plt.colorbar() # Add a colorbar to show the scale
plt.title("2D Array Plot")
plt.savefig("testarray.png")

class SuperResolution3DNet(torch.nn.Module):
def __init__(self, factor):
super(SuperResolution3DNet, self).__init__()
Expand Down Expand Up @@ -112,7 +123,7 @@ def __init__(self, factor):
)

# Initialize only the weights in self.encoder and self.decoder
self.initialize_encoder_decoder_weights()
# self.initialize_encoder_decoder_weights()

def initialize_encoder_decoder_weights(self):
for m in [self.encoder, self.decoder]:
Expand Down Expand Up @@ -163,12 +174,13 @@ def forward(self, x):

losses_train = []
losses_val = []
losses_val_interp = []
steps_val = []

optimizerBFGS = torch.optim.LBFGS(
net.parameters(), lr=0.1
) # Use LBFGS sometimes, it really does do magic sometimes, though its a bit of a diva
optimizerADAM = torch.optim.Adam(net.parameters(), lr=0.00001)
optimizerADAM = torch.optim.Adam(net.parameters(), lr=1e-4)

# Define the ratio for the split (e.g., 80% train, 20% test)
train_ratio = 0.8
Expand All @@ -179,21 +191,26 @@ def forward(self, x):
num_train = int(train_ratio * num_samples)
num_test = num_samples - num_train

train_torch = dataX[:num_train].permute(0, 4, 1, 2, 3).to(device)
test_torch = dataX[num_train:].permute(0, 4, 1, 2, 3).to(device)
# Permute data to put the channel as the second dimension (N, C, H, W, D)
dataX = dataX.permute(0, 4, 1, 2, 3)

# Create a dataset from tensors
dataset = TensorDataset(dataX)

# Split the dataset into training and testing datasets
train_dataset, test_dataset = random_split(dataset, [num_train, num_test])
batch_size = 51

# Create DataLoader for batching -- in case data gets larger
train_loader = DataLoader(
dataset=TensorDataset(train_torch),
dataset=train_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=False,
num_workers=0,
)
test_loader = DataLoader(
dataset=TensorDataset(test_torch),
dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
Expand Down Expand Up @@ -229,14 +246,14 @@ def __call__(self, output: torch.tensor, dummy: torch.tensor) -> torch.tensor:
if restart and os.path.exists(file_path):
net.load_state_dict(torch.load(file_path))

oneoverdx = 64.0 / 16.0
# oneoverdx = 64.0 / 16.0
oneoverdx = (64.0 * 2**9) / 512.0
print(f"dx {1.0/oneoverdx}")
my_loss = Hamiltonian_loss(oneoverdx)

# Note: it will slow down signficantly with BFGS steps, they are 10x slower, just be aware!
ADAMsteps = (
1000000 # Will perform # steps of ADAM steps and then switch over to BFGS-L
)
n_steps = 0 # Total amount of steps
ADAMsteps = 21 # Will perform # steps of ADAM steps and then switch over to BFGS-L
n_steps = 23 # Total amount of steps

net.train()
net.to(device)
Expand Down Expand Up @@ -284,7 +301,6 @@ def closure():

loss_train = closure()
total_loss_train += loss_train.item()

# Calculate the average training loss
average_loss_train = total_loss_train / len(train_loader)
# Log the average training loss
Expand All @@ -295,11 +311,10 @@ def closure():

# Validation

if counter % 4 == 0:
if counter % 1 == 0:
with torch.no_grad():
total_loss_val = (
0.0 # Initialize a variable to accumulate the total loss
)
total_loss_val = 0.0
interp_val = 0.0
for (y_val_batch,) in test_loader:
# for X_val_batch, y_val_batch in test_loader:
# Transfer batch to GPU
Expand All @@ -312,15 +327,21 @@ def closure():
diff - 1 : -diff - 1,
diff - 1 : -diff - 1,
]
y_val_interp = net.interpolation(X_val_batch)
y_val_pred = net(X_val_batch)
loss_val = my_loss(y_val_pred, y_val_batch)
total_loss_val += loss_val.item()
interp_val += my_loss(y_val_interp, y_val_batch).item()
# Calculate the average loss
average_loss_val = total_loss_val / len(test_loader)
average_interp_val = interp_val / len(test_loader)
losses_val_interp.append(average_interp_val)
losses_val.append(average_loss_val)
steps_val.append(counter)
writer.add_scalar("loss/test", loss_val.item(), counter)
if counter % 1000 == 0:
writer.add_scalar("loss/test", loss_val.item(), counter)

if counter % 1 == 0:
# Writing out network and scaler
torch.save(
net.state_dict(),
Expand All @@ -332,9 +353,8 @@ def closure():
# Plotting shit at the end
plt.figure(figsize=(9, 6))
plt.plot(np.array(losses_train), label="Train")
plt.plot(
steps_val, np.array(losses_val), label="Val with Relative loss", linewidth=0.5
)
plt.plot(steps_val, np.array(losses_val), label="Val", linewidth=0.5)
plt.plot(steps_val, np.array(losses_val_interp), label="baseline", linewidth=0.5)
plt.yscale("log")
plt.legend()
plt.savefig(f"{folder_name}/training.png")
Expand Down Expand Up @@ -518,6 +538,13 @@ def closure():
f"L1 loss interpolation {my_loss(y_interpolated.cpu(), y_batch[:, :, :, :, :].cpu())}\n"
)

print(
f"L1 loss Neural Network {my_loss(y_pred[:, :, :, :, :].cpu(), y_batch[:, :, :, :, :].cpu())}\n"
)
print(
f"L1 loss interpolation {my_loss(y_interpolated.cpu(), y_batch[:, :, :, :, :].cpu())}\n"
)


if __name__ == "__main__":
main()

0 comments on commit 38fd25a

Please sign in to comment.