Skip to content

Commit

Permalink
Update train_egnn_ours.py
Browse files Browse the repository at this point in the history
  • Loading branch information
manosth authored Jul 13, 2024
1 parent 9a9e945 commit ebe296b
Showing 1 changed file with 16 additions and 205 deletions.
221 changes: 16 additions & 205 deletions train_egnn_ours.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,60 +27,12 @@
sns.set_context("paper")
sns.set(font_scale=2)
cmap = plt.get_cmap("twilight")
# cmap_t = plt.get_cmap("turbo")
# cmap = plt.get_cmap("hsv")
color_plot = sns.cubehelix_palette(4, reverse=True, rot=-0.2)
from matplotlib import cm, rc

rc("text", usetex=True)
rc("text.latex", preamble=r"\usepackage{amsmath}")


def zeromean(X, mean=None, std=None):
"Expects data in NxCxWxH."
if mean is None:
mean = X.mean(axis=(0, 2, 3))
std = X.std(axis=(0, 2, 3))
std = torch.ones(std.shape)

X = torchvision.transforms.Normalize(mean, std)(X)
return X, mean, std


def standardize(X, mean=None, std=None):
"Expects data in NxCxWxH."
if mean is None:
mean = X.mean(axis=(0, 2, 3))
std = X.std(axis=(0, 2, 3))

X = torchvision.transforms.Normalize(mean, std)(X)
return X, mean, std


def standardize_y(Y, mean=None, std=None):
"Expects data in Nx1."
if mean is None:
mean = Y.min()
std = Y.max() - Y.min()

Y = (Y - mean) / std
return Y, mean, std


def whiten(X, zca=None, mean=None, eps=1e-8):
"Expects data in NxCxWxH."
os = X.shape
X = X.reshape(os[0], -1)

if zca is None:
mean = X.mean(dim=0)
cov = np.cov(X, rowvar=False)
U, S, V = np.linalg.svd(cov)
zca = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + eps)), U.T))
X = torch.Tensor(np.dot(X - mean, zca.T).reshape(os))
return X, zca, mean


def lattice_nbr(grid_size):
"""dxd edge list (periodic)"""
edg = set()
Expand Down Expand Up @@ -153,32 +105,6 @@ def forward(self, x):
out = x + out
return out


def energy_loss_nima(grid_list, nbr):
"""
Computes the energy of the configuration in the XY model.
Parameters
----------
grid_list: PyTorch tensor of size (grid_size ** 2)
List containing the spin configuration (angle) of each lattice point.
nbr: dict
Dictionary containing the neighbors of each lattice point.
Returns
-------
loss: PyTorch tensor
Energy of the configuration.
"""

loss = (
-1
/ len(nbr)
* torch.sum(torch.cos(grid_list[:, nbr[:, 0]] - grid_list[:, nbr[:, 1]]), dim=1)
)
return loss


