Skip to content

Commit

Permalink
Merge pull request #132 from NCAR/fix_rollout_netcdf
Browse files Browse the repository at this point in the history
Fix mpi in `rollout_to_netcdf.py`
  • Loading branch information
jsschreck authored Nov 24, 2024
2 parents 6fe8c7b + 1c3d77d commit 6814c8d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions applications/rollout_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# credit
from credit.models import load_model
from credit.seed import seed_everything
from credit.distributed import get_rank_info

from credit.data import (
concat_and_reshape,
Expand Down Expand Up @@ -967,11 +968,11 @@ def predict(rank, world_size, conf, p):
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"])

with mp.Pool(num_cpus) as p:
if conf["predict"]["mode"] in ["fsdp", "ddp"]: # multi-gpu inference
_ = predict(
int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]), conf, p=p
)
_ = predict(world_rank, world_size, conf, p=p)
else: # single device inference
_ = predict(0, 1, conf, p=p)

Expand Down

0 comments on commit 6814c8d

Please sign in to comment.