Skip to content

Commit

Permalink
Only set CUDA_DEVICE_MAX_CONNECTIONS=1 for Hopper/cc9.0 runs
Browse files Browse the repository at this point in the history
Cherry-pick of #1236 to main.
  • Loading branch information
olupton committed Jan 14, 2025
1 parent 34066f5 commit 7339a81
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
1 change: 0 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ ENV BUILD_DATE=${BUILD_DATE}
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0

COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
Expand Down
7 changes: 6 additions & 1 deletion .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,12 @@ pushd ${MAXTEXT_DIR}

export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

local_arch=$(local_cuda_arch)
if [[ "${local_arch}" == "9.0" ]]; then
echo "Setting CUDA_DEVICE_MAX_CONNECTIONS=1 for cc${local_arch} devices"
export CUDA_DEVICE_MAX_CONNECTIONS=1
fi

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).
Expand Down

0 comments on commit 7339a81

Please sign in to comment.