Skip to content

Commit

Permalink
Update train_pretrained.py
Browse files Browse the repository at this point in the history
  • Loading branch information
manosth authored Jul 13, 2024
1 parent 953031b commit 977a19a
Showing 1 changed file with 3 additions and 95 deletions.
98 changes: 3 additions & 95 deletions train_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
sns.set_context("paper")
sns.set(font_scale=2)
cmap = plt.get_cmap("twilight")
# cmap_t = plt.get_cmap("turbo")
# cmap = plt.get_cmap("hsv")
color_plot = sns.cubehelix_palette(4, reverse=True, rot=-0.2)
from matplotlib import cm, rc

Expand All @@ -32,13 +30,10 @@
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision

# from torchsummary import summary
from torchinfo import summary


def conv3x3(
in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1
) -> nn.Conv2d:
Expand All @@ -54,15 +49,12 @@ def conv3x3(
dilation=dilation,
)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion: int = 1

def __init__(
self,
inplanes: int,
Expand Down Expand Up @@ -105,7 +97,6 @@ def forward(self, x: Tensor) -> Tensor:

out += identity
out = self.relu(out)

return out


Expand All @@ -115,9 +106,7 @@ class Bottleneck(nn.Module):
# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion: int = 4

def __init__(
self,
inplanes: int,
Expand Down Expand Up @@ -163,10 +152,8 @@ def forward(self, x: Tensor) -> Tensor:

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -304,68 +291,19 @@ def _forward_impl(self, x: Tensor) -> Tensor:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def zeromean(X, mean=None, std=None):
"Expects data in NxCxWxH."
if mean is None:
mean = X.mean(axis=(0, 2, 3))
std = X.std(axis=(0, 2, 3))
std = torch.ones(std.shape)

X = torchvision.transforms.Normalize(mean, std)(X)
return X, mean, std


def standardize(X, mean=None, std=None):
"Expects data in NxCxWxH."
if mean is None:
mean = X.mean(axis=(0, 2, 3))
std = X.std(axis=(0, 2, 3))

X = torchvision.transforms.Normalize(mean, std)(X)
return X, mean, std


def standardize_y(Y, mean=None, std=None):
"Expects data in Nx1."
if mean is None:
mean = Y.min()
std = Y.max() - Y.min()

Y = (Y - mean) / std
return Y, mean, std


def whiten(X, zca=None, mean=None, eps=1e-8):
"Expects data in NxCxWxH."
os = X.shape
X = X.reshape(os[0], -1)

if zca is None:
mean = X.mean(dim=0)
cov = np.cov(X, rowvar=False)
U, S, V = np.linalg.svd(cov)
zca = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + eps)), U.T))
X = torch.Tensor(np.dot(X - mean, zca.T).reshape(os))
return X, zca, mean


def main():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
seed = 13
torch.manual_seed(seed)
np.random.seed(seed)

data_norm = None # "y"
grid_size = 100
batch_size = 64
# data = np.load("/Users/manos/data/gauge/data_n=10000.npy", allow_pickle=True)
data = np.load("data_n=10000.npy", allow_pickle=True)
X, Y = data.item()["x"], data.item()["y"]

Expand All @@ -381,22 +319,6 @@ def main():
X_te = torch.Tensor(X_te).view(-1, 1, grid_size, grid_size)
Y_te = torch.Tensor(Y_te).view(-1, 1)

if data_norm == "standard":
X_tr, mean, std = standardize(X_tr)
X_te, _, _ = standardize(X_te, mean, std)
elif data_norm == "zeromean":
X_tr, mean, std = zeromean(X_tr)
X_te, _, _ = zeromean(X_te, mean, std)
elif data_norm == "whiten":
X_tr, mean, std = standardize(X_tr)
X_te, _, _ = standardize(X_te, mean, std)

X_tr, zca, mean = whiten(X_tr)
X_te, _, _ = whiten(X_te, zca, mean)
elif data_norm == "y":
Y_tr, mean, std = standardize_y(Y_tr)
Y_te, _, _ = standardize_y(Y_te, mean, std)

train_dl = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(X_tr, Y_tr),
batch_size=batch_size,
Expand All @@ -413,10 +335,7 @@ def main():
)

epochs = 50

train = "pre"
# train = "pre"
# train = "fine"

if train == "train":
model = ResNet(Bottleneck, [2, 2, 2, 2], in_channels=1)
Expand All @@ -431,21 +350,13 @@ def main():
n_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(n_ftrs, 1)

# model_pre = torchvision.models.resnet18(weights="IMAGENET1K_V1")

# # sum the channels of the first layer
# state_dict = model_pre.state_dict()
# conv1_weight = state_dict["conv1.weight"]
# state_dict["conv1.weight"] = conv1_weight.sum(dim=1, keepdim=True)
# model.load_state_dict(state_dict)

model = model.to(device)
if train == "train":
summary(model, input_size=(batch_size, 1, grid_size, grid_size))
else:
summary(model, input_size=(batch_size, 3, grid_size, grid_size))

opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0001)
# opt = optim.SGD(model.parameters(), lr=1e-3, weight_decay=0.0001, momentum=0.9)
schd = optim.lr_scheduler.MultiStepLR(
opt, [int(1 / 2 * epochs), int(3 / 4 * epochs)], gamma=0.1
)
Expand All @@ -459,8 +370,8 @@ def main():
prev = start

# generate gauge field
prod = 0.75 # 0.25
add = 0 # .25
prod = 0.75
add = 0

# apply gauge
t = np.linspace(0, 1, grid_size)
Expand Down Expand Up @@ -500,7 +411,6 @@ def main():
s = x + field_x_c + field_y_c
s = s.repeat(1, 3, 1, 1)

# y_hat = model(x)
y_hat = model(s)
loss = loss_func(y_hat, y)

Expand Down Expand Up @@ -542,7 +452,6 @@ def main():
s = x + field_x_c + field_y_c
s = s.repeat(1, 3, 1, 1)

# y_hat = model(x)
y_hat = model(s)
loss = loss_func(y_hat, y)

Expand Down Expand Up @@ -588,6 +497,5 @@ def main():
plt.savefig(f"loss_resnet-{train}.pdf", bbox_inches="tight")
plt.close()


if __name__ == "__main__":
main()

0 comments on commit 977a19a

Please sign in to comment.