diff --git a/applications/rollout_to_netcdf.py b/applications/rollout_to_netcdf.py index 96475ae6..1457f53b 100644 --- a/applications/rollout_to_netcdf.py +++ b/applications/rollout_to_netcdf.py @@ -26,6 +26,7 @@ # credit from credit.models import load_model from credit.seed import seed_everything +from credit.distributed import get_rank_info from credit.data import ( concat_and_reshape, @@ -967,11 +968,11 @@ def predict(rank, world_size, conf, p): seed = 1000 if "seed" not in conf else conf["seed"] seed_everything(seed) + local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"]) + with mp.Pool(num_cpus) as p: if conf["predict"]["mode"] in ["fsdp", "ddp"]: # multi-gpu inference - _ = predict( - int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]), conf, p=p - ) + _ = predict(world_rank, world_size, conf, p=p) else: # single device inference _ = predict(0, 1, conf, p=p)