Skip to content

Commit

Permalink
Update gen_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
manosth authored Jul 13, 2024
1 parent b156eb6 commit e5e1843
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions data_gen/gen_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
sns.set_context("paper")
sns.set(font_scale=1.4)
cmap = plt.get_cmap("twilight")
# cmap = plt.get_cmap("hsv")

# torch imports
import torch
Expand Down Expand Up @@ -61,7 +60,7 @@ def forward(self):
return self.grid_list


def energy_loss_nima(grid_list, nbr):
def energy_loss(grid_list, nbr):
"""
Computes the energy of the configuration in the XY model.
Expand Down Expand Up @@ -116,19 +115,11 @@ def energy_loss_nima(grid_list, nbr):
gamma = (final_lr / lr) ** (1 / updates)
step_size = epochs // updates
schd = optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
# if lr >= 10:
# schd = optim.lr_scheduler.MultiStepLR(
# opt, [int(0.7 * epochs), int(0.85 * epochs)], gamma=0.1
# )
# elif lr >= 1:
# schd = optim.lr_scheduler.MultiStepLR(opt, [int(0.9 * epochs)], gamma=0.1)
# else:
# schd = optim.lr_scheduler.MultiStepLR(opt, [int(0.9 * epochs)], gamma=1)


for epoch in range(1, epochs + 1):
model.train()
s_e = model()
loss = energy_loss_nima(s_e, model.get_nbr())
loss = energy_loss(s_e, model.get_nbr())

opt.zero_grad()
loss.backward()
Expand All @@ -145,5 +136,3 @@ def energy_loss_nima(grid_list, nbr):

data = {"x": fields, "y": energies, "lr": rates}
np.save(f"data_n={N}.npy", data)
# sns.histplot(energies)
# plt.show()

0 comments on commit e5e1843

Please sign in to comment.