Skip to content

Commit

Permalink
modifying transforms for backpropability
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Oct 24, 2024
1 parent 7ba2972 commit 67f43f3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 30 deletions.
58 changes: 58 additions & 0 deletions applications/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from argparse import ArgumentParser
import logging
import os
from os.path import join
import shutil

import yaml
import torch

def model_checkpoint_to_checkpoint(save_loc):
checkpoint_to_convert = join(save_loc, "convert_checkpoint.pt")

if not os.path.isfile(checkpoint_to_convert):
logging.info("copying model_checkpoint.pt to stage for conversion\n")
ckpt = join(save_loc, "model_checkpoint.pt")
shutil.copy(ckpt, checkpoint_to_convert)

state_dict = torch.load(checkpoint_to_convert)
dst_checkpoint = join(save_loc, "checkpoint.pt")
torch.save({"model_state_dict": state_dict},
dst_checkpoint)
logging.info(f"converted model_checkpoint to {dst_checkpoint}")

if __name__ == "__main__":
description = "Train a segmengation model on a hologram data set"
parser = ArgumentParser(description=description)
parser.add_argument(
"-c",
"--config",
dest="model_config",
type=str,
default=False,
help="Path to the model configuration (yml) containing your inputs.",
)

args = parser.parse_args()
args_dict = vars(args)
config = args_dict.pop("model_config")

# Set up logger to print stuff
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

with open(config) as cf:
conf = yaml.load(cf, Loader=yaml.FullLoader)

# Create directories if they do not exist and copy yml file
save_loc = os.path.expandvars(conf["save_loc"])

model_checkpoint_to_checkpoint(save_loc)

3 changes: 0 additions & 3 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,6 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ
assert "level_list" in conf['data'], (
'need to specify hybrid sigma level indices for skebs'
)
assert "timestep" in conf['data'], (
'need to specify timestep in seconds for skebs'
)
assert conf['trainer']["train_batch_size"] == conf['trainer']["valid_batch_size"], (
'train and valid batch sizes need to be the same for skebs'
)
Expand Down
9 changes: 7 additions & 2 deletions credit/physics_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ def __init__(self,
b_vals,
plev_dim=1):
super().__init__()
self.a_vals = a_vals
self.b_vals = b_vals
self.register_buffer('a_vals',
a_vals,
persistent=False)
self.register_buffer('b_vals',
b_vals,
persistent=False)

self.plev_dim = plev_dim
self.is_fully_initialized = False
def compute_p(self, sp):
Expand Down
40 changes: 23 additions & 17 deletions credit/postblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,11 @@ def __init__(self,

def forward(self, x):
x = x.permute(0, 2, 3, 4, 1) # put channels last
logger.debug(x.shape)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
logger.debug(x.shape)
x = x.permute(0, -1, 1, 2, 3) # put channels back to 1st dim
logger.debug(x.shape)
return x
# return torch.ones((x.shape[0], self.levels, 1, self.nlat, self.nlon))

class SKEBS(nn.Module):
"""
Expand Down Expand Up @@ -607,7 +603,7 @@ def __init__(self, post_conf):
self.sp_index = post_conf["skebs"]["SP_ind"]

# need this info
self.timestep = post_conf["data"]["timestep"]
self.timestep = post_conf["data"]["lead_time_periods"] * 3600
self.level_info = xr.open_dataset(post_conf["data"]["level_info_file"])
self.level_list = post_conf["data"]["level_list"]
self.surface_area = xr.open_dataset(post_conf["data"]["save_loc_static"])["surface_area"].to_numpy()
Expand All @@ -619,6 +615,9 @@ def __init__(self, post_conf):

num_channels = self.levels * self.channels + self.surface_channels + self.output_only_channels
self.backscatter_network = Backscatter_FCNN(num_channels, self.levels)

self.state_trans = load_transforms(post_conf, scaler_only=True)

def initialize_sht(self):
"""
Initialize spherical harmonics and inverse spherical harmonics transformations
Expand All @@ -637,7 +636,6 @@ def initialize_skebs_parameters(self):
torch.arange(1, self.lmax + 1).unsqueeze(1),
persistent=False) # (lmax, 1)
# assume (b, c, t, ,lat,lon)

