From 886e29855de2c3001075d601805e067ccb8a1644 Mon Sep 17 00:00:00 2001 From: Emmanouil Theodosis Date: Sat, 13 Jul 2024 14:41:51 +0300 Subject: [PATCH] Update train_egnn.py --- train_egnn.py | 78 ++++----------------------------------------------- 1 file changed, 5 insertions(+), 73 deletions(-) diff --git a/train_egnn.py b/train_egnn.py index a1984e5..0cbc233 100644 --- a/train_egnn.py +++ b/train_egnn.py @@ -30,60 +30,12 @@ sns.set_context("paper") sns.set(font_scale=2) cmap = plt.get_cmap("twilight") -# cmap_t = plt.get_cmap("turbo") -# cmap = plt.get_cmap("hsv") color_plot = sns.cubehelix_palette(4, reverse=True, rot=-0.2) from matplotlib import cm, rc rc("text", usetex=True) rc("text.latex", preamble=r"\usepackage{amsmath}") - -def zeromean(X, mean=None, std=None): - "Expects data in NxCxWxH." - if mean is None: - mean = X.mean(axis=(0, 2, 3)) - std = X.std(axis=(0, 2, 3)) - std = torch.ones(std.shape) - - X = torchvision.transforms.Normalize(mean, std)(X) - return X, mean, std - - -def standardize(X, mean=None, std=None): - "Expects data in NxCxWxH." - if mean is None: - mean = X.mean(axis=(0, 2, 3)) - std = X.std(axis=(0, 2, 3)) - - X = torchvision.transforms.Normalize(mean, std)(X) - return X, mean, std - - -def standardize_y(Y, mean=None, std=None): - "Expects data in Nx1." - if mean is None: - mean = Y.min() - std = Y.max() - Y.min() - - Y = (Y - mean) / std - return Y, mean, std - - -def whiten(X, zca=None, mean=None, eps=1e-8): - "Expects data in NxCxWxH." - os = X.shape - X = X.reshape(os[0], -1) - - if zca is None: - mean = X.mean(dim=0) - cov = np.cov(X, rowvar=False) - U, S, V = np.linalg.svd(cov) - zca = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + eps)), U.T)) - X = torch.Tensor(np.dot(X - mean, zca.T).reshape(os)) - return X, zca, mean - - def lattice_nbr(grid_size): """dxd edge list (periodic)""" edg = set() @@ -124,9 +76,7 @@ def main(): batch_size = 64 ### Data loading - # data = np.load("data_n=10000.npy", allow_pickle=True) - data = np.load("data_n=10000_gauge.npy", allow_pickle=True) - # data = np.load("/Users/manos/data/gauge/data_n=10000.npy", allow_pickle=True) + data = np.load("data_n=10000.npy", allow_pickle=True) X, Y = data.item()["x"], data.item()["y"] tr_idx = np.random.choice(X.shape[0], int(0.8 * X.shape[0]), replace=False) @@ -141,22 +91,6 @@ def main(): X_te = torch.Tensor(X_te).view(-1, 1, grid_size, grid_size) Y_te = torch.Tensor(Y_te).view(-1, 1) - if data_norm == "standard": - X_tr, mean, std = standardize(X_tr) - X_te, _, _ = standardize(X_te, mean, std) - elif data_norm == "zeromean": - X_tr, mean, std = zeromean(X_tr) - X_te, _, _ = zeromean(X_te, mean, std) - elif data_norm == "whiten": - X_tr, mean, std = standardize(X_tr) - X_te, _, _ = standardize(X_te, mean, std) - - X_tr, zca, mean = whiten(X_tr) - X_te, _, _ = whiten(X_te, zca, mean) - elif data_norm == "y": - Y_tr, mean, std = standardize_y(Y_tr) - Y_te, _, _ = standardize_y(Y_te, mean, std) - train_dl = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(X_tr, Y_tr), batch_size=batch_size, @@ -192,7 +126,6 @@ def main(): ).to(device) opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0001) - # opt = optim.SGD(model.parameters(), lr=1e-3, weight_decay=0.0001, momentum=0.9) schd = optim.lr_scheduler.MultiStepLR( opt, [int(1 / 2 * epochs), int(3 / 4 * epochs)], gamma=0.1 ) @@ -206,8 +139,8 @@ def main(): prev = start # generate gauge field - prod = 0.75 # 0.25 - add = 0 # .25 + prod = 0.75 + add = 0 # apply gauge t = np.linspace(0, 1, grid_size) @@ -241,7 +174,7 @@ def main(): .float() .to(device) ) - x_t = x # + field_x_c + field_y_c + x_t = x + field_x_c + field_y_c # EGNN expects data as (N * grid_size * grid_size, 2) x = x_t.view(batch_size_t * grid_size * grid_size, 1) @@ -291,7 +224,7 @@ def main(): .float() .to(device) ) - x_t = x # + field_x_c + field_y_c + x_t = x + field_x_c + field_y_c # EGNN expects data as (N * grid_size * grid_size, 2) x = x_t.view(batch_size_t * grid_size * grid_size, 1) @@ -347,6 +280,5 @@ def main(): plt.savefig("loss.pdf", bbox_inches="tight") plt.close() - if __name__ == "__main__": main()