Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.clamp to trainers + add drop_path in FuXi + bugfix FuXi when use_spectral_norm: False + minor updates on credit.parser #130

Merged
merged 8 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,16 @@ def predict(rank, world_size, conf, p):
+ len(conf["data"]["forcing_variables"])
+ len(conf["data"]["static_variables"])
)


# ------------------------------------------------------- #
# clamp to remove outliers
if conf["data"]["data_clamp"] is None:
flag_clamp = False
else:
flag_clamp = True
clamp_min = float(conf["data"]["data_clamp"][0])
clamp_max = float(conf["data"]["data_clamp"][1])

# ====================================================== #
# postblock opts outside of model
post_conf = conf["model"]["post_conf"]
Expand Down Expand Up @@ -646,6 +655,13 @@ def predict(rank, world_size, conf, p):

# -------------------------------------------------------------------------------------- #
# start prediction

# --------------------------------------------- #
# clamp
if flag_clamp:
x = torch.clamp(x, min=clamp_min, max=clamp_max)
#y = torch.clamp(y, min=clamp_min, max=clamp_max)

y_pred = model(x)

# ============================================= #
Expand Down
6 changes: 2 additions & 4 deletions config/example_physics_single.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# example
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_conserve/'
seed: 1000
Expand Down Expand Up @@ -47,6 +44,7 @@ data:

# data workflow
scaler_type: 'std_new'
data_clamp: [-16, 16]

# number of input states
# FuXi has 2 input states
Expand Down
46 changes: 18 additions & 28 deletions credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def latitude_weights(conf):
return L


def variable_weights(conf, channels, surface_channels, frames):
def variable_weights(conf, channels, frames):
"""Create variable-specific weights for different atmospheric
and surface channels.

Expand All @@ -382,7 +382,6 @@ def variable_weights(conf, channels, surface_channels, frames):
conf (dict): Configuration dictionary containing the
variable weights.
channels (int): Number of channels for atmospheric variables.
surface_channels (int): Number of channels for surface variables.
frames (int): Number of time frames.

Returns:
Expand All @@ -393,41 +392,25 @@ def variable_weights(conf, channels, surface_channels, frames):
varname_upper_air = conf["data"]["variables"]
varname_surface = conf["data"]["surface_variables"]
varname_diagnostics = conf["data"]["diagnostic_variables"]
# N_levels = conf['data']['levels']

# weights_UVTQ = torch.tensor([
# conf["loss"]["variable_weights"]["U"],
# conf["loss"]["variable_weights"]["V"],
# conf["loss"]["variable_weights"]["T"],
# conf["loss"]["variable_weights"]["Q"]
# ]).view(1, channels * frames, 1, 1)

weights_UVTQ = torch.tensor(
# surface + diag channels
N_channels_single = len(varname_surface) + len(varname_diagnostics)

weights_upper_air = torch.tensor(
[conf["loss"]["variable_weights"][var] for var in varname_upper_air]
).view(1, channels * frames, 1, 1)

# Load weights for SP, t2m, V500, U500, T500, Z500, Q500
# weights_sfc = torch.tensor([
# conf["loss"]["variable_weights"]["SP"],
# conf["loss"]["variable_weights"]["t2m"],
# conf["loss"]["variable_weights"]["V500"],
# conf["loss"]["variable_weights"]["U500"],
# conf["loss"]["variable_weights"]["T500"],
# conf["loss"]["variable_weights"]["Z500"],
# conf["loss"]["variable_weights"]["Q500"]
# ]).view(1, surface_channels, 1, 1)

weights_sfc = torch.tensor(

weights_single = torch.tensor(
[
conf["loss"]["variable_weights"][var]
for var in (varname_surface + varname_diagnostics)
]
).view(1, surface_channels, 1, 1)
).view(1, N_channels_single, 1, 1)

# Combine all weights along the color channel
variable_weights = torch.cat([weights_UVTQ, weights_sfc], dim=1)
var_weights = torch.cat([weights_upper_air, weights_single], dim=1)

return variable_weights
return var_weights


class VariableTotalLoss2D(torch.nn.Module):
Expand Down Expand Up @@ -471,18 +454,25 @@ def __init__(self, conf, validation=False):
logger.info("Using latitude weights in loss calculations")
self.lat_weights = latitude_weights(conf)[:, 10].unsqueeze(0).unsqueeze(-1)

# ------------------------------------------------------------- #
# variable weights
# order: upper air --> surface --> diagnostics
self.var_weights = None
if conf["loss"]["use_variable_weights"]:
logger.info("Using variable weights in loss calculations")

var_weights = [
value if isinstance(value, list) else [value]
for value in conf["loss"]["variable_weights"].values()
]

var_weights = np.array(
[item for sublist in var_weights for item in sublist]
)

self.var_weights = torch.from_numpy(var_weights)

# ------------------------------------------------------------- #

self.use_spectral_loss = conf["loss"]["use_spectral_loss"]
if self.use_spectral_loss:
self.spectral_lambda_reg = conf["loss"]["spectral_lambda_reg"]
Expand Down
45 changes: 39 additions & 6 deletions credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,15 @@ class UTransformer(nn.Module):
"""