# parameters we want to learn: (init to berner 2009 values for now)
self.alpha = Parameter(torch.tensor(0.5, requires_grad=True))
self.variance = Parameter(torch.tensor(0.083, requires_grad=True))
Expand All @@ -653,26 +651,28 @@ def initialize_pattern(self, y_pred):
"""
y_shape = y_pred.shape

self.register_buffer('spec_coef',
torch.zeros(
self.spec_coef = torch.zeros(
(y_shape[0], 1, 1, self.lmax, self.mmax), # b, 1, 1, lat, lon
dtype=torch.cfloat,),
persistent=False)
dtype=torch.cfloat,
device=y_pred.device)
logger.info(f"spec_coef device: {self.spec_coef.device}")
# self.register_buffer('spec_coef',
# torch.zeros(
# (y_shape[0], 1, 1, self.lmax, self.mmax), # b, 1, 1, lat, lon
# dtype=torch.cfloat,),
# persistent=False)

# initialize pattern todo: how many iters?
for i in range(10):
logger.debug(f'cycle {i}')
for i in range(1):
self.spec_coef = self.cycle_pattern(self.spec_coef)

def cycle_pattern(self, spec_coef):
Gamma = torch.sum(self.lrange * (self.lrange + 1.0) * (self.lrange + 2.0) * self.lrange ** (2.0 * self.p)) # scalar
b = torch.sqrt((4.0 * PI * RAD_EARTH**2.0) / (self.variance * Gamma) * self.alpha * self.dE) # scalar
g_n = b * self.lrange ** self.p # (lmax, 1)
logger.debug(f"g_n: {g_n.shape}")

device = spec_coef.device
logger.info(f"spec_coef device: {device}")
noise = self.variance * torch.randn(self.spec_coef.shape, device=device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n?
logger.debug(f"spec_coef device: {spec_coef.device}")
noise = self.variance * torch.randn(self.spec_coef.shape, device=spec_coef.device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n?
new_coef = (1.0 - self.alpha) * spec_coef + g_n * torch.sqrt(self.alpha) * noise # (lmax, mmax)
return new_coef

Expand All @@ -688,7 +688,6 @@ def initialize_mass_calc(self):
self.register_buffer('surface_area_tensor',
torch.from_numpy(self.surface_area).view(1, 1, 1, self.nlat, 1),
persistent=False)
logger.debug(f"surface_area_tensor: {self.surface_area_tensor.shape}")
self.compute_plev_quantities = compute_pressure_on_mlevs(a_vals=self.a_tensor, b_vals=self.b_tensor, plev_dim=1)

def calculate_mass(self, sp):
Expand All @@ -699,7 +698,11 @@ def calculate_mass(self, sp):
)

def forward(self, x):
logger.debug(f"lrange device: {self.lrange.device}")
logger.debug(f"sa tensor device: {self.surface_area_tensor.device}")
x = x["y_pred"]
logger.debug(f"x device: {x.device}")
x = self.state_trans.inverse_transform(x)

if not self.spec_coef_is_initialized: #hacky way of doing lazymodulemixin
self.initialize_pattern(x)
Expand All @@ -716,7 +719,10 @@ def forward(self, x):
total_forcing = self.timestep * backscatter_pred * pattern_on_grid
# shape (b, levels, t, lat, lon)


assert torch.min(x[:, self.sp_index]) >= 0., "sp less than 0"
mlev_mass = self.calculate_mass(x[:, self.sp_index : self.sp_index + 1]) # slice to keep dims
assert torch.min(mlev_mass) >= 0., "mass is less than 0"
# (b, levels, 1, lat, lon)
u_squared, v_squared = x[:, self.U_inds] ** 2, x[:, self.V_inds] ** 2
wind_squared = u_squared + v_squared
Expand Down
14 changes: 7 additions & 7 deletions credit/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,23 +340,23 @@ def inverse_transform(self, x: torch.Tensor) -> torch.Tensor:

# Subset upper air
tensor_upper_air = x[:, :self.num_upper_air, :, :]
transformed_upper_air = tensor_upper_air.clone()
transformed_upper_air = torch.empty_like(tensor_upper_air)

# Surface variables
if self.flag_surface:
tensor_surface = x[:, self.num_upper_air:(self.num_upper_air+self.num_surface), :, :]
transformed_surface = tensor_surface.clone()
transformed_surface = torch.empty_like(tensor_surface)

# Diagnostic variables (the very last of the stack)
if self.flag_diagnostic:
tensor_diagnostic = x[:, -self.num_diagnostic:, :, :]
transformed_diagnostic = tensor_diagnostic.clone()
transformed_diagnostic = torch.empty_like(tensor_diagnostic)

# Reverse upper air variables
k = 0
for name in self.varname_upper_air:
mean_tensor = self.mean_tensors[name].to(device)
std_tensor = self.std_tensors[name].to(device)
mean_tensor = self.mean_tensors[name].to(device).view(1, self.levels, 1, 1, 1) #(16,)
std_tensor = self.std_tensors[name].to(device).view(1, self.levels, 1, 1, 1)
for level in range(self.levels):
mean = mean_tensor[level]
std = std_tensor[level]
Expand All @@ -366,8 +366,8 @@ def inverse_transform(self, x: torch.Tensor) -> torch.Tensor:
# Reverse surface variables
if self.flag_surface:
for k, name in enumerate(self.varname_surface):
mean = self.mean_tensors[name].to(device)
std = self.std_tensors[name].to(device)
mean = self.mean_tensors[name].to(device) #size none
std = self.std_tensors[name].to(device) #size none
transformed_surface[:, k] = tensor_surface[:, k] * std + mean

# Reverse diagnostic variables
Expand Down
2 changes: 1 addition & 1 deletion tests/test_postblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_SKEBS_integration():

model = CrossFormer(**conf["model"])
device = torch.device(f"cuda:{1 % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu")
model = model.load_model(conf)
# model = model.load_model(conf)
model.to(device)
logging.info(f"model: {device}")

Expand Down

0 comments on commit 67f43f3

Please sign in to comment.