From 67f43f3137e478a6df46d0936b127c5ba296fe74 Mon Sep 17 00:00:00 2001 From: dkimpara Date: Thu, 24 Oct 2024 15:28:27 -0600 Subject: [PATCH] modifying transforms for backpropability --- applications/convert_checkpoint.py | 58 ++++++++++++++++++++++++++++++ credit/parser.py | 3 -- credit/physics_core.py | 9 +++-- credit/postblock.py | 40 ++++++++++++--------- credit/transforms.py | 14 ++++---- tests/test_postblock.py | 2 +- 6 files changed, 96 insertions(+), 30 deletions(-) create mode 100644 applications/convert_checkpoint.py diff --git a/applications/convert_checkpoint.py b/applications/convert_checkpoint.py new file mode 100644 index 0000000..c7b4f7b --- /dev/null +++ b/applications/convert_checkpoint.py @@ -0,0 +1,58 @@ +from argparse import ArgumentParser +import logging +import os +from os.path import join +import shutil + +import yaml +import torch + +def model_checkpoint_to_checkpoint(save_loc): + checkpoint_to_convert = join(save_loc, "convert_checkpoint.pt") + + if not os.path.isfile(checkpoint_to_convert): + logging.info("copying model_checkpoint.pt to stage for conversion\n") + ckpt = join(save_loc, "model_checkpoint.pt") + shutil.copy(ckpt, checkpoint_to_convert) + + state_dict = torch.load(checkpoint_to_convert) + dst_checkpoint = join(save_loc, "checkpoint.pt") + torch.save({"model_state_dict": state_dict}, + dst_checkpoint) + logging.info(f"converted model_checkpoint to {dst_checkpoint}") + +if __name__ == "__main__": + description = "Train a segmengation model on a hologram data set" + parser = ArgumentParser(description=description) + parser.add_argument( + "-c", + "--config", + dest="model_config", + type=str, + default=False, + help="Path to the model configuration (yml) containing your inputs.", + ) + + args = parser.parse_args() + args_dict = vars(args) + config = args_dict.pop("model_config") + + # Set up logger to print stuff + root = logging.getLogger() + root.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + # Stream output to stdout + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + root.addHandler(ch) + + with open(config) as cf: + conf = yaml.load(cf, Loader=yaml.FullLoader) + + # Create directories if they do not exist and copy yml file + save_loc = os.path.expandvars(conf["save_loc"]) + + model_checkpoint_to_checkpoint(save_loc) + diff --git a/credit/parser.py b/credit/parser.py index 5bc2051..ebe581f 100644 --- a/credit/parser.py +++ b/credit/parser.py @@ -372,9 +372,6 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ assert "level_list" in conf['data'], ( 'need to specify hybrid sigma level indices for skebs' ) - assert "timestep" in conf['data'], ( - 'need to specify timestep in seconds for skebs' - ) assert conf['trainer']["train_batch_size"] == conf['trainer']["valid_batch_size"], ( 'train and valid batch sizes need to be the same for skebs' ) diff --git a/credit/physics_core.py b/credit/physics_core.py index 7dbd10f..72e4cb3 100644 --- a/credit/physics_core.py +++ b/credit/physics_core.py @@ -34,8 +34,13 @@ def __init__(self, b_vals, plev_dim=1): super().__init__() - self.a_vals = a_vals - self.b_vals = b_vals + self.register_buffer('a_vals', + a_vals, + persistent=False) + self.register_buffer('b_vals', + b_vals, + persistent=False) + self.plev_dim = plev_dim self.is_fully_initialized = False def compute_p(self, sp): diff --git a/credit/postblock.py b/credit/postblock.py index a4c8425..0bc285e 100644 --- a/credit/postblock.py +++ b/credit/postblock.py @@ -565,15 +565,11 @@ def __init__(self, def forward(self, x): x = x.permute(0, 2, 3, 4, 1) # put channels last - logger.debug(x.shape) x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) - logger.debug(x.shape) x = x.permute(0, -1, 1, 2, 3) # put channels back to 1st dim - logger.debug(x.shape) return x - # return torch.ones((x.shape[0], self.levels, 1, self.nlat, self.nlon)) class SKEBS(nn.Module): """ @@ -607,7 +603,7 @@ def __init__(self, post_conf): self.sp_index = post_conf["skebs"]["SP_ind"] # need this info - self.timestep = post_conf["data"]["timestep"] + self.timestep = post_conf["data"]["lead_time_periods"] * 3600 self.level_info = xr.open_dataset(post_conf["data"]["level_info_file"]) self.level_list = post_conf["data"]["level_list"] self.surface_area = xr.open_dataset(post_conf["data"]["save_loc_static"])["surface_area"].to_numpy() @@ -619,6 +615,9 @@ def __init__(self, post_conf): num_channels = self.levels * self.channels + self.surface_channels + self.output_only_channels self.backscatter_network = Backscatter_FCNN(num_channels, self.levels) + + self.state_trans = load_transforms(post_conf, scaler_only=True) + def initialize_sht(self): """ Initialize spherical harmonics and inverse spherical harmonics transformations @@ -637,7 +636,6 @@ def initialize_skebs_parameters(self): torch.arange(1, self.lmax + 1).unsqueeze(1), persistent=False) # (lmax, 1) # assume (b, c, t, ,lat,lon) - # parameters we want to learn: (init to berner 2009 values for now) self.alpha = Parameter(torch.tensor(0.5, requires_grad=True)) self.variance = Parameter(torch.tensor(0.083, requires_grad=True)) @@ -653,26 +651,28 @@ def initialize_pattern(self, y_pred): """ y_shape = y_pred.shape - self.register_buffer('spec_coef', - torch.zeros( + self.spec_coef = torch.zeros( (y_shape[0], 1, 1, self.lmax, self.mmax), # b, 1, 1, lat, lon - dtype=torch.cfloat,), - persistent=False) + dtype=torch.cfloat, + device=y_pred.device) + logger.info(f"spec_coef device: {self.spec_coef.device}") + # self.register_buffer('spec_coef', + # torch.zeros( + # (y_shape[0], 1, 1, self.lmax, self.mmax), # b, 1, 1, lat, lon + # dtype=torch.cfloat,), + # persistent=False) # initialize pattern todo: how many iters? - for i in range(10): - logger.debug(f'cycle {i}') + for i in range(1): self.spec_coef = self.cycle_pattern(self.spec_coef) def cycle_pattern(self, spec_coef): Gamma = torch.sum(self.lrange * (self.lrange + 1.0) * (self.lrange + 2.0) * self.lrange ** (2.0 * self.p)) # scalar b = torch.sqrt((4.0 * PI * RAD_EARTH**2.0) / (self.variance * Gamma) * self.alpha * self.dE) # scalar g_n = b * self.lrange ** self.p # (lmax, 1) - logger.debug(f"g_n: {g_n.shape}") - device = spec_coef.device - logger.info(f"spec_coef device: {device}") - noise = self.variance * torch.randn(self.spec_coef.shape, device=device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n? + logger.debug(f"spec_coef device: {spec_coef.device}") + noise = self.variance * torch.randn(self.spec_coef.shape, device=spec_coef.device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n? new_coef = (1.0 - self.alpha) * spec_coef + g_n * torch.sqrt(self.alpha) * noise # (lmax, mmax) return new_coef @@ -688,7 +688,6 @@ def initialize_mass_calc(self): self.register_buffer('surface_area_tensor', torch.from_numpy(self.surface_area).view(1, 1, 1, self.nlat, 1), persistent=False) - logger.debug(f"surface_area_tensor: {self.surface_area_tensor.shape}") self.compute_plev_quantities = compute_pressure_on_mlevs(a_vals=self.a_tensor, b_vals=self.b_tensor, plev_dim=1) def calculate_mass(self, sp): @@ -699,7 +698,11 @@ def calculate_mass(self, sp): ) def forward(self, x): + logger.debug(f"lrange device: {self.lrange.device}") + logger.debug(f"sa tensor device: {self.surface_area_tensor.device}") x = x["y_pred"] + logger.debug(f"x device: {x.device}") + x = self.state_trans.inverse_transform(x) if not self.spec_coef_is_initialized: #hacky way of doing lazymodulemixin self.initialize_pattern(x) @@ -716,7 +719,10 @@ def forward(self, x): total_forcing = self.timestep * backscatter_pred * pattern_on_grid # shape (b, levels, t, lat, lon) + + assert torch.min(x[:, self.sp_index]) >= 0., "sp less than 0" mlev_mass = self.calculate_mass(x[:, self.sp_index : self.sp_index + 1]) # slice to keep dims + assert torch.min(mlev_mass) >= 0., "mass is less than 0" # (b, levels, 1, lat, lon) u_squared, v_squared = x[:, self.U_inds] ** 2, x[:, self.V_inds] ** 2 wind_squared = u_squared + v_squared diff --git a/credit/transforms.py b/credit/transforms.py index 788bcc9..456636e 100644 --- a/credit/transforms.py +++ b/credit/transforms.py @@ -340,23 +340,23 @@ def inverse_transform(self, x: torch.Tensor) -> torch.Tensor: # Subset upper air tensor_upper_air = x[:, :self.num_upper_air, :, :] - transformed_upper_air = tensor_upper_air.clone() + transformed_upper_air = torch.empty_like(tensor_upper_air) # Surface variables if self.flag_surface: tensor_surface = x[:, self.num_upper_air:(self.num_upper_air+self.num_surface), :, :] - transformed_surface = tensor_surface.clone() + transformed_surface = torch.empty_like(tensor_surface) # Diagnostic variables (the very last of the stack) if self.flag_diagnostic: tensor_diagnostic = x[:, -self.num_diagnostic:, :, :] - transformed_diagnostic = tensor_diagnostic.clone() + transformed_diagnostic = torch.empty_like(tensor_diagnostic) # Reverse upper air variables k = 0 for name in self.varname_upper_air: - mean_tensor = self.mean_tensors[name].to(device) - std_tensor = self.std_tensors[name].to(device) + mean_tensor = self.mean_tensors[name].to(device).view(1, self.levels, 1, 1, 1) #(16,) + std_tensor = self.std_tensors[name].to(device).view(1, self.levels, 1, 1, 1) for level in range(self.levels): mean = mean_tensor[level] std = std_tensor[level] @@ -366,8 +366,8 @@ def inverse_transform(self, x: torch.Tensor) -> torch.Tensor: # Reverse surface variables if self.flag_surface: for k, name in enumerate(self.varname_surface): - mean = self.mean_tensors[name].to(device) - std = self.std_tensors[name].to(device) + mean = self.mean_tensors[name].to(device) #size none + std = self.std_tensors[name].to(device) #size none transformed_surface[:, k] = tensor_surface[:, k] * std + mean # Reverse diagnostic variables diff --git a/tests/test_postblock.py b/tests/test_postblock.py index c4c40ae..6d5e915 100644 --- a/tests/test_postblock.py +++ b/tests/test_postblock.py @@ -46,7 +46,7 @@ def test_SKEBS_integration(): model = CrossFormer(**conf["model"]) device = torch.device(f"cuda:{1 % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu") - model = model.load_model(conf) + # model = model.load_model(conf) model.to(device) logging.info(f"model: {device}")