Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into vkozlov/move-to-ubunt…
Browse files Browse the repository at this point in the history
…u24.04
  • Loading branch information
DwarKapex committed Nov 27, 2024
2 parents 3c2ec97 + de72dd8 commit d279373
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/container/Dockerfile.triton
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ sed -i -e 's|-Werror||g' CMakeLists.txt
sed -i -e 's|\(LLVMAMDGPU.*\)|# \1|g' CMakeLists.txt
# Do not build tests
sed -i -e 's|^\s*add_subdirectory(unittest)|# unit tests disabled|' CMakeLists.txt
# Avoid building GPUHello.cpp, which does not compile with cl693214397 of OpenXLA Triton
sed -i -e 's|^add_subdirectory(Instrumentation)|# Instrumentation test disabled|' test/lib/CMakeLists.txt
# Avoid error due to forward declaration of Module
sed -i -e '/#include "llvm\/IR\/IRBuilder.h"/a #include "llvm/IR/Module.h"' test/lib/Instrumentation/GPUHello.cpp
# Google has patches that mess with include paths in source files
Expand Down
36 changes: 18 additions & 18 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -274,34 +274,34 @@ if [[ ! -e "/usr/local/cuda/lib" ]]; then
fi

if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then
echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
fi
cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF
build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
build --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
EOF
time python "${SRC_PATH_JAX}/build/build.py" \

pushd ${SRC_PATH_JAX}
time python "${SRC_PATH_JAX}/build/build.py" build \
--editable \
--use_clang \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
--bazel_options=--override_repository=xla=$SRC_PATH_XLA \
--local_xla_path=$SRC_PATH_XLA \
--output_path=${BUILD_PATH_JAXLIB} \
$BUILD_PARAM
popd

# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax-cuda-pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax-cuda-plugin" >> build/requirements.in
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
popd
Expand All @@ -316,13 +316,13 @@ else
fi

# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}"
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}"

# after installation (example)
# jax 0.4.32.dev20240808+9c2caedab /opt/jax
# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_pjrt
# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_plugin
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
## after installation (example)
# jax 0.4.36.dev20241125+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
Expand Down
2 changes: 1 addition & 1 deletion .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

apt-get update
apt-get install -y nsight-compute nsight-systems-cli
apt-get install -y nsight-compute nsight-systems-cli-2024.6.1
apt-get clean

rm -rf /var/lib/apt/lists/*
Expand Down
7 changes: 4 additions & 3 deletions .github/container/nsys-jax-combine
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ with zipfile.ZipFile(args.output, "w") as ofile:
)
# Gather output files of the scrpt
for path in working_dir.rglob("*"):
with open(working_dir / path, "rb") as src, ofile.open(
str(path.relative_to(mirror_dir)), "w"
) as dst:
with (
open(working_dir / path, "rb") as src,
ofile.open(str(path.relative_to(mirror_dir)), "w") as dst,
):
# https://github.com/python/mypy/issues/15031 ?
shutil.copyfileobj(src, dst) # type: ignore
11 changes: 6 additions & 5 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jax_source_dir() {

query_tests() {
cd `jax_source_dir`
python build/build.py --configure_only
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel query tests/... 2>&1 | grep -F '//tests:'
exit
}
Expand Down Expand Up @@ -113,8 +113,6 @@ for t in $*; do
done

TEST_TAG_FILTER_ARRAY=()
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
Expand Down Expand Up @@ -163,7 +161,10 @@ case "${BATTERY}" in
;;
esac

TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")
TEST_TAG_FILTERS=""
if [[ ${#TEST_TAG_FILTER_ARRAY[@]} > 0 ]]; then
TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")
fi

print_var NCPUS
print_var NGPUS
Expand All @@ -190,5 +191,5 @@ pip install matplotlib
## Run tests

cd `jax_source_dir`
python build/build.py --configure_only
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}
6 changes: 4 additions & 2 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ jobs:
docker run -i --shm-size=1g --gpus all \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee tee test-gpu.log
nvidia-cuda-mps-control -d
test-jax.sh -b gpu
EOF
STATISTICS_SCRIPT: |
Expand Down Expand Up @@ -463,8 +464,9 @@ jobs:
docker run -i --shm-size=1g --gpus all --volume $PWD:/output \
${{ needs.build-triton.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-triton.log
# autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch...
pip install --no-deps torch --index-url https://download.pytorch.org/whl/cpu
# autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch; this relies on
# actually having a CUDA backend for pytoch
pip install --no-deps torch
python /opt/jax-triton/tests/triton_call_test.py --xml_output_file /output/triton_test.xml
EOF
STATISTICS_SCRIPT: |
Expand Down

0 comments on commit d279373

Please sign in to comment.