Skip to content

Commit

Permalink
Added pressure interpolation to rollout_netcdf.py and added configura…
Browse files Browse the repository at this point in the history
…tion info to example.yml. Also added ability to specify encoding dictionaries for compression and chunking in rollout.
  • Loading branch information
djgagne committed Oct 26, 2024
1 parent b15c001 commit 44ae540
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 49 deletions.
34 changes: 28 additions & 6 deletions config/example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# the location to save your workspace, it will have
# (1) pbs script, (2) a copy of this config, (3) model weights, (4) training_log.csv
# if save_loc does not exist, it will be created automatically
save_loc: '/glade/work/$USER/CREDIT_runs/fuxi_6h/'
save_loc: '/glade/derecho/scratch/$USER/CREDIT_runs/fuxi_6h/'
seed: 1000 # random seeed

data:
Expand Down Expand Up @@ -138,7 +138,7 @@ trainer:
# run train_multistep.py for multi-step
type: standard

# (optional) use torch.compile: False. May not be compatiable with custom models
# (optional) use torch.compile: False. May not be compatible with custom models
compile: False

# load existing weights / optimizer / mixed-precision grad scaler / learning rate scheduler state(s)
Expand Down Expand Up @@ -223,9 +223,9 @@ model:

# fuxi example
frames: 2 # number of input states
image_height: 640 # number of latitude grids
image_width: 1280 # number of longitude grids
levels: 15 # number of upper-air variable levels
image_height: &height 640 # number of latitude grids
image_width: &width 1280 # number of longitude grids
levels: &levels 16 # number of upper-air variable levels
channels: 4 # upper-air variable channels
surface_channels: 7 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
Expand Down Expand Up @@ -368,10 +368,32 @@ predict:
# users can use $repo/credit/metadata/era5.yaml as an example to create their own
metadata: '/glade/u/home/ksha/miles-credit/credit/metadata/era5.yaml'

interp_pressure:
pressure_levels: [300.0, 500.0, 850.0, 925.0]

ua_var_encoding:
zlib: True
complevel: 1
shuffle: True
chunksizes: [1, *levels, *height, *width]

pressure_var_encoding:
zlib: True
complevel: 1
shuffle: True
chunksizes: [ 1, 4, *height, *width]

surface_var_encoding:
zlib: true
complevel: 1
shuffle: True
chunksizes: [1, *height, *width]


# credit.pbs supports NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
pbs:
# example for derecho
conda: "/glade/work/ksha/miniconda3/envs/credit"
conda: "/glade/u/home/dgagne/conda-envs/hcredit"
project: "NAML0001"
job_name: "train_model"
walltime: "12:00:00"
Expand Down
48 changes: 33 additions & 15 deletions credit/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import xarray as xr
from tqdm import tqdm
from .physics_constants import RDGAS, RVGAS
import os


