Skip to content

Commit

Permalink
Increase JAX unit test timeout and add option to enable/disable resul…
Browse files Browse the repository at this point in the history
…t caching (#723)
  • Loading branch information
yhtang committed Apr 16, 2024
1 parent 0524d0e commit f0c6772
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ usage() {
echo "Usage: $0 [OPTION]... TESTS"
echo " -b, --battery Specify predefined test batteries to run."
echo " --build-jaxlib Runs the JAX tests using jaxlib built form source."
echo " --cache-test-results yes|no|auto, passes through to bazel `--cache_test_results`"
echo " --reuse-jaxlib Runs the JAX tests using preinstalled jaxlib. (DEFAULT)"
echo " --disable-x64 Disable 64-bit floating point support in JAX (some tests may fail)"
echo " --enable-x64 Enable 64-bit floating point support in JAX (DEFAULT, required for some tests)"
Expand Down Expand Up @@ -35,11 +36,17 @@ print_var() {
echo "$1: ${!1}"
}

args=$(getopt -o b:qh --long build-jaxlib,disable-x64,enable-x64,reuse-jaxlib,query,help -- "$@")
args=$(getopt -o b:qh --long battery:,build-jaxlib,cache-test-results:,disable-x64,enable-x64,reuse-jaxlib,query,help -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi

## Set default arguments if not provided via command-line
BUILD_JAXLIB=0
CACHE_TEST_RESULTS=no
ENABLE_X64=-1

## parse arguments
eval set -- "$args"
while [ : ]; do
case "$1" in
Expand All @@ -51,6 +58,10 @@ while [ : ]; do
BUILD_JAXLIB=1
shift 1
;;
--cache-test-results)
CACHE_TEST_RESULTS="$2"
shift 2
;;
--enable-x64)
ENABLE_X64=1
shift 1
Expand Down Expand Up @@ -81,11 +92,6 @@ if [[ $# -eq 0 ]] && [[ -z "$BATTERY" ]]; then
exit 1
fi

## Set default arguments if not provided via command-line

BUILD_JAXLIB=${BUILD_JAXLIB:-0}
ENABLE_X64=${ENABLE_X64:-1}

## Set derived variables

NCPUS=$(grep -c '^processor' /proc/cpuinfo)
Expand All @@ -111,6 +117,8 @@ for t in $*; do
done

COMMON_FLAGS=$(cat << EOF
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
--test_env=JAX_SKIP_SLOW_TESTS=1
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
Expand Down

0 comments on commit f0c6772

Please sign in to comment.