Skip to content

Commit

Permalink
bugfix credit.output on diagnostic variable creation
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Oct 20, 2024
1 parent 1788656 commit 7dd44cb
Showing 1 changed file with 71 additions and 12 deletions.
83 changes: 71 additions & 12 deletions credit/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def split_and_reshape(tensor, conf):

# get number of channels
channels = len(conf["data"]["variables"])
single_level_channels = len(conf["data"]["surface_variables"])
single_level_channels = len(conf["data"]["surface_variables"]) + len(conf["data"]["diagnostic_variables"])

# subset upper air variables
tensor_upper_air = tensor[:, : int(channels * levels), :, :]
Expand All @@ -71,16 +71,45 @@ def split_and_reshape(tensor, conf):


def make_xarray(pred, forecast_datetime, lat, lon, conf):

"""
Create two xarray.DataArray objects for upper air and surface variables.
Parameters:
-----------
pred : torch.Tensor or np.ndarray
Prediction tensor containing both upper air and surface variables.
forecast_datetime : datetime
The forecast initialization datetime.
lat : np.ndarray or list
Latitude values.
lon : np.ndarray or list
Longitude values.
conf : dict
Configuration dictionary containing details about the data structure
and variables.
Returns:
--------
darray_upper_air : xarray.DataArray
DataArray containing upper air variables with dimensions
[time, vars, level, latitude, longitude].
darray_single_level : xarray.DataArray
DataArray containing surface variables with dimensions
[time, vars, latitude, longitude].
"""

# subset upper air and surface variables
tensor_upper_air, tensor_single_level = split_and_reshape(pred, conf)

# save upper air variables
varname_upper = conf['data']['variables']

# make xr.DatasArray
darray_upper_air = xr.DataArray(
tensor_upper_air,
dims=["time", "vars", "level", "lat", "lon"],
dims=["time", "vars", "level", "latitude", "longitude"],
coords=dict(
vars=conf["data"]["variables"],
vars=varname_upper,
time=[forecast_datetime],
level=range(conf["model"]["levels"]),
lat=lat,
Expand All @@ -89,26 +118,55 @@ def make_xarray(pred, forecast_datetime, lat, lon, conf):
)

# save surface variables
# !!! need to add diagnostic vars !!!
varname_single_level = conf['data']['surface_variables'] + conf['data']['diagnostic_variables']

# make xr.DatasArray
darray_single_level = xr.DataArray(
tensor_single_level.squeeze(2),
dims=["time", "vars", "lat", "lon"],
dims=["time", "vars", "latitude", "longitude"],
coords=dict(
vars=conf["data"]["surface_variables"],
vars=varname_single_level,
time=[forecast_datetime],
lat=lat,
lon=lon,
),
)

# return x-arrays as outputs
return darray_upper_air, darray_single_level


def save_netcdf_increment(darray_upper_air, darray_single_level, nc_filename, forecast_hour, meta_data, conf):
def save_netcdf_increment(darray_upper_air,
darray_single_level,
nc_filename,
forecast_hour,
meta_data,
conf):
"""
Save forecast increments to a unique NetCDF file using Dask for parallel processing.
Parameters:
-----------
darray_upper_air : xarray.DataArray
DataArray containing upper air variables.
darray_single_level : xarray.DataArray
DataArray containing surface level variables.
nc_filename : str
Base name of the NetCDF file to be saved.
forecast_hour : int
The forecast hour corresponding to the data being saved.
meta_data : dict or bool
Metadata information for the variables in the dataset. If False, no metadata is applied.
conf : dict
Configuration dictionary containing paths and parameters for saving the NetCDF files.
Returns:
--------
None
The function saves the merged upper air and surface datasets to a unique NetCDF file.
"""
try:
"""
Save increment to a unique NetCDF file using Dask for parallel processing.
"""
# Convert DataArrays to Datasets
ds_upper = darray_upper_air.to_dataset(dim="vars")
ds_single = darray_single_level.to_dataset(dim="vars")
Expand Down Expand Up @@ -159,4 +217,5 @@ def save_netcdf_increment(darray_upper_air, darray_single_level, nc_filename, fo

logger.info(f"Saved forecast hour {forecast_hour} to {unique_filename}")
except Exception:
print(traceback.format_exc())
print(traceback.format_exc())

0 comments on commit 7dd44cb

Please sign in to comment.