diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 6afbaace1..d73ed6d7e 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -112,11 +112,13 @@ for t in $*; do BAZEL_TARGET="${BAZEL_TARGET} $t" done +TEST_TAG_FILTER_ARRAY=() +TEST_TAG_FILTER_ARRAY+=('-multiaccelerator') + COMMON_FLAGS=$(cat << EOF --@local_config_cuda//:enable_cuda --cache_test_results=${CACHE_TEST_RESULTS} --test_timeout=600 ---test_tag_filters=-multiaccelerator --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=JAX_ACCELERATOR_COUNT=${NGPUS} --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform @@ -138,7 +140,11 @@ case "${BATTERY}" in JOBS_PER_GPU=8 JOBS=$((NGPUS * JOBS_PER_GPU)) EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow" - BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests" + # collect from all tests subdirectories recursively, + # use jax_test_gpu tag generated by jax_multiplatform_test rule: + # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265 + TEST_TAG_FILTER_ARRAY+=('jax_test_gpu') + BAZEL_TARGET="${BAZEL_TARGET} //tests/..." ;; backend-independent) JOBS_PER_GPU=4 @@ -157,6 +163,8 @@ case "${BATTERY}" in ;; esac +TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") + print_var NCPUS print_var NGPUS print_var BATTERY @@ -165,6 +173,7 @@ print_var JOBS_PER_GPU print_var JOBS print_var BUILD_JAXLIB print_var BAZEL_TARGET +print_var TEST_TAG_FILTERS print_var COMMON_FLAGS print_var EXTRA_FLAGS @@ -182,4 +191,4 @@ pip install matplotlib cd `jax_source_dir` python build/build.py --configure_only -bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS} +bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}