diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index ab24f2c1..96475ae6 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -496,7 +496,16 @@ def predict(rank, world_size, conf, p): + len(conf["data"]["forcing_variables"]) + len(conf["data"]["static_variables"]) ) - + + # ------------------------------------------------------- # + # clamp to remove outliers + if conf["data"]["data_clamp"] is None: + flag_clamp = False + else: + flag_clamp = True + clamp_min = float(conf["data"]["data_clamp"][0]) + clamp_max = float(conf["data"]["data_clamp"][1]) + # ====================================================== # # postblock opts outside of model post_conf = conf["model"]["post_conf"] @@ -646,6 +655,13 @@ def predict(rank, world_size, conf, p): # -------------------------------------------------------------------------------------- # # start prediction + + # --------------------------------------------- # + # clamp + if flag_clamp: + x = torch.clamp(x, min=clamp_min, max=clamp_max) + #y = torch.clamp(y, min=clamp_min, max=clamp_max) + y_pred = model(x) # ============================================= # diff --git a/config/example_physics_single.yml b/config/example_physics_single.yml index 653d948f..506ba318 100644 --- a/config/example_physics_single.yml +++ b/config/example_physics_single.yml @@ -1,8 +1,5 @@ # --------------------------------------------------------------------------------------------------------------------- # -# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu) -# the FuXi architecture has been modified to reduce the overall model size -# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs -# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q] +# example # --------------------------------------------------------------------------------------------------------------------- # save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_conserve/' seed: 1000 @@ -47,6 +44,7 @@ data: # data workflow scaler_type: 'std_new' + data_clamp: [-16, 16] # number of input states # FuXi has 2 input states diff --git a/credit/loss.py b/credit/loss.py index 10fd701e..067deafc 100644 --- a/credit/loss.py +++ b/credit/loss.py @@ -369,7 +369,7 @@ def latitude_weights(conf): return L -def variable_weights(conf, channels, surface_channels, frames): +def variable_weights(conf, channels, frames): """Create variable-specific weights for different atmospheric and surface channels. @@ -382,7 +382,6 @@ def variable_weights(conf, channels, surface_channels, frames): conf (dict): Configuration dictionary containing the variable weights. channels (int): Number of channels for atmospheric variables. - surface_channels (int): Number of channels for surface variables. frames (int): Number of time frames. Returns: @@ -393,41 +392,25 @@ def variable_weights(conf, channels, surface_channels, frames): varname_upper_air = conf["data"]["variables"] varname_surface = conf["data"]["surface_variables"] varname_diagnostics = conf["data"]["diagnostic_variables"] - # N_levels = conf['data']['levels'] - # weights_UVTQ = torch.tensor([ - # conf["loss"]["variable_weights"]["U"], - # conf["loss"]["variable_weights"]["V"], - # conf["loss"]["variable_weights"]["T"], - # conf["loss"]["variable_weights"]["Q"] - # ]).view(1, channels * frames, 1, 1) - - weights_UVTQ = torch.tensor( + # surface + diag channels + N_channels_single = len(varname_surface) + len(varname_diagnostics) + + weights_upper_air = torch.tensor( [conf["loss"]["variable_weights"][var] for var in varname_upper_air] ).view(1, channels * frames, 1, 1) - - # Load weights for SP, t2m, V500, U500, T500, Z500, Q500 - # weights_sfc = torch.tensor([ - # conf["loss"]["variable_weights"]["SP"], - # conf["loss"]["variable_weights"]["t2m"], - # conf["loss"]["variable_weights"]["V500"], - # conf["loss"]["variable_weights"]["U500"], - # conf["loss"]["variable_weights"]["T500"], - # conf["loss"]["variable_weights"]["Z500"], - # conf["loss"]["variable_weights"]["Q500"] - # ]).view(1, surface_channels, 1, 1) - - weights_sfc = torch.tensor( + + weights_single = torch.tensor( [ conf["loss"]["variable_weights"][var] for var in (varname_surface + varname_diagnostics) ] - ).view(1, surface_channels, 1, 1) + ).view(1, N_channels_single, 1, 1) # Combine all weights along the color channel - variable_weights = torch.cat([weights_UVTQ, weights_sfc], dim=1) + var_weights = torch.cat([weights_upper_air, weights_single], dim=1) - return variable_weights + return var_weights class VariableTotalLoss2D(torch.nn.Module): @@ -471,18 +454,25 @@ def __init__(self, conf, validation=False): logger.info("Using latitude weights in loss calculations") self.lat_weights = latitude_weights(conf)[:, 10].unsqueeze(0).unsqueeze(-1) + # ------------------------------------------------------------- # + # variable weights + # order: upper air --> surface --> diagnostics self.var_weights = None if conf["loss"]["use_variable_weights"]: logger.info("Using variable weights in loss calculations") + var_weights = [ value if isinstance(value, list) else [value] for value in conf["loss"]["variable_weights"].values() ] + var_weights = np.array( [item for sublist in var_weights for item in sublist] ) + self.var_weights = torch.from_numpy(var_weights) - + # ------------------------------------------------------------- # + self.use_spectral_loss = conf["loss"]["use_spectral_loss"] if self.use_spectral_loss: self.spectral_lambda_reg = conf["loss"]["spectral_lambda_reg"] diff --git a/credit/models/fuxi.py b/credit/models/fuxi.py index 162e8807..7627ef5c 100644 --- a/credit/models/fuxi.py +++ b/credit/models/fuxi.py @@ -226,7 +226,15 @@ class UTransformer(nn.Module): """ def __init__( - self, embed_dim, num_groups, input_resolution, num_heads, window_size, depth + self, embed_dim, + num_groups, + input_resolution, + num_heads, + window_size, + depth, + proj_drop, + attn_drop, + drop_path ): super().__init__() num_groups = to_2tuple(num_groups) @@ -248,7 +256,15 @@ def __init__( # SwinT block self.layer = SwinTransformerV2Stage( - embed_dim, embed_dim, input_resolution, depth, num_heads, window_size[0] + embed_dim, + embed_dim, + input_resolution, + depth, + num_heads, + window_size[0], + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path ) # <--- window_size[0] get window_size[int] from tuple # up-sampling block @@ -315,6 +331,9 @@ def __init__( window_size=7, use_spectral_norm=True, interp=True, + proj_drop=0, + attn_drop=0, + drop_path=0, padding_conf=None, post_conf=None, **kwargs, @@ -323,11 +342,15 @@ def __init__( self.use_interp = interp self.use_spectral_norm = use_spectral_norm + if padding_conf is None: padding_conf = {"activate": False} + self.use_padding = padding_conf["activate"] + if post_conf is None: post_conf = {"activate": False} + self.use_post_block = post_conf["activate"] # input tensor size (time, lat, lon) @@ -362,8 +385,17 @@ def __init__( self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim) # Downsampling --> SwinTransformerV2 stacks --> Upsampling + logger.info(f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}") + self.u_transformer = UTransformer( - dim, num_groups, input_resolution, num_heads, window_size, depth=depth + dim, num_groups, + input_resolution, + num_heads, + window_size, + depth=depth, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path ) # dense layer applied on channel dmension @@ -383,11 +415,12 @@ def __init__( if self.use_padding: self.padding_opt = TensorPadding(**padding_conf) + # Move the model to the device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(device) + if self.use_spectral_norm: logger.info("Adding spectral norm to all conv and linear layers") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Move the model to the device - self.to(device) apply_spectral_norm(self) if self.use_post_block: diff --git a/credit/parser.py b/credit/parser.py index c50e7a45..b7db64ef 100644 --- a/credit/parser.py +++ b/credit/parser.py @@ -250,6 +250,9 @@ def credit_main_parser( ) ## I/O data sizes + + conf["data"].setdefault("data_clamp", None) + if parse_training: assert ( "train_years" in conf["data"] @@ -264,7 +267,7 @@ def credit_main_parser( assert ( "forecast_len" in conf["data"] ), "Number of time frames for loss compute ('forecast_len') is missing from conf['data']" - + if "valid_history_len" not in conf["data"]: # use "history_len" for "valid_history_len" conf["data"]["valid_history_len"] = conf["data"]["history_len"] @@ -420,8 +423,8 @@ def credit_main_parser( ) # # debug only - # conf['model']['post_conf']['varname_input'] = varname_input - # conf['model']['post_conf']['varname_output'] = varname_output + conf['model']['post_conf']['varname_input'] = varname_input + conf['model']['post_conf']['varname_output'] = varname_output # --------------------------------------------------------------------- # # SKEBS @@ -478,6 +481,7 @@ def credit_main_parser( conf["model"]["post_conf"]["global_mass_fixer"].setdefault("denorm", True) conf["model"]["post_conf"]["global_mass_fixer"].setdefault("simple_demo", False) conf["model"]["post_conf"]["global_mass_fixer"].setdefault("midpoint", False) + conf['model']['post_conf']['global_mass_fixer'].setdefault('grid_type', 'pressure') assert ( "fix_level_num" in conf["model"]["post_conf"]["global_mass_fixer"] @@ -487,7 +491,11 @@ def credit_main_parser( assert ( "lon_lat_level_name" in conf["model"]["post_conf"]["global_mass_fixer"] ), "Must specifiy var names for lat/lon/level in physics reference file" - + + if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma': + assert 'surface_pressure_name' in conf['model']['post_conf']['global_mass_fixer'], ( + 'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates') + q_inds = [ i_var for i_var, var in enumerate(varname_output) @@ -498,6 +506,13 @@ def credit_main_parser( ] conf["model"]["post_conf"]["global_mass_fixer"]["q_inds"] = q_inds + if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma': + sp_inds = [ + i_var for i_var, var in enumerate(varname_output) + if var in conf['model']['post_conf']['global_mass_fixer']['surface_pressure_name'] + ] + conf['model']['post_conf']['global_mass_fixer']['sp_inds'] = sp_inds[0] + # --------------------------------------------------------------------- # # global water fixer flag_water = ( @@ -518,12 +533,17 @@ def credit_main_parser( "simple_demo", False ) conf["model"]["post_conf"]["global_water_fixer"].setdefault("midpoint", False) + conf['model']['post_conf']['global_water_fixer'].setdefault('grid_type', 'pressure') if conf["model"]["post_conf"]["global_water_fixer"]["simple_demo"] is False: assert ( "lon_lat_level_name" in conf["model"]["post_conf"]["global_water_fixer"] ), "Must specifiy var names for lat/lon/level in physics reference file" + if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma': + assert 'surface_pressure_name' in conf['model']['post_conf']['global_water_fixer'], ( + 'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates') + q_inds = [ i_var for i_var, var in enumerate(varname_output) @@ -551,6 +571,13 @@ def credit_main_parser( conf["model"]["post_conf"]["global_water_fixer"]["precip_ind"] = precip_inds[0] conf["model"]["post_conf"]["global_water_fixer"]["evapor_ind"] = evapor_inds[0] + if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma': + sp_inds = [ + i_var for i_var, var in enumerate(varname_output) + if var in conf['model']['post_conf']['global_water_fixer']['surface_pressure_name'] + ] + conf['model']['post_conf']['global_water_fixer']['sp_inds'] = sp_inds[0] + # --------------------------------------------------------------------- # # global energy fixer flag_energy = ( @@ -571,6 +598,7 @@ def credit_main_parser( "simple_demo", False ) conf["model"]["post_conf"]["global_energy_fixer"].setdefault("midpoint", False) + conf['model']['post_conf']['global_energy_fixer'].setdefault('grid_type', 'pressure') if conf["model"]["post_conf"]["global_energy_fixer"]["simple_demo"] is False: assert ( @@ -578,6 +606,10 @@ def credit_main_parser( in conf["model"]["post_conf"]["global_energy_fixer"] ), "Must specifiy var names for lat/lon/level in physics reference file" + if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma': + assert 'surface_pressure_name' in conf['model']['post_conf']['global_energy_fixer'], ( + 'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates') + T_inds = [ i_var for i_var, var in enumerate(varname_output) @@ -645,6 +677,13 @@ def credit_main_parser( surf_flux_inds ) + if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma': + sp_inds = [ + i_var for i_var, var in enumerate(varname_output) + if var in conf['model']['post_conf']['global_energy_fixer']['surface_pressure_name'] + ] + conf['model']['post_conf']['global_energy_fixer']['sp_inds'] = sp_inds[0] + # --------------------------------------------------------- # # conf['trainer'] section diff --git a/credit/trainers/trainerERA5_multistep_grad_accum.py b/credit/trainers/trainerERA5_multistep_grad_accum.py index 3d2c840e..7b88f3ad 100644 --- a/credit/trainers/trainerERA5_multistep_grad_accum.py +++ b/credit/trainers/trainerERA5_multistep_grad_accum.py @@ -112,6 +112,15 @@ def train_one_epoch( ): scheduler.step() + # ------------------------------------------------------- # + # clamp to remove outliers + if conf["data"]["data_clamp"] is None: + flag_clamp = False + else: + flag_clamp = True + clamp_min = float(conf["data"]["data_clamp"][0]) + clamp_max = float(conf["data"]["data_clamp"][1]) + # ====================================================== # # postblock opts outside of model post_conf = conf["model"]["post_conf"] @@ -195,6 +204,11 @@ def train_one_epoch( # concat on var dimension x = torch.cat((x, x_forcing_batch), dim=1) + # --------------------------------------------- # + # clamp + if flag_clamp: + x = torch.clamp(x, min=clamp_min, max=clamp_max) + # predict with the model y_pred = self.model(x) @@ -246,6 +260,11 @@ def train_one_epoch( # concat on var dimension y = torch.cat((y, y_diag_batch), dim=1) + # --------------------------------------------- # + # clamp + if flag_clamp: + y = torch.clamp(y, min=clamp_min, max=clamp_max) + loss = criterion(y.to(y_pred.dtype), y_pred).mean() # track the loss @@ -405,6 +424,15 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): else len(valid_loader) ) + # ------------------------------------------------------- # + # clamp to remove outliers + if conf["data"]["data_clamp"] is None: + flag_clamp = False + else: + flag_clamp = True + clamp_min = float(conf["data"]["data_clamp"][0]) + clamp_max = float(conf["data"]["data_clamp"][1]) + # ====================================================== # # postblock opts outside of model post_conf = conf["model"]["post_conf"] @@ -466,7 +494,11 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): # concat on var dimension x = torch.cat((x, x_forcing_batch), dim=1) - # logger.info('k = {}; x.shape() = {}'.format(forecast_step, x.shape)) + # --------------------------------------------- # + # clamp + if flag_clamp: + x = torch.clamp(x, min=clamp_min, max=clamp_max) + y_pred = self.model(x) # ============================================= # @@ -517,6 +549,11 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): # concat on var dimension y = torch.cat((y, y_diag_batch), dim=1) + # --------------------------------------------- # + # clamp + if flag_clamp: + y = torch.clamp(y, min=clamp_min, max=clamp_max) + # ----------------------------------------------------------------------- # # calculate rolling loss loss = criterion(y.to(y_pred.dtype), y_pred).mean() diff --git a/credit/trainers/trainerERA5_v2.py b/credit/trainers/trainerERA5_v2.py index 02c7b947..2c261a6a 100644 --- a/credit/trainers/trainerERA5_v2.py +++ b/credit/trainers/trainerERA5_v2.py @@ -67,6 +67,15 @@ def train_one_epoch( ): scheduler.step() + # ------------------------------------------------------- # + # clamp to remove outliers + if conf["data"]["data_clamp"] is None: + flag_clamp = False + else: + flag_clamp = True + clamp_min = float(conf["data"]["data_clamp"][0]) + clamp_max = float(conf["data"]["data_clamp"][1]) + # ====================================================== # # postblock opts outside of model post_conf = conf["model"]["post_conf"] @@ -157,6 +166,12 @@ def train_one_epoch( # concat on var dimension y = torch.cat((y, y_diag_batch), dim=1) + # --------------------------------------------- # + # clamp + if flag_clamp: + x = torch.clamp(x, min=clamp_min, max=clamp_max) + y = torch.clamp(y, min=clamp_min, max=clamp_max) + # single step predict y_pred = self.model(x) @@ -340,6 +355,15 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): results_dict = defaultdict(list) + # ------------------------------------------------------- # + # clamp to remove outliers + if conf["data"]["data_clamp"] is None: + flag_clamp = False + else: + flag_clamp = True + clamp_min = float(conf["data"]["data_clamp"][0]) + clamp_max = float(conf["data"]["data_clamp"][1]) + # ====================================================== # # postblock opts outside of model post_conf = conf["model"]["post_conf"] @@ -424,6 +448,12 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): # concat on var dimension y = torch.cat((y, y_diag_batch), dim=1) + # --------------------------------------------- # + # clamp + if flag_clamp: + x = torch.clamp(x, min=clamp_min, max=clamp_max) + y = torch.clamp(y, min=clamp_min, max=clamp_max) + y_pred = self.model(x) # ============================================= #