From 554c15f1ebc2a9d93fc2e6d763b02a37073531ba Mon Sep 17 00:00:00 2001 From: Arnold Kazadi Date: Sat, 23 Nov 2024 18:49:41 -0700 Subject: [PATCH] Restore mpi in rollout_to_netcdf.py --- applications/rollout_to_netcdf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index ab24f2c1..abc284b0 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -26,6 +26,7 @@ # credit from credit.models import load_model from credit.seed import seed_everything +from credit.distributed import get_rank_info from credit.data import ( concat_and_reshape, @@ -951,11 +952,11 @@ def predict(rank, world_size, conf, p): seed = 1000 if "seed" not in conf else conf["seed"] seed_everything(seed) + local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"]) + with mp.Pool(num_cpus) as p: if conf["predict"]["mode"] in ["fsdp", "ddp"]: # multi-gpu inference - _ = predict( - int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]), conf, p=p - ) + _ = predict(world_rank, world_size, conf, p=p) else: # single device inference _ = predict(0, 1, conf, p=p)