diff --git a/egnn_clean.py b/egnn_clean.py new file mode 100644 index 0000000..ae39374 --- /dev/null +++ b/egnn_clean.py @@ -0,0 +1,211 @@ +from torch import nn +import torch + + +class E_GCL(nn.Module): + """ + E(n) Equivariant Convolutional Layer + re + """ + + def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False): + super(E_GCL, self).__init__() + input_edge = input_nf * 2 + self.residual = residual + self.attention = attention + self.normalize = normalize + self.coords_agg = coords_agg + self.tanh = tanh + self.epsilon = 1e-8 + edge_coords_nf = 1 + + self.edge_mlp = nn.Sequential( + nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), + act_fn, + nn.Linear(hidden_nf, hidden_nf), + act_fn) + + self.node_mlp = nn.Sequential( + nn.Linear(hidden_nf + input_nf, hidden_nf), + act_fn, + nn.Linear(hidden_nf, output_nf)) + + layer = nn.Linear(hidden_nf, 1, bias=False) + torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) + + coord_mlp = [] + coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) + coord_mlp.append(act_fn) + coord_mlp.append(layer) + if self.tanh: + coord_mlp.append(nn.Tanh()) + self.coord_mlp = nn.Sequential(*coord_mlp) + + if self.attention: + self.att_mlp = nn.Sequential( + nn.Linear(hidden_nf, 1), + nn.Sigmoid()) + + def edge_model(self, source, target, radial, edge_attr): + if edge_attr is None: # Unused. + out = torch.cat([source, target, radial], dim=1) + else: + out = torch.cat([source, target, radial, edge_attr], dim=1) + out = self.edge_mlp(out) + if self.attention: + att_val = self.att_mlp(out) + out = out * att_val + return out + + def node_model(self, x, edge_index, edge_attr, node_attr): + row, col = edge_index + agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) + if node_attr is not None: + agg = torch.cat([x, agg, node_attr], dim=1) + else: + agg = torch.cat([x, agg], dim=1) + out = self.node_mlp(agg) + if self.residual: + out = x + out + return out, agg + + def coord_model(self, coord, edge_index, coord_diff, edge_feat): + row, col = edge_index + trans = coord_diff * self.coord_mlp(edge_feat) + if self.coords_agg == 'sum': + agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) + elif self.coords_agg == 'mean': + agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) + else: + raise Exception('Wrong coords_agg parameter' % self.coords_agg) + coord += agg + return coord + + def coord2radial(self, edge_index, coord): + row, col = edge_index + coord_diff = coord[row] - coord[col] + radial = torch.sum(coord_diff**2, 1).unsqueeze(1) + + if self.normalize: + norm = torch.sqrt(radial).detach() + self.epsilon + coord_diff = coord_diff / norm + + return radial, coord_diff + + def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None): + row, col = edge_index + radial, coord_diff = self.coord2radial(edge_index, coord) + + edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) + coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) + h, agg = self.node_model(h, edge_index, edge_feat, node_attr) + + return h, coord, edge_attr + + +class EGNN(nn.Module): + def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False): + ''' + + :param in_node_nf: Number of features for 'h' at the input + :param hidden_nf: Number of hidden features + :param out_node_nf: Number of features for 'h' at the output + :param in_edge_nf: Number of features for the edge features + :param device: Device (e.g. 'cpu', 'cuda:0',...) + :param act_fn: Non-linearity + :param n_layers: Number of layer for the EGNN + :param residual: Use residual connections, we recommend not changing this one + :param attention: Whether using attention or not + :param normalize: Normalizes the coordinates messages such that: + instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij) + we get: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j|| + We noticed it may help in the stability or generalization in some future works. + We didn't use it in our paper. + :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of + phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy. + We didn't use it in our paper. + ''' + + super(EGNN, self).__init__() + self.hidden_nf = hidden_nf + self.device = device + self.n_layers = n_layers + self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf) + self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) + for i in range(0, n_layers): + self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, + act_fn=act_fn, residual=residual, attention=attention, + normalize=normalize, tanh=tanh)) + self.to(self.device) + + def forward(self, h, x, edges, edge_attr): + h = self.embedding_in(h) + for i in range(0, self.n_layers): + h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr) + h = self.embedding_out(h) + return h, x + + +def unsorted_segment_sum(data, segment_ids, num_segments): + result_shape = (num_segments, data.size(1)) + result = data.new_full(result_shape, 0) # Init empty result tensor. + segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) + result.scatter_add_(0, segment_ids, data) + return result + + +def unsorted_segment_mean(data, segment_ids, num_segments): + result_shape = (num_segments, data.size(1)) + segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) + result = data.new_full(result_shape, 0) # Init empty result tensor. + count = data.new_full(result_shape, 0) + result.scatter_add_(0, segment_ids, data) + count.scatter_add_(0, segment_ids, torch.ones_like(data)) + return result / count.clamp(min=1) + + +def get_edges(n_nodes): + rows, cols = [], [] + for i in range(n_nodes): + for j in range(n_nodes): + if i != j: + rows.append(i) + cols.append(j) + + edges = [rows, cols] + return edges + + +def get_edges_batch(n_nodes, batch_size): + edges = get_edges(n_nodes) + edge_attr = torch.ones(len(edges[0]) * batch_size, 1) + edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])] + if batch_size == 1: + return edges, edge_attr + elif batch_size > 1: + rows, cols = [], [] + for i in range(batch_size): + rows.append(edges[0] + n_nodes * i) + cols.append(edges[1] + n_nodes * i) + edges = [torch.cat(rows), torch.cat(cols)] + return edges, edge_attr + + +if __name__ == "__main__": + # Dummy parameters + batch_size = 8 + n_nodes = 4 + n_feat = 1 + x_dim = 3 + + # Dummy variables h, x and fully connected edges + h = torch.ones(batch_size * n_nodes, n_feat) + x = torch.ones(batch_size * n_nodes, x_dim) + edges, edge_attr = get_edges_batch(n_nodes, batch_size) + + # Initialize EGNN + egnn = EGNN(in_node_nf=n_feat, hidden_nf=32, out_node_nf=1, in_edge_nf=1) + + # Run EGNN + h, x = egnn(h, x, edges, edge_attr) + diff --git a/test_egnn.py b/test_egnn.py new file mode 100644 index 0000000..4b60e65 --- /dev/null +++ b/test_egnn.py @@ -0,0 +1,230 @@ +# system imports +import os + +# python imports +import numpy as np +from operator import itemgetter + +# torch imports +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision + +from torchinfo import summary + +# egnn imports +import egnn_clean as eg + +# plotting imports +import seaborn as sns +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib import animation + +# plotting defaults +sns.set_theme() +sns.set_context("paper") +sns.set(font_scale=2) +cmap = plt.get_cmap("twilight") +# cmap_t = plt.get_cmap("turbo") +# cmap = plt.get_cmap("hsv") +color_plot = sns.cubehelix_palette(4, reverse=True, rot=-0.2) +from matplotlib import cm, rc + +rc("text", usetex=True) +rc("text.latex", preamble=r"\usepackage{amsmath}") + + +def zeromean(X, mean=None, std=None): + "Expects data in NxCxWxH." + if mean is None: + mean = X.mean(axis=(0, 2, 3)) + std = X.std(axis=(0, 2, 3)) + std = torch.ones(std.shape) + + X = torchvision.transforms.Normalize(mean, std)(X) + return X, mean, std + + +def standardize(X, mean=None, std=None): + "Expects data in NxCxWxH." + if mean is None: + mean = X.mean(axis=(0, 2, 3)) + std = X.std(axis=(0, 2, 3)) + + X = torchvision.transforms.Normalize(mean, std)(X) + return X, mean, std + + +def standardize_y(Y, mean=None, std=None): + "Expects data in Nx1." + if mean is None: + mean = Y.min() + std = Y.max() - Y.min() + + Y = (Y - mean) / std + return Y, mean, std + + +def whiten(X, zca=None, mean=None, eps=1e-8): + "Expects data in NxCxWxH." + os = X.shape + X = X.reshape(os[0], -1) + + if zca is None: + mean = X.mean(dim=0) + cov = np.cov(X, rowvar=False) + U, S, V = np.linalg.svd(cov) + zca = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + eps)), U.T)) + X = torch.Tensor(np.dot(X - mean, zca.T).reshape(os)) + return X, zca, mean + + +def lattice_nbr(grid_size): + """dxd edge list (periodic)""" + edg = set() + for x in range(grid_size): + for y in range(grid_size): + v = x + grid_size * y + for i in [-1, 1]: + edg.add((v, ((x + i) % grid_size) + y * grid_size)) + edg.add((v, x + ((y + i) % grid_size) * grid_size)) + return torch.tensor(np.array(list(edg)), dtype=int) + + +def get_edges_batch(edges, n_nodes, batch_size, device): + edge_attr = torch.ones(len(edges[0]) * batch_size, 1, device=device) + edges = [ + torch.LongTensor(edges[0]).to(device), + torch.LongTensor(edges[1]).to(device), + ] + if batch_size == 1: + return edges, edge_attr + elif batch_size > 1: + rows, cols = [], [] + for i in range(batch_size): + rows.append(edges[0] + n_nodes * i) + cols.append(edges[1] + n_nodes * i) + edges = [torch.cat(rows), torch.cat(cols)] + return edges, edge_attr + + +def main(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + seed = 13 + torch.manual_seed(seed) + np.random.seed(seed) + + data_norm = "y" + grid_size = 100 + batch_size = 1 + + ### Data loading + data = np.load("data_n=10000.npy", allow_pickle=True) + # data = np.load("/Users/manos/data/gauge/data_n=10000.npy", allow_pickle=True) + X, Y = data.item()["x"], data.item()["y"] + + tr_idx = np.random.choice(X.shape[0], int(0.8 * X.shape[0]), replace=False) + mask = np.zeros(X.shape[0], dtype=bool) + mask[tr_idx] = True + X_tr, Y_tr = X[mask], Y[mask] + X_te, Y_te = X[~mask], Y[~mask] + + # reformat to (N, C, W, H) + X_tr = torch.Tensor(X_tr).view(-1, 1, grid_size, grid_size) + Y_tr = torch.Tensor(Y_tr).view(-1, 1) + X_te = torch.Tensor(X_te).view(-1, 1, grid_size, grid_size) + Y_te = torch.Tensor(Y_te).view(-1, 1) + + if data_norm == "standard": + X_tr, mean, std = standardize(X_tr) + X_te, _, _ = standardize(X_te, mean, std) + elif data_norm == "zeromean": + X_tr, mean, std = zeromean(X_tr) + X_te, _, _ = zeromean(X_te, mean, std) + elif data_norm == "whiten": + X_tr, mean, std = standardize(X_tr) + X_te, _, _ = standardize(X_te, mean, std) + + X_tr, zca, mean = whiten(X_tr) + X_te, _, _ = whiten(X_te, zca, mean) + elif data_norm == "y": + Y_tr, mean, std = standardize_y(Y_tr) + Y_te, _, _ = standardize_y(Y_te, mean, std) + + train_dl = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(X_tr, Y_tr), + batch_size=batch_size, + num_workers=4, + shuffle=True, + pin_memory=True, + ) + test_dl = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(X_te, Y_te), + batch_size=batch_size, + num_workers=4, + shuffle=True, + pin_memory=True, + ) + + epochs = 50 + L = lattice_nbr(grid_size) + sL = sorted(L, key=itemgetter(0)) + rows, cols = [], [] + for item in sL: + rows.append(item[0]) + cols.append(item[1]) + edges_b = [rows, cols] + + model = eg.EGNN( + in_node_nf=1, + hidden_nf=64, + out_node_nf=1, + in_edge_nf=1, + device=device, + n_layers=2, + ) + model.load_state_dict(torch.load("best_model_egnn.pth")) + model.to(device) + + loss_func = torch.nn.MSELoss() + + model.eval() + with torch.no_grad(): + net_loss = 0.0 + n_total = 0 + # for idx, (x, y) in enumerate(train_dl): + for idx, (x, y) in enumerate(test_dl): + x, y = x.to(device), y.to(device) + batch_size_t = x.shape[0] + edges, edge_attr = get_edges_batch( + edges_b, grid_size * grid_size, batch_size_t, device + ) + + # EGNN expects data as (N * grid_size * grid_size, 2) + x = x.view(batch_size_t * grid_size * grid_size, 1) + s = torch.cat((torch.cos(x), torch.sin(x)), dim=-1) + h = torch.ones(batch_size_t * grid_size * grid_size, 1, device=device) + + if idx == 0: + summary(model, input_data=[h, s, edges, edge_attr]) + h_hat, s_hat = model(h, s, edges, edge_attr) + + h_hat = h_hat.view(batch_size_t, grid_size * grid_size) + h_sum = torch.sum(h_hat, dim=1, keepdim=True) + + loss = loss_func(h_sum, y) + + if idx % 200 == 0: + print(f"actul energy: {y}\t estimated energy: {h_sum}") + net_loss += loss.item() * len(x) + n_total += len(x) + test_loss = net_loss / n_total + print(f"loss: {test_loss:.8f}") + + +if __name__ == "__main__": + main() diff --git a/train_egnn.py b/train_egnn.py new file mode 100644 index 0000000..a1984e5 --- /dev/null +++ b/train_egnn.py @@ -0,0 +1,352 @@ +# system imports +import os +import time + +# python imports +import numpy as np +from operator import itemgetter + +# torch imports +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +import torchvision + +from torchinfo import summary + +# egnn imports +import egnn_clean as eg + +# plotting imports +import seaborn as sns +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib import animation + +# plotting defaults +sns.set_theme() +sns.set_context("paper") +sns.set(font_scale=2) +cmap = plt.get_cmap("twilight") +# cmap_t = plt.get_cmap("turbo") +# cmap = plt.get_cmap("hsv") +color_plot = sns.cubehelix_palette(4, reverse=True, rot=-0.2) +from matplotlib import cm, rc + +rc("text", usetex=True) +rc("text.latex", preamble=r"\usepackage{amsmath}") + + +def zeromean(X, mean=None, std=None): + "Expects data in NxCxWxH." + if mean is None: + mean = X.mean(axis=(0, 2, 3)) + std = X.std(axis=(0, 2, 3)) + std = torch.ones(std.shape) + + X = torchvision.transforms.Normalize(mean, std)(X) + return X, mean, std + + +def standardize(X, mean=None, std=None): + "Expects data in NxCxWxH." + if mean is None: + mean = X.mean(axis=(0, 2, 3)) + std = X.std(axis=(0, 2, 3)) + + X = torchvision.transforms.Normalize(mean, std)(X) + return X, mean, std + + +def standardize_y(Y, mean=None, std=None): + "Expects data in Nx1." + if mean is None: + mean = Y.min() + std = Y.max() - Y.min() + + Y = (Y - mean) / std + return Y, mean, std + + +def whiten(X, zca=None, mean=None, eps=1e-8): + "Expects data in NxCxWxH." + os = X.shape + X = X.reshape(os[0], -1) + + if zca is None: + mean = X.mean(dim=0) + cov = np.cov(X, rowvar=False) + U, S, V = np.linalg.svd(cov) + zca = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + eps)), U.T)) + X = torch.Tensor(np.dot(X - mean, zca.T).reshape(os)) + return X, zca, mean + + +def lattice_nbr(grid_size): + """dxd edge list (periodic)""" + edg = set() + for x in range(grid_size): + for y in range(grid_size): + v = x + grid_size * y + for i in [-1, 1]: + edg.add((v, ((x + i) % grid_size) + y * grid_size)) + edg.add((v, x + ((y + i) % grid_size) * grid_size)) + return torch.tensor(np.array(list(edg)), dtype=int) + + +def get_edges_batch(edges, n_nodes, batch_size, device): + edge_attr = torch.ones(len(edges[0]) * batch_size, 1, device=device) + edges = [ + torch.LongTensor(edges[0]).to(device), + torch.LongTensor(edges[1]).to(device), + ] + if batch_size == 1: + return edges, edge_attr + elif batch_size > 1: + rows, cols = [], [] + for i in range(batch_size): + rows.append(edges[0] + n_nodes * i) + cols.append(edges[1] + n_nodes * i) + edges = [torch.cat(rows), torch.cat(cols)] + return edges, edge_attr + + +def main(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + seed = 13 + torch.manual_seed(seed) + np.random.seed(seed) + + data_norm = None # "y" + grid_size = 100 + batch_size = 64 + + ### Data loading + # data = np.load("data_n=10000.npy", allow_pickle=True) + data = np.load("data_n=10000_gauge.npy", allow_pickle=True) + # data = np.load("/Users/manos/data/gauge/data_n=10000.npy", allow_pickle=True) + X, Y = data.item()["x"], data.item()["y"] + + tr_idx = np.random.choice(X.shape[0], int(0.8 * X.shape[0]), replace=False) + mask = np.zeros(X.shape[0], dtype=bool) + mask[tr_idx] = True + X_tr, Y_tr = X[mask], Y[mask] + X_te, Y_te = X[~mask], Y[~mask] + + # reformat to (N, C, W, H) + X_tr = torch.Tensor(X_tr).view(-1, 1, grid_size, grid_size) + Y_tr = torch.Tensor(Y_tr).view(-1, 1) + X_te = torch.Tensor(X_te).view(-1, 1, grid_size, grid_size) + Y_te = torch.Tensor(Y_te).view(-1, 1) + + if data_norm == "standard": + X_tr, mean, std = standardize(X_tr) + X_te, _, _ = standardize(X_te, mean, std) + elif data_norm == "zeromean": + X_tr, mean, std = zeromean(X_tr) + X_te, _, _ = zeromean(X_te, mean, std) + elif data_norm == "whiten": + X_tr, mean, std = standardize(X_tr) + X_te, _, _ = standardize(X_te, mean, std) + + X_tr, zca, mean = whiten(X_tr) + X_te, _, _ = whiten(X_te, zca, mean) + elif data_norm == "y": + Y_tr, mean, std = standardize_y(Y_tr) + Y_te, _, _ = standardize_y(Y_te, mean, std) + + train_dl = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(X_tr, Y_tr), + batch_size=batch_size, + num_workers=4, + shuffle=True, + pin_memory=True, + ) + test_dl = torch.utils.data.DataLoader( + torch.utils.data.TensorDataset(X_te, Y_te), + batch_size=batch_size, + num_workers=4, + shuffle=True, + pin_memory=True, + ) + + epochs = 50 + L = lattice_nbr(grid_size) + sL = sorted(L, key=itemgetter(0)) + rows, cols = [], [] + for item in sL: + rows.append(item[0]) + cols.append(item[1]) + edges_b = [rows, cols] + + hidden_nf = 64 + model = eg.EGNN( + in_node_nf=1, + hidden_nf=hidden_nf, + out_node_nf=1, + in_edge_nf=1, + device=device, + n_layers=2, + ).to(device) + + opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0001) + # opt = optim.SGD(model.parameters(), lr=1e-3, weight_decay=0.0001, momentum=0.9) + schd = optim.lr_scheduler.MultiStepLR( + opt, [int(1 / 2 * epochs), int(3 / 4 * epochs)], gamma=0.1 + ) + loss_func = torch.nn.MSELoss() + + best_loss = 1e10 + best_epoch = None + loss_tr = [] + loss_te = [] + start = time.time() + prev = start + + # generate gauge field + prod = 0.75 # 0.25 + add = 0 # .25 + + # apply gauge + t = np.linspace(0, 1, grid_size) + sine_x = prod * np.cos(2 * np.pi * t) - add + sine_y = prod * np.cos(2 * np.pi * 2 * t) - add + + field_x = sine_x.reshape(1, grid_size) + field_x = np.repeat(field_x, grid_size, axis=0).reshape(1, 1, grid_size, grid_size) + + field_y = sine_y.reshape(grid_size, 1) + field_y = np.repeat(field_y, grid_size, axis=1).reshape(1, 1, grid_size, grid_size) + for epoch in range(1, epochs + 1): + net_loss = 0.0 + n_total = 0 + + model.train() + for idx, (x, y) in enumerate(train_dl): + x, y = x.to(device), y.to(device) + batch_size_t = x.shape[0] + edges, edge_attr = get_edges_batch( + edges_b, grid_size * grid_size, batch_size_t, device + ) + + field_x_c = ( + torch.tensor(np.repeat(field_x, batch_size_t, axis=0)) + .float() + .to(device) + ) + field_y_c = ( + torch.tensor(np.repeat(field_y, batch_size_t, axis=0)) + .float() + .to(device) + ) + x_t = x # + field_x_c + field_y_c + + # EGNN expects data as (N * grid_size * grid_size, 2) + x = x_t.view(batch_size_t * grid_size * grid_size, 1) + s = torch.cat((torch.cos(x), torch.sin(x)), dim=-1) + h = torch.ones(batch_size_t * grid_size * grid_size, 1, device=device) + + if idx == 0 and epoch == 1: + summary(model, input_data=[h, s, edges, edge_attr]) + + h_hat, s_hat = model(h, s, edges, edge_attr) + + h_hat = h_hat.view(batch_size_t, grid_size * grid_size) + h_sum = torch.sum(h_hat, dim=1, keepdim=True) + + loss = loss_func(h_sum, y) + + opt.zero_grad(set_to_none=True) + loss.backward() + opt.step() + + net_loss += loss.item() * len(x) + n_total += len(x) + train_loss = net_loss / n_total + loss_tr.append(train_loss) + + current = time.time() + + net_loss = 0.0 + n_total = 0 + model.eval() + with torch.no_grad(): + for idx, (x, y) in enumerate(test_dl): + x, y = x.to(device), y.to(device) + batch_size_t = x.shape[0] + + edges, edge_attr = get_edges_batch( + edges_b, grid_size * grid_size, batch_size_t, device + ) + + field_x_c = ( + torch.tensor(np.repeat(field_x, batch_size_t, axis=0)) + .float() + .to(device) + ) + field_y_c = ( + torch.tensor(np.repeat(field_y, batch_size_t, axis=0)) + .float() + .to(device) + ) + x_t = x # + field_x_c + field_y_c + + # EGNN expects data as (N * grid_size * grid_size, 2) + x = x_t.view(batch_size_t * grid_size * grid_size, 1) + s = torch.cat((torch.cos(x), torch.sin(x)), dim=-1) + h = torch.ones(batch_size_t * grid_size * grid_size, 1, device=device) + + h_hat, s_hat = model(h, s, edges, edge_attr) + + h_hat = h_hat.view(batch_size_t, grid_size * grid_size) + h_sum = torch.sum(h_hat, dim=1, keepdim=True) + + loss = loss_func(h_sum, y) + + net_loss += loss.item() * len(x) + n_total += len(x) + test_loss = net_loss / n_total + loss_te.append(test_loss) + print( + f"Epoch {epoch} Loss: {train_loss} (train)\t{test_loss} (test)\t({current - prev:3.2f} s/iter)" + ) + prev = current + + if train_loss <= best_loss: + best_loss = train_loss + best_epoch = epoch + torch.save(model.state_dict(), f"best_model_egnn{hidden_nf}.pth") + + with open("log_loss_tr_none.txt", "a") as file: + file.write("Epoch " + str(epoch) + ":\t" + str(train_loss) + "\n") + + with open("log_loss_te_none.txt", "a") as file: + file.write("Epoch " + str(epoch) + ":\t" + str(test_loss) + "\n") + + schd.step() + + plt.figure() + plt.yscale("log") + plt.plot( + range(len(loss_tr)), + loss_tr, + color=color_plot[0], + label="train", + ) + plt.plot( + range(len(loss_te)), + loss_te, + color=color_plot[2], + label="test", + ) + plt.ylabel("loss") + plt.xlabel("epoch") + plt.legend() + plt.savefig("loss.pdf", bbox_inches="tight") + plt.close() + + +if __name__ == "__main__": + main()