diff --git a/applications/rollout_metrics.py b/applications/rollout_metrics.py index 2610548..e847a21 100644 --- a/applications/rollout_metrics.py +++ b/applications/rollout_metrics.py @@ -281,6 +281,13 @@ def predict(rank, world_size, conf, p): # no y_surf y = reshape_only(batch["y"]).to(device).float() + # adding diagnostic vars to y + if "y_diag" in batch: + y_diag_batch = ( + batch["y_diag"].to(device).permute(0, 2, 1, 3, 4) + ) + y = torch.cat((y, y_diag_batch), dim=1).to(device).float() + # -------------------------------------------------------------------------------------- # # start prediction diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index 1914cc2..167b0ae 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -284,7 +284,14 @@ def predict(rank, world_size, conf, p): else: # no y_surf y = reshape_only(batch["y"]).to(device).float() - + + # adding diagnostic vars to y + if "y_diag" in batch: + y_diag_batch = ( + batch["y_diag"].to(device).permute(0, 2, 1, 3, 4) + ) + y = torch.cat((y, y_diag_batch), dim=1).to(device).float() + # -------------------------------------------------------------------------------------- # # start prediction