Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect recursively and filter GPU tests using jax_test_gpu tag #1091

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ for t in $*; do
BAZEL_TARGET="${BAZEL_TARGET} $t"
done

TEST_TAG_FILTER_ARRAY=()
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
--test_env=JAX_SKIP_SLOW_TESTS=1
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
Expand All @@ -138,7 +140,11 @@ case "${BATTERY}" in
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
# collect from all tests subdirectories recursively,
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
TEST_TAG_FILTER_ARRAY+=('jax_test_gpu')
BAZEL_TARGET="${BAZEL_TARGET} //tests/..."
;;
backend-independent)
JOBS_PER_GPU=4
Expand All @@ -157,6 +163,8 @@ case "${BATTERY}" in
;;
esac

TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")

print_var NCPUS
print_var NGPUS
print_var BATTERY
Expand All @@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
print_var JOBS
print_var BUILD_JAXLIB
print_var BAZEL_TARGET
print_var TEST_TAG_FILTERS
print_var COMMON_FLAGS
print_var EXTRA_FLAGS

Expand All @@ -182,4 +191,4 @@ pip install matplotlib

cd `jax_source_dir`
python build/build.py --configure_only
bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}
Loading