def full_state_pressure_interpolation(
state_dataset: xr.Dataset,
pressure_levels: np.ndarray,
model_a: np.ndarray,
model_b: np.ndarray,
pressure_levels: np.ndarray = np.array([500.0, 850.0]),
interp_fields: tuple[str] = ("U", "V", "T", "Q"),
pres_ending: str = "_PRES",
temperature_var: str = "T",
q_var: str = "Q",
surface_pressure_var: str = "SP",
Expand All @@ -20,17 +20,18 @@ def full_state_pressure_interpolation(
lat_var: str = "latitude",
lon_var: str = "longitude",
pres_var: str = "pressure",
level_var: str = "level",
model_level_file: str = "../credit/metadata/ERA5_Lev_Info.nc",
verbose: int = 1,
) -> xr.Dataset:
"""
Interpolate full model state variables from model levels to pressure levels.
Args:
state_dataset (xr.Dataset): state variables being interpolated
pressure_levels (np.ndarray): pressure levels for interpolation in Pa.
model_a (np.ndarray): model level a coefficients.
model_b (np.ndarray): model level b coefficients.
pressure_levels (np.ndarray): pressure levels for interpolation in hPa.
interp_fields (tuple[str]): fields to be interpolated.
pres_ending (str): ending string to attach to pressure interpolated variables.
temperature_var (str): temperature variable to be interpolated (units K).
q_var (str): mixing ratio/specific humidity variable to be interpolated (units kg/kg).
surface_pressure_var (str): surface pressure variable (units Pa).
Expand All @@ -40,10 +41,17 @@ def full_state_pressure_interpolation(
lat_var (str): latitude coordinate
lon_var (str): longitude coordinate
pres_var (str): pressure coordinate
level_var (str): name of level coordinate
model_level_file (str): relative path to file containing model levels.
verbose (int): verbosity level. If verbose > 0, print progress.
Returns:
pressure_ds (xr.Dataset): Dataset containing pressure interpolated variables.
"""
path_to_file = os.path.abspath(os.path.dirname(__file__))
model_level_file = os.path.join(path_to_file, model_level_file)
with xr.open_dataset(model_level_file) as mod_lev_ds:
model_a = mod_lev_ds["a_model"][state_dataset[level_var]].values
model_b = mod_lev_ds["b_model"][state_dataset[level_var]].values
pres_dims = (time_var, pres_var, lat_var, lon_var)
coords = {
time_var: state_dataset[time_var],
Expand All @@ -53,14 +61,17 @@ def full_state_pressure_interpolation(
}
pressure_ds = xr.Dataset(
data_vars={
f: xr.DataArray(
coords=coords, dims=pres_dims, name=f, attrs=state_dataset[f].attrs
f + pres_ending: xr.DataArray(
coords=coords,
dims=pres_dims,
name=f + pres_ending,
attrs=state_dataset[f].attrs,
)
for f in interp_fields
},
coords=coords,
)
pressure_ds[geopotential_var] = xr.DataArray(
pressure_ds[geopotential_var + pres_ending] = xr.DataArray(
coords=coords, dims=pres_dims, name=geopotential_var
)
disable = False
Expand All @@ -79,11 +90,17 @@ def full_state_pressure_interpolation(
model_b,
)
for interp_field in interp_fields:
pressure_ds[interp_field][t] = interp_hybrid_to_pressure_levels(
state_dataset[interp_field][t].values, pressure_grid, pressure_levels
pressure_ds[interp_field + pres_ending][t] = (
interp_hybrid_to_pressure_levels(
state_dataset[interp_field][t].values,
pressure_grid / 100.0,
pressure_levels,
)
)
pressure_ds[geopotential_var + pres_ending][t] = (
interp_hybrid_to_pressure_levels(
geopotential_grid, pressure_grid / 100.0, pressure_levels
)
pressure_ds[geopotential_var][t] = interp_hybrid_to_pressure_levels(
geopotential_grid, pressure_grid, pressure_levels
)
return pressure_ds

Expand Down Expand Up @@ -132,11 +149,12 @@ def create_pressure_grid(surface_pressure, model_a, model_b):
def interp_hybrid_to_pressure_levels(model_var, model_pressure, interp_pressures):
"""
Interpolate data field from hybrid sigma-pressure vertical coordinates to pressure levels.
`model_pressure` and `interp_pressure` should have consistent units with each other.
Args:
model_var (np.ndarray): 3D field on hybrid sigma-pressure levels with shape (levels, y, x).
model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa
interp_pressures: (np.ndarray): pressure levels for interpolation in units Pa.
model_pressure (np.ndarray): 3D pressure field with shape (levels, y, x) in units Pa or hPa
interp_pressures: (np.ndarray): pressure levels for interpolation in units Pa or hPa.
Returns:
pressure_var (np.ndarray): 3D field on pressure levels with shape (len(interp_pressures), y, x).
Expand Down
88 changes: 74 additions & 14 deletions credit/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
import traceback
import xarray as xr
from credit.data import drop_var_from_dataset

from credit.interp import full_state_pressure_interpolation
import numpy as np

logger = logging.getLogger(__name__)



def load_metadata(conf):
def load_metadata(conf: dict):
"""
Load metadata attributes from yaml file in credit/metadata directory
Args:
conf (dict): Configuration dictionary
"""
# set priorities for user-specified metadata
if conf["predict"]["metadata"]:
Expand Down Expand Up @@ -72,17 +76,38 @@ def split_and_reshape(tensor, conf):


def make_xarray(pred, forecast_datetime, lat, lon, conf):
"""
Convert prediction tensor to xarray DataArrays for later saving.
Args:
pred (torch.Tensor): full tensor containing output of the AI NWP model
forecast_datetime (pd.Timestamp or datetime.datetime): valid time of the forecast
lat: latitude coordinate array
lon: longitude coordinate array
conf (dict): config dictionary for training/rollout
Returns:
xr.DataArray: upper air predictions, xr.DataArray: surface variable predictions
"""
# subset upper air and surface variables
tensor_upper_air, tensor_single_level = split_and_reshape(pred, conf)

if "level_ids" in conf["data"].keys():
level_ids = conf["data"]["level_ids"]
else:
level_ids = np.array(
[10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137],
dtype=np.int64,
)

# save upper air variables
darray_upper_air = xr.DataArray(
tensor_upper_air,
dims=["time", "vars", "level", "lat", "lon"],
coords=dict(
vars=conf["data"]["variables"],
time=[forecast_datetime],
level=range(conf["model"]["levels"]),
level=level_ids,
lat=lat,
lon=lon,
),
Expand All @@ -100,13 +125,30 @@ def make_xarray(pred, forecast_datetime, lat, lon, conf):
lon=lon,
),
)
# return x-arrays as outputs
# return DataArrays as outputs
return darray_upper_air, darray_single_level


def save_netcdf_increment(
darray_upper_air, darray_single_level, nc_filename, forecast_hour, meta_data, conf
darray_upper_air: xr.DataArray,
darray_single_level: xr.DataArray,
nc_filename: str,
forecast_hour: int,
meta_data: dict,
conf: dict,
):
"""
Save CREDIT model prediction output to netCDF file. Also performs pressure level
interpolation on the output if you wish.
Args:
darray_upper_air (xr.DataArray): upper air variable predictions
darray_single_level (xr.DataArray): surface variable predictions
nc_filename (str): file description to go into output filenames
forecast_hour (int): how many hours since the initialization of the model.
meta_data (dict): metadata dictionary for output variables
conf (dict): configuration dictionary for training and/or rollout
"""
try:
"""
Save increment to a unique NetCDF file using Dask for parallel processing.
Expand All @@ -124,8 +166,21 @@ def save_netcdf_increment(
# Add CF convention version
ds_merged.attrs["Conventions"] = "CF-1.11"

# Add model config file parameters (x)
# ds_merged.attrs.update(conf)
if "interp_pressure" in conf["predict"].keys():
if "surface_geopotential_var" in conf["predict"]["interp_pressure"].keys():
surface_geopotential_var = conf["predict"]["interp_pressure"][
"surface_geopotential_var"
]
else:
surface_geopotential_var = "Z_GDS4_SFC"
with xr.open_dataset(conf["data"]["save_loc_static"]) as static_ds:
ds_merged[surface_geopotential_var] = static_ds[
surface_geopotential_var
]
pressure_interp = full_state_pressure_interpolation(
ds_merged, **conf["predict"]["interp_pressure"]
)
ds_merged = xr.merge([ds_merged, pressure_interp])

logger.info(f"Trying to save forecast hour {forecast_hour} to {nc_filename}")

Expand All @@ -135,7 +190,6 @@ def save_netcdf_increment(
unique_filename = os.path.join(
save_location, f"pred_{nc_filename}_{forecast_hour:03d}.nc"
)

# ---------------------------------------------------- #
# If conf['predict']['save_vars'] provided --> drop useless vars
if "save_vars" in conf["predict"]:
Expand All @@ -158,12 +212,18 @@ def save_netcdf_increment(
ds_merged.time.encoding[metadata_time] = meta_data["time"][
metadata_time
]

# Convert to Dask array if not already
ds_merged = ds_merged.chunk({"time": 1})

encoding_dict = {}
if "ua_var_encoding" in conf["predict"].keys():
for ua_var in conf["data"]["variables"]:
encoding_dict[ua_var] = conf["predict"]["ua_var_encoding"]
if "surface_var_encoding" in conf["predict"].keys():
for surface_var in conf["data"]["variables"]:
encoding_dict[surface_var] = conf["predict"]["surface_var_encoding"]
if "pressure_var_encoding" in conf["predict"].keys():
for pres_var in conf["data"]["variables"]:
encoding_dict[pres_var] = conf["predict"]["pressure_var_encoding"]
# Use Dask to write the dataset in parallel
ds_merged.to_netcdf(unique_filename, mode="w")
ds_merged.to_netcdf(unique_filename, encoding=encoding_dict)

logger.info(f"Saved forecast hour {forecast_hour} to {unique_filename}")
except Exception:
Expand Down
6 changes: 3 additions & 3 deletions credit/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def __init__(self, conf):
self.level_ids = conf["data"]["level_ids"]
else:
self.level_ids = np.array(
[10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136],
[10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137],
dtype=np.int64,
)
self.n_levels = int(conf["model"]["levels"])
Expand Down Expand Up @@ -889,7 +889,7 @@ def __call__(self, sample: Sample) -> Sample:
else:
try:
arr = DSD[sv].squeeze()
except:
except KeyError:
continue
arrs.append(arr)

Expand Down Expand Up @@ -1341,7 +1341,7 @@ def __call__(self, sample: Sample) -> Sample:
else:
try:
arr = DSD[sv].squeeze()
except:
except KeyError:
continue
arrs.append(arr)

Expand Down
15 changes: 4 additions & 11 deletions tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@
def test_full_state_pressure_interpolation():
path_to_test = os.path.abspath(os.path.dirname(__file__))
input_file = os.path.join(path_to_test, "data/test_interp.nc")
model_level_file = os.path.join(path_to_test, "../credit/metadata/ERA5_Lev_Info.nc")
ds = xr.open_dataset(input_file)
model_levels = xr.open_dataset(model_level_file)
pressure_levels = np.array([200.0, 500.0, 700.0, 850.0, 1000.0]) * 100.0
model_a = model_levels["a_model"].loc[ds["level"]].values
model_b = model_levels["b_model"].loc[ds["level"]].values
interp_ds = full_state_pressure_interpolation(ds,
pressure_levels,
model_a,
model_b,
lat_var="lat",
lon_var="lon")
pressure_levels = np.array([200.0, 500.0, 700.0, 850.0, 1000.0])
interp_ds = full_state_pressure_interpolation(
ds, pressure_levels=pressure_levels, lat_var="lat", lon_var="lon"
)
assert interp_ds["U"].shape[1] == pressure_levels.size, "Pressure level mismatch"
return

0 comments on commit 44ae540

Please sign in to comment.