Skip to content

Commit

Permalink
Merge pull request #110 from NCAR/physics
Browse files Browse the repository at this point in the history
`credit.postblock` major updates on `GlobalMassFixer` and `GlobalEnergyFixer`
  • Loading branch information
yingkaisha authored Oct 16, 2024
2 parents 9fba515 + f1a7a9b commit 7357df0
Show file tree
Hide file tree
Showing 8 changed files with 692 additions and 123 deletions.
95 changes: 56 additions & 39 deletions config/example_physics_single.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,42 @@
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_physics_base/'
save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_conserve/'
seed: 1000

data:
# upper-air variables
variables: ['specific_total_water']
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/all_in_one/ERA5_plevel_6h_*zarr'
variables: ['U', 'V', 'T', 'Z', 'specific_total_water']
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/all_in_one/ERA5_plevel_1deg_6h_*_conserve.zarr'

# surface variables
surface_variables: ['SKT']
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/all_in_one/ERA5_plevel_6h_*zarr'
surface_variables: ['MSL', 'VAR_2T', 'SKT']
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/all_in_one/ERA5_plevel_1deg_6h_*_conserve.zarr'

# dynamic forcing variables
dynamic_forcing_variables: ['toa_incident_solar_radiation', 'land_sea_CI_mask']
save_loc_dynamic_forcing: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/all_in_one/ERA5_plevel_6h_*zarr'
save_loc_dynamic_forcing: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/all_in_one/ERA5_plevel_1deg_6h_*_conserve.zarr'

# diagnostic variables
diagnostic_variables: ['total_precipitation']
save_loc_diagnostic: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/all_in_one/ERA5_plevel_6h_*zarr'
diagnostic_variables: ['evaporation', 'total_precipitation', 'TCC',
'surface_net_solar_radiation',
'surface_net_thermal_radiation',
'surface_sensible_heat_flux',
'surface_latent_heat_flux',
'top_net_solar_radiation',
'top_net_thermal_radiation']
save_loc_diagnostic: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/all_in_one/ERA5_plevel_1deg_6h_*_conserve.zarr'

# static variables
static_variables: ['z_norm']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/static/ERA5_plevel_6h_static.zarr'
static_variables: ['z_norm', 'soil_type']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/static/ERA5_plevel_1deg_6h_conserve_static.zarr'

# physics file
save_loc_physics: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/static/ERA5_plevel_1deg_6h_conserve_static.zarr'

# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/mean_std/mean_6h_1979_2019_13lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/mean_std/std_residual_6h_1979_2019_13lev_0.25deg.nc'
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/mean_std/mean_6h_1979_2019_conserve_1deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/mean_std/std_residual_6h_1979_2019_conserve_1deg.nc'

# train / validation split
train_years: [1979, 2019]
Expand Down Expand Up @@ -79,29 +88,29 @@ trainer:
load_optimizer: False
load_scaler: False
load_sheduler: False

num_epoch: 2

skip_validation: False
update_learning_rate: False

save_backup_weights: True
save_best_weights: True

save_metric_vars: ['Z500', 'total_precipitation']
save_metric_vars: ['Z_21', 'specific_total_water_30', 'total_precipitation']

learning_rate: 1.0e-03 # <-- change to your lr
weight_decay: 0

train_batch_size: 1
valid_batch_size: 1

batches_per_epoch: 3
valid_batches_per_epoch: 3
batches_per_epoch: 2
valid_batches_per_epoch: 1
stopping_patience: 50

start_epoch: 0
num_epoch: 2

reload_epoch: True
epochs: &epochs 70
epochs: &epochs 100

use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
Expand All @@ -122,29 +131,29 @@ model:
type: "fuxi"

frames: 2 # number of input states
image_height: 721 # number of latitude grids
image_width: 1440 # number of longitude grids
levels: 13 # number of upper-air variable levels
channels: 1 # upper-air variable channels
surface_channels: 1 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 1 # diagnostic variable channels
image_height: 181 # number of latitude grids
image_width: 360 # number of longitude grids
levels: 37 # number of upper-air variable levels
channels: 5 # upper-air variable channels
surface_channels: 3 # surface variable channels
input_only_channels: 4 # dynamic forcing, forcing, static channels
output_only_channels: 9 # diagnostic variable channels

# patchify layer
patch_height: 16 # number of latitude grids in each 3D patch
patch_width: 16 # number of longitude grids in each 3D patch
patch_height: 32 # number of latitude grids in each 3D patch
patch_width: 32 # number of longitude grids in each 3D patch
frame_patch_size: 2 # number of input states in each 3D patch

