diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 90e7c2488..96e5862f8 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -174,6 +174,9 @@ if [ $DTYPE == "fp8" ]; then fi GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') +if [ "$CUDA_VISIBLE_DEVICES" != "" ]; then + GPUS_PER_NODE=`python -c 'import os; x=os.environ.get("CUDA_VISIBLE_DEVICES", ""); print(len(x.split(",")))'` +fi NGPUS=$((GPUS_PER_NODE * NODES)) # Heuristic to figure out ici and dcn of DP