def __init__(
self, embed_dim, num_groups, input_resolution, num_heads, window_size, depth
self, embed_dim,
num_groups,
input_resolution,
num_heads,
window_size,
depth,
proj_drop,
attn_drop,
drop_path
):
super().__init__()
num_groups = to_2tuple(num_groups)
Expand All @@ -248,7 +256,15 @@ def __init__(

# SwinT block
self.layer = SwinTransformerV2Stage(
embed_dim, embed_dim, input_resolution, depth, num_heads, window_size[0]
embed_dim,
embed_dim,
input_resolution,
depth,
num_heads,
window_size[0],
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
) # <--- window_size[0] get window_size[int] from tuple

# up-sampling block
Expand Down Expand Up @@ -315,6 +331,9 @@ def __init__(
window_size=7,
use_spectral_norm=True,
interp=True,
proj_drop=0,
attn_drop=0,
drop_path=0,
padding_conf=None,
post_conf=None,
**kwargs,
Expand All @@ -323,11 +342,15 @@ def __init__(

self.use_interp = interp
self.use_spectral_norm = use_spectral_norm

if padding_conf is None:
padding_conf = {"activate": False}

self.use_padding = padding_conf["activate"]

if post_conf is None:
post_conf = {"activate": False}

self.use_post_block = post_conf["activate"]

# input tensor size (time, lat, lon)
Expand Down Expand Up @@ -362,8 +385,17 @@ def __init__(
self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim)

# Downsampling --> SwinTransformerV2 stacks --> Upsampling
logger.info(f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}")

self.u_transformer = UTransformer(
dim, num_groups, input_resolution, num_heads, window_size, depth=depth
dim, num_groups,
input_resolution,
num_heads,
window_size,
depth=depth,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
)

# dense layer applied on channel dmension
Expand All @@ -383,11 +415,12 @@ def __init__(
if self.use_padding:
self.padding_opt = TensorPadding(**padding_conf)

# Move the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)

if self.use_spectral_norm:
logger.info("Adding spectral norm to all conv and linear layers")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move the model to the device
self.to(device)
apply_spectral_norm(self)

if self.use_post_block:
Expand Down
47 changes: 43 additions & 4 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def credit_main_parser(
)

## I/O data sizes

conf["data"].setdefault("data_clamp", None)

if parse_training:
assert (
"train_years" in conf["data"]
Expand All @@ -264,7 +267,7 @@ def credit_main_parser(
assert (
"forecast_len" in conf["data"]
), "Number of time frames for loss compute ('forecast_len') is missing from conf['data']"

if "valid_history_len" not in conf["data"]:
# use "history_len" for "valid_history_len"
conf["data"]["valid_history_len"] = conf["data"]["history_len"]
Expand Down Expand Up @@ -420,8 +423,8 @@ def credit_main_parser(
)

# # debug only
# conf['model']['post_conf']['varname_input'] = varname_input
# conf['model']['post_conf']['varname_output'] = varname_output
conf['model']['post_conf']['varname_input'] = varname_input
conf['model']['post_conf']['varname_output'] = varname_output
# --------------------------------------------------------------------- #

# SKEBS
Expand Down Expand Up @@ -478,6 +481,7 @@ def credit_main_parser(
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("denorm", True)
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("simple_demo", False)
conf["model"]["post_conf"]["global_mass_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_mass_fixer'].setdefault('grid_type', 'pressure')

assert (
"fix_level_num" in conf["model"]["post_conf"]["global_mass_fixer"]
Expand All @@ -487,7 +491,11 @@ def credit_main_parser(
assert (
"lon_lat_level_name" in conf["model"]["post_conf"]["global_mass_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"


if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_mass_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

q_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand All @@ -498,6 +506,13 @@ def credit_main_parser(
]
conf["model"]["post_conf"]["global_mass_fixer"]["q_inds"] = q_inds

if conf['model']['post_conf']['global_mass_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_mass_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_mass_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------------------- #
# global water fixer
flag_water = (
Expand All @@ -518,12 +533,17 @@ def credit_main_parser(
"simple_demo", False
)
conf["model"]["post_conf"]["global_water_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_water_fixer'].setdefault('grid_type', 'pressure')

if conf["model"]["post_conf"]["global_water_fixer"]["simple_demo"] is False:
assert (
"lon_lat_level_name" in conf["model"]["post_conf"]["global_water_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"

if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_water_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

q_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand Down Expand Up @@ -551,6 +571,13 @@ def credit_main_parser(
conf["model"]["post_conf"]["global_water_fixer"]["precip_ind"] = precip_inds[0]
conf["model"]["post_conf"]["global_water_fixer"]["evapor_ind"] = evapor_inds[0]

if conf['model']['post_conf']['global_water_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_water_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_water_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------------------- #
# global energy fixer
flag_energy = (
Expand All @@ -571,13 +598,18 @@ def credit_main_parser(
"simple_demo", False
)
conf["model"]["post_conf"]["global_energy_fixer"].setdefault("midpoint", False)
conf['model']['post_conf']['global_energy_fixer'].setdefault('grid_type', 'pressure')

if conf["model"]["post_conf"]["global_energy_fixer"]["simple_demo"] is False:
assert (
"lon_lat_level_name"
in conf["model"]["post_conf"]["global_energy_fixer"]
), "Must specifiy var names for lat/lon/level in physics reference file"

if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma':
assert 'surface_pressure_name' in conf['model']['post_conf']['global_energy_fixer'], (
'Must specifiy surface pressure var name when using hybrid sigma-pressure coordinates')

T_inds = [
i_var
for i_var, var in enumerate(varname_output)
Expand Down Expand Up @@ -645,6 +677,13 @@ def credit_main_parser(
surf_flux_inds
)

if conf['model']['post_conf']['global_energy_fixer']['grid_type'] == 'sigma':
sp_inds = [
i_var for i_var, var in enumerate(varname_output)
if var in conf['model']['post_conf']['global_energy_fixer']['surface_pressure_name']
]
conf['model']['post_conf']['global_energy_fixer']['sp_inds'] = sp_inds[0]

# --------------------------------------------------------- #
# conf['trainer'] section

Expand Down
Loading