diff --git a/.github/container/Dockerfile.triton b/.github/container/Dockerfile.triton index ee0cbc4d9..1a7e90d58 100644 --- a/.github/container/Dockerfile.triton +++ b/.github/container/Dockerfile.triton @@ -65,6 +65,8 @@ sed -i -e 's|-Werror||g' CMakeLists.txt sed -i -e 's|\(LLVMAMDGPU.*\)|# \1|g' CMakeLists.txt # Do not build tests sed -i -e 's|^\s*add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt +# Avoid building GPUHello.cpp, which does not compile with cl693214397 of OpenXLA Triton +sed -i -e 's|^add_subdirectory(Instrumentation)|# Instrumentation test disabled|' test/lib/CMakeLists.txt # Avoid error due to forward declaration of Module sed -i -e '/#include "llvm\/IR\/IRBuilder.h"/a #include "llvm/IR/Module.h"' test/lib/Instrumentation/GPUHello.cpp # Google has patches that mess with include paths in source files diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index db01bccff..95cf5246b 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -274,25 +274,25 @@ if [[ ! -e "/usr/local/cuda/lib" ]]; then fi if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then - echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc" + echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc" fi cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF -build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda" -build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn" -build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl" +build --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda" +build --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn" +build --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl" EOF -time python "${SRC_PATH_JAX}/build/build.py" \ + +pushd ${SRC_PATH_JAX} +time python "${SRC_PATH_JAX}/build/build.py" build \ --editable \ --use_clang \ - --enable_cuda \ - --build_gpu_plugin \ - --gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \ + --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \ --cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \ - --enable_nccl=true \ --bazel_options=--linkopt=-fuse-ld=lld \ - --bazel_options=--override_repository=xla=$SRC_PATH_XLA \ + --local_xla_path=$SRC_PATH_XLA \ --output_path=${BUILD_PATH_JAXLIB} \ $BUILD_PARAM +popd # Make sure that JAX depends on the local jaxlib installation # https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels @@ -300,8 +300,8 @@ line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib" if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then pushd "${SRC_PATH_JAX}" echo "${line}" >> build/requirements.in - echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in - echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in + echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax-cuda-pjrt" >> build/requirements.in + echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax-cuda-plugin" >> build/requirements.in PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))') bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}" popd @@ -316,13 +316,13 @@ else fi # install jax and jaxlib -pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}" +pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}" -# after installation (example) -# jax 0.4.32.dev20240808+9c2caedab /opt/jax -# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_pjrt -# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_plugin -# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib +## after installation (example) +# jax 0.4.36.dev20241125+f828f2d7d /opt/jax +# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt +# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin +# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib pip list | grep jax # Ensure directories are readable by all for non-root users diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index f3e4e0715..4aa001cf1 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -12,7 +12,7 @@ export DEBIAN_FRONTEND=noninteractive export TZ=America/Los_Angeles apt-get update -apt-get install -y nsight-compute nsight-systems-cli +apt-get install -y nsight-compute nsight-systems-cli-2024.6.1 apt-get clean rm -rf /var/lib/apt/lists/* diff --git a/.github/container/nsys-jax-combine b/.github/container/nsys-jax-combine index 00e53efb6..36ef782af 100755 --- a/.github/container/nsys-jax-combine +++ b/.github/container/nsys-jax-combine @@ -192,8 +192,9 @@ with zipfile.ZipFile(args.output, "w") as ofile: ) # Gather output files of the scrpt for path in working_dir.rglob("*"): - with open(working_dir / path, "rb") as src, ofile.open( - str(path.relative_to(mirror_dir)), "w" - ) as dst: + with ( + open(working_dir / path, "rb") as src, + ofile.open(str(path.relative_to(mirror_dir)), "w") as dst, + ): # https://github.com/python/mypy/issues/15031 ? shutil.copyfileobj(src, dst) # type: ignore diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index d73ed6d7e..3a14c7a72 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -26,7 +26,7 @@ jax_source_dir() { query_tests() { cd `jax_source_dir` - python build/build.py --configure_only + python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel query tests/... 2>&1 | grep -F '//tests:' exit } @@ -113,8 +113,6 @@ for t in $*; do done TEST_TAG_FILTER_ARRAY=() -TEST_TAG_FILTER_ARRAY+=('-multiaccelerator') - COMMON_FLAGS=$(cat << EOF --@local_config_cuda//:enable_cuda --cache_test_results=${CACHE_TEST_RESULTS} @@ -163,7 +161,10 @@ case "${BATTERY}" in ;; esac -TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") +TEST_TAG_FILTERS="" +if [[ ${#TEST_TAG_FILTER_ARRAY[@]} > 0 ]]; then + TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") +fi print_var NCPUS print_var NGPUS @@ -190,5 +191,5 @@ pip install matplotlib ## Run tests cd `jax_source_dir` -python build/build.py --configure_only +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS} diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 2f77c07f4..4816c709e 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -295,6 +295,7 @@ jobs: docker run -i --shm-size=1g --gpus all \ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ bash <<"EOF" |& tee tee test-gpu.log + nvidia-cuda-mps-control -d test-jax.sh -b gpu EOF STATISTICS_SCRIPT: | @@ -463,8 +464,9 @@ jobs: docker run -i --shm-size=1g --gpus all --volume $PWD:/output \ ${{ needs.build-triton.outputs.DOCKER_TAG_FINAL }} \ bash <<"EOF" |& tee test-triton.log - # autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch... - pip install --no-deps torch --index-url https://download.pytorch.org/whl/cpu + # autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch; this relies on + # actually having a CUDA backend for pytoch + pip install --no-deps torch python /opt/jax-triton/tests/triton_call_test.py --xml_output_file /output/triton_test.xml EOF STATISTICS_SCRIPT: |