class GaugeNet(nn.Module):
def __init__(
self, in_dim, grid_size, hid_dim=64, out_dim=1, n_layers=2, device="cpu"
Expand All @@ -201,20 +127,17 @@ def __init__(
)
self.down = torch.sparse_coo_tensor(
down.t(),
# up.t(),
torch.ones(len(down), device=device),
(grid_size * grid_size, grid_size * grid_size),
device=device,
)
self.left = torch.sparse_coo_tensor(
# up.t(),
left.t(),
torch.ones(len(left), device=device),
(grid_size * grid_size, grid_size * grid_size),
device=device,
)
self.right = torch.sparse_coo_tensor(
# up.t(),
right.t(),
torch.ones(len(right), device=device),
(grid_size * grid_size, grid_size * grid_size),
Expand All @@ -223,49 +146,12 @@ def __init__(

self.H = torch.eye(2, device=device)

# lattice_list has 40,000 (4 * grid_size ** 2) elements
# self.A = -1 / (4 * grid_size**2) * torch.ones(4, 1, device=device)
# self.s = nn.Parameter(torch.randn(1, device=device))

### SAVED
# self.emb = nn.Linear(4, 1)
# self.act1 = nn.SiLU()
# self.post = nn.Linear(grid_size**2, out_dim)
### SAVED

### SAVED2
self.emb = nn.Linear(4, hid_dim)
self.act1 = nn.ReLU()
self.hidden = nn.Linear(hid_dim, hid_dim)
self.act2 = nn.ReLU()
self.post = nn.Linear(hid_dim, out_dim)
### SAVED2

# self.act1 = nn.Identity()
# self.hidden = nn.Linear(hid_dim, hid_dim)
# self.act2 = nn.ReLU()
# self.act2 = nn.Identity()
# self.final = nn.Linear(grid_size * grid_size, out_dim, bias=False)
# self.emb = nn.Linear(4, 1, bias=False)
# self.pre_mlp = nn.Sequential(
# nn.Linear(hid_dim, hid_dim, bias=False),
# nn.SiLU(),
# nn.Linear(hid_dim, hid_dim, bias=False),
# )
# list = []
# for i in range(n_layers):
# self.add_module("layer_%d" % i, GaugeLayer(hid_dim, device))
# # self.add_module("act_%d" % i, nn.SiLU())
# # list.append(nn.Linear(hid_dim, hid_dim))
# # list.append(nn.SiLU())
# # self.net = nn.Sequential(*list)
# self.post_mlp = nn.Sequential(
# nn.Dropout(0.4),
# nn.Linear(hid_dim, hid_dim),
# nn.SiLU(),
# nn.Linear(hid_dim, out_dim),
# )


def forward(self, x):
"""
Currentl Torch doesn't support sparse bmm, so we need to do it manually.
Expand Down Expand Up @@ -301,68 +187,24 @@ def forward(self, x):
s_left = torch.einsum("bim,bin,mn->bi", x, Ax_left, self.H)
s_right = torch.einsum("bim,bin,mn->bi", x, Ax_right, self.H)

# s_up = torch.einsum("bim,ij,bjn,mn->bi", x, self.up, x, self.H)
# s_down = torch.einsum("bim,ij,bjn,mn->bi", x, self.down, x, self.H)
# s_left = torch.einsum("bim,ij,bjn,mn->bi", x, self.left, x, self.H)
# s_right = torch.einsum("bim,ij,bjn,mn->bi", x, self.right, x, self.H)

# h is (B, grid_size ** 2, 4)
h = torch.stack([s_up, s_down, s_left, s_right], dim=2)
# h = h @ self.A
# h = h.sum(dim=1)
# h = self.s * h
# print(h.shape)

# h = self.pre_mlp(h) + h
# h = self.act(h)
# # print(h.shape)
# h = self.pre_mlp(h)
# print(h.shape)
# for layer in range(self.n_layers):
# h = self._modules["layer_%d" % layer](h)
# # print(h.shape)
# # h = self._modules["act_%d" % layer](h) + h
# # h = self.net(h)
# h = self.post_mlp(h)
# h = h.sum(dim=1)
# h = self.act(h)
# h = self.final(h.squeeze())
h = self.act1(self.emb(h))
h = self.act2(self.hidden(h))
# h = self.act2(self.hidden(h)) # + h
h = self.post(h)

# h = self.act2(h)
# print(h.shape)
return h.sum(dim=1)
# return h

### SAVED
# h = self.act1(self.emb(h)).squeeze()
# h = self.post(h)
# return h
### SAVED

### SAVED2
# h = self.act1(self.emb(h))
# h = self.post(h)
# return h.sum(dim=1)
### SAVED2


def main():
device = "cuda:1" if torch.cuda.is_available() else "cpu"
seed = 13
torch.manual_seed(seed)
np.random.seed(seed)

data_norm = None # "y"
grid_size = 100
batch_size = 64

### Data loading
data = np.load("data_n=10000.npy", allow_pickle=True)
# data = np.load("/Users/manos/data/gauge/data_n=10000.npy", allow_pickle=True)
X, Y = data.item()["x"], data.item()["y"]

tr_idx = np.random.choice(X.shape[0], int(0.8 * X.shape[0]), replace=False)
Expand All @@ -377,22 +219,6 @@ def main():
X_te = torch.Tensor(X_te).view(-1, 1, grid_size, grid_size)
Y_te = torch.Tensor(Y_te).view(-1, 1)

if data_norm == "standard":
Y_tr, mean, std = standardize(Y_tr)
Y_te, _, _ = standardize(Y_te, mean, std)
elif data_norm == "zeromean":
X_tr, mean, std = zeromean(X_tr)
X_te, _, _ = zeromean(X_te, mean, std)
elif data_norm == "whiten":
X_tr, mean, std = standardize(X_tr)
X_te, _, _ = standardize(X_te, mean, std)

X_tr, zca, mean = whiten(X_tr)
X_te, _, _ = whiten(X_te, zca, mean)
elif data_norm == "y":
Y_tr, mean, std = standardize_y(Y_tr)
Y_te, _, _ = standardize_y(Y_te, mean, std)

train_dl = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(X_tr, Y_tr),
batch_size=batch_size,
Expand All @@ -415,23 +241,19 @@ def main():
).to(device)

opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0001)
# opt = optim.SGD(model.parameters(), lr=1e-3, weight_decay=0.0001, momentum=0.9)
schd = optim.lr_scheduler.MultiStepLR(
opt,
[int(1 / 2 * epochs), int(3 / 4 * epochs)],
gamma=0.1,
# opt,
# [140, 180],
# gamma=0.1,
)
loss_func = torch.nn.MSELoss()

summary(model, input_size=(batch_size, grid_size * grid_size, 2))
edge_list = lattice_nbr(grid_size)

# generate gauge field
prod = 0.75 # 0.25
add = 0 # .25
prod = 0.75
add = 0

# apply gauge
t = np.linspace(0, 1, grid_size)
Expand Down Expand Up @@ -460,20 +282,13 @@ def main():

x = x.view(-1, grid_size * grid_size, 1)
bs = x.shape[0]
# field_x_c = torch.tensor(np.repeat(field_x, bs, axis=0)).float().to(device)
# field_y_c = torch.tensor(np.repeat(field_y, bs, axis=0)).float().to(device)
# x_t = x + field_x_c + field_y_c
# print(energy_loss_nima(x, edge_list)[:4])
# s = torch.cat((torch.cos(x_t), torch.sin(x_t)), dim=-1)
s = torch.cat((torch.cos(x), torch.sin(x)), dim=-1)
field_x_c = torch.tensor(np.repeat(field_x, bs, axis=0)).float().to(device)
field_y_c = torch.tensor(np.repeat(field_y, bs, axis=0)).float().to(device)
x_t = x + field_x_c + field_y_c
print(energy_loss_nima(x, edge_list)[:4])
s = torch.cat((torch.cos(x_t), torch.sin(x_t)), dim=-1)

h_hat = model(s)

# h_hat = model(s)
if epoch == 1 and idx == 0:
print(h_hat.shape)
# print(h_hat[:4])
# print(y[:4])
loss = loss_func(h_hat, y)

opt.zero_grad(set_to_none=True)
Expand All @@ -482,12 +297,10 @@ def main():

net_loss += loss.item() * len(x)
n_total += len(x)
# break
train_loss = net_loss / n_total
loss_tr.append(train_loss)

current = time.time()
# break
net_loss = 0.0
n_total = 0
model.eval()
Expand All @@ -496,15 +309,14 @@ def main():
x, y = x.to(device), y.to(device)
x = x.view(-1, grid_size * grid_size, 1)
bs = x.shape[0]
# field_x_c = (
# torch.tensor(np.repeat(field_x, bs, axis=0)).float().to(device)
# )
# field_y_c = (
# torch.tensor(np.repeat(field_y, bs, axis=0)).float().to(device)
# )
# x_t = x + field_x_c + field_y_c
# s = torch.cat((torch.cos(x_t), torch.sin(x_t)), dim=-1)
s = torch.cat((torch.cos(x), torch.sin(x)), dim=-1)
field_x_c = (
torch.tensor(np.repeat(field_x, bs, axis=0)).float().to(device)
)
field_y_c = (
torch.tensor(np.repeat(field_y, bs, axis=0)).float().to(device)
)
x_t = x + field_x_c + field_y_c
s = torch.cat((torch.cos(x_t), torch.sin(x_t)), dim=-1)

h_hat = model(s)
loss = loss_func(h_hat, y)
Expand Down Expand Up @@ -551,6 +363,5 @@ def main():
plt.savefig("loss_ours.pdf", bbox_inches="tight")
plt.close()


if __name__ == "__main__":
main()

0 comments on commit ebe296b

Please sign in to comment.