# hidden layers
dim: 64 # dimension (default: 1536)
num_groups: 2 # number of groups (default: 32)
num_heads: 2 # number of heads (default: 8)
dim: 128 # dimension (default: 1536)
num_groups: 2 # number of groups (default: 32)
num_heads: 8 # number of heads (default: 8)
window_size: 7 # window size (default: 7)
depth: 4 # number of swin transformers (default: 48)
depth: 2 # number of swin transformers (default: 48)

# map boundary padding
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
pad_lon: 40 # number of grids to pad on 0 and 360 deg lon
pad_lat: 40 # number of grids to pad on -90 and 90 deg lat

# use spectral norm
use_spectral_norm: True
Expand All @@ -158,24 +167,32 @@ model:
tracer_fixer:
activate: True
denorm: True
tracer_name: ['specific_total_water', 'total_precipitation']
tracer_name: ['specific_total_water', 'total_precipitation', 'TCC']
tracer_thres: [0, 0, 0]

global_mass_fixer:
activate: False
activate: True
simple_demo: False
denorm: True
midpoint: False
fix_level_num: 14
lon_lat_level_name: ['lon2d', 'lat2d', 'p_level']
specific_total_water_name: ['specific_total_water']
precipitation_name: ['total_precipitation']
evaporation_name: ['evaporation']

global_energy_fixer:
activate: False
activate: True
simple_demo: False
denorm: True
midpoint: False
lon_lat_level_name: ['lon2d', 'lat2d', 'p_level']
air_temperature_name: ['T']
specific_total_water_name: ['specific_total_water']
u_wind_name: ['U']
v_wind_name: ['V']
surface_geopotential_name: ['z_norm']
lon_lat_level_name: ['lon2d', 'lat2d', 'p_level']
surface_geopotential_name: ['geopotential_at_surface']
TOA_net_radiation_flux_name: ['top_net_solar_radiation', 'top_net_thermal_radiation']
surface_net_radiation_flux_name: ['surface_net_solar_radiation', 'surface_net_thermal_radiation']
surface_energy_flux_name: ['surface_sensible_heat_flux', 'surface_latent_heat_flux',]
Expand All @@ -190,7 +207,7 @@ loss:

# use latitude weighting
use_latitude_weights: True
latitude_weights: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_base/static/ERA5_plevel_6h_static.zarr'
latitude_weights: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_plevel_1deg/static/ERA5_plevel_1deg_6h_conserve_static.zarr'

# turn-off variable weighting
use_variable_weights: False
Expand Down
8 changes: 4 additions & 4 deletions credit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(
if self.filename_forcing is not None:
# drop variables if they are not in the config
ds = get_forward_data(filename_forcing)
ds_forcing = drop_var_from_dataset(ds, varname_forcing)
ds_forcing = drop_var_from_dataset(ds, varname_forcing).load() # <---- load in static

self.xarray_forcing = ds_forcing
else:
Expand All @@ -477,7 +477,7 @@ def __init__(
if self.filename_static is not None:
# drop variables if they are not in the config
ds = get_forward_data(filename_static)
ds_static = drop_var_from_dataset(ds, varname_static)
ds_static = drop_var_from_dataset(ds, varname_static).load() # <---- load in static

self.xarray_static = ds_static
else:
Expand Down Expand Up @@ -567,7 +567,7 @@ def __getitem__(self, index):
month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time'])) # <-- upper air
# indices to subset
ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs)
forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing).load() # <-- load into memory
forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing) #.load() # <-- loadded in init
# forcing and upper air have different years but the same mon/day/hour
# safely replace forcing time with upper air time
forcing_subset_input['time'] = historical_ERA5_images['time']
Expand All @@ -587,7 +587,7 @@ def __getitem__(self, index):

# slice + load to the GPU
static_subset_input = static_subset_input.isel(
time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory
time=slice(0, self.history_len, self.skip_periods)) #.load() # <-- loaded in init

# update
static_subset_input['time'] = historical_ERA5_images['time']
Expand Down
9 changes: 5 additions & 4 deletions credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,16 @@ def forward(self, x: torch.Tensor):
# if lat/lon grids (i.e., img_size) cannot be divided by the patche size completely
# this will preserve the output size
x = F.interpolate(x, size=img_size[1:], mode="bilinear")

# unfold the time dimension
x = x.unsqueeze(2)

if self.use_post_block:
x = {
"y_pred": x,
"x": x_copy,
}
x = self.postblock(x)
# unfold the time dimension
return x.unsqueeze(2)
x = self.postblock(x)
return x


