diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 96e5862f8..ebb2afcdc 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -174,7 +174,7 @@ if [ $DTYPE == "fp8" ]; then fi GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') -if [ "$CUDA_VISIBLE_DEVICES" != "" ]; then +if [ -n "${CUDA_VISIBLE_DEVICES:-}" ]; then GPUS_PER_NODE=`python -c 'import os; x=os.environ.get("CUDA_VISIBLE_DEVICES", ""); print(len(x.split(",")))'` fi NGPUS=$((GPUS_PER_NODE * NODES))