if __name__ == "__main__":
Expand Down
44 changes: 28 additions & 16 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ

# --------------------------------------------------------------------- #
# tracer fixer
flag_tracer = conf['model']['post_conf']['tracer_fixer']['activate']
flag_tracer = conf['model']['post_conf']['activate'] and conf['model']['post_conf']['tracer_fixer']['activate']

if flag_tracer:
# when tracer fixer is on, get tensor indices of tracers
Expand Down Expand Up @@ -357,14 +357,23 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ

# --------------------------------------------------------------------- #
# global mass fixer
flag_mass = conf['model']['post_conf']['global_mass_fixer']['activate']
flag_mass = conf['model']['post_conf']['activate'] and conf['model']['post_conf']['global_mass_fixer']['activate']

if flag_mass:
# when global mass fixer is on, get tensor indices of q, precip, evapor
# these variables must be outputs

# global mass fixer runs on de-normalized variables by default
# global mass fixer defaults
conf['model']['post_conf']['global_mass_fixer'].setdefault('denorm', True)
conf['model']['post_conf']['global_mass_fixer'].setdefault('simple_demo', False)
conf['model']['post_conf']['global_mass_fixer'].setdefault('midpoint', False)

assert 'fix_level_num' in conf['model']['post_conf']['global_mass_fixer'], (
'Must specifiy what level to fix on specific total water')

if conf['model']['post_conf']['global_mass_fixer']['simple_demo'] is False:
assert 'lon_lat_level_name' in conf['model']['post_conf']['global_mass_fixer'], (
'Must specifiy var names for lat/lon/level in physics reference file')

q_inds = [
i_var for i_var, var in enumerate(varname_output)
Expand All @@ -387,13 +396,20 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ

# --------------------------------------------------------------------- #
# global energy fixer
flag_energy = conf['model']['post_conf']['global_energy_fixer']['activate']
flag_energy = conf['model']['post_conf']['activate'] and conf['model']['post_conf']['global_energy_fixer']['activate']

if flag_energy:
# when global energy fixer is on, get tensor indices of energy components
# geopotential at surface is input, others are outputs

# global energy fixer runs on de-normalized variables by default
# global energy fixer defaults
conf['model']['post_conf']['global_energy_fixer'].setdefault('denorm', True)
conf['model']['post_conf']['global_energy_fixer'].setdefault('simple_demo', False)
conf['model']['post_conf']['global_energy_fixer'].setdefault('midpoint', False)

if conf['model']['post_conf']['global_mass_fixer']['simple_demo'] is False:
assert 'lon_lat_level_name' in conf['model']['post_conf']['global_energy_fixer'], (
'Must specifiy var names for lat/lon/level in physics reference file')

T_inds = [
i_var for i_var, var in enumerate(varname_output)
Expand All @@ -415,10 +431,10 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ
if var in conf['model']['post_conf']['global_energy_fixer']['v_wind_name']
]

Phi_inds = [
i_var for i_var, var in enumerate(varname_input)
if var in conf['model']['post_conf']['global_energy_fixer']['surface_geopotential_name']
]
# Phi_inds = [
# i_var for i_var, var in enumerate(varname_input)
# if var in conf['model']['post_conf']['global_energy_fixer']['surface_geopotential_name']
# ]

TOA_rad_inds = [
i_var for i_var, var in enumerate(varname_output)
Expand All @@ -439,7 +455,7 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ
conf['model']['post_conf']['global_energy_fixer']['q_inds'] = q_inds
conf['model']['post_conf']['global_energy_fixer']['U_inds'] = U_inds
conf['model']['post_conf']['global_energy_fixer']['V_inds'] = V_inds
conf['model']['post_conf']['global_energy_fixer']['Phi_ind'] = Phi_inds[0]
#conf['model']['post_conf']['global_energy_fixer']['Phi_ind'] = Phi_inds[0]
conf['model']['post_conf']['global_energy_fixer']['TOA_rad_inds'] = TOA_rad_inds
conf['model']['post_conf']['global_energy_fixer']['surf_rad_inds'] = surf_rad_inds
conf['model']['post_conf']['global_energy_fixer']['surf_flux_inds'] = surf_flux_inds
Expand Down Expand Up @@ -1185,8 +1201,4 @@ def predict_data_check(conf, print_summary=False):
print('Coordinate checking passed')
print("All input files, zscore files, and the lat/lon file share the same lat, lon, level coordinate name and values")

return True




return True
Loading

0 comments on commit 7357df0

Please sign in to comment.