From 38d8d0abd90bbc134c2cce8186222b3fccd68725 Mon Sep 17 00:00:00 2001 From: Twice Date: Sat, 23 Nov 2024 01:16:27 +0800 Subject: [PATCH] Add build and test workflow for PJRT plugin in pkgci (#19222) This changes are for adding the build workflow for `iree-pjrt-plugin-*` packages in `build_tools/pkgci` scripts. Also, it enables the build of package `iree-pjrt-plugin-cpu` in Github Actions (pkgci.yaml). Some other changes are for fixing some build problems: - ninja is missing from pyproject.toml - IREE_BUILD_COMPILER need to be enabled to make IREELLVMIncludeSetup available It closes #19221. ci-exactly: build_packages, test_pjrt --------- Signed-off-by: PragmaTwice Signed-off-by: PragmaTwice --- .github/workflows/pkgci.yml | 6 ++ .github/workflows/pkgci_test_pjrt.yml | 69 +++++++++++++++++++ build_tools/testing/run_jax_tests.sh | 67 ++++++++++++++++++ integrations/pjrt/CMakeLists.txt | 4 +- .../_setup_support/iree_pjrt_setup.py | 2 +- .../iree_cpu_plugin/pyproject.toml | 1 + .../iree_cuda_plugin/pyproject.toml | 1 + .../iree_rocm_plugin/pyproject.toml | 1 + .../iree_vulkan_plugin/pyproject.toml | 1 + integrations/pjrt/test/test_add.py | 9 +++ 10 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/pkgci_test_pjrt.yml create mode 100755 build_tools/testing/run_jax_tests.sh create mode 100644 integrations/pjrt/test/test_add.py diff --git a/.github/workflows/pkgci.yml b/.github/workflows/pkgci.yml index 69d2f43dda8e..e4017cedeb06 100644 --- a/.github/workflows/pkgci.yml +++ b/.github/workflows/pkgci.yml @@ -109,3 +109,9 @@ jobs: needs: [setup, build_packages] if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_tensorflow') uses: ./.github/workflows/pkgci_test_tensorflow.yml + + test_pjrt: + name: Test PJRT plugin + needs: [setup, build_packages] + if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_pjrt') + uses: ./.github/workflows/pkgci_test_pjrt.yml diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml new file mode 100644 index 000000000000..8237648b8540 --- /dev/null +++ b/.github/workflows/pkgci_test_pjrt.yml @@ -0,0 +1,69 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: PkgCI Test PJRT plugin +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + workflow_dispatch: + inputs: + artifact_run_id: + type: string + default: "" + +jobs: + build_and_test: + strategy: + matrix: + include: + - runner: ubuntu-20.04 + pjrt_platform: cpu + # TODO: cuda runner is not available yet, refer to #18814 + # - runner: some-cuda-available-runner + # pjrt_platform: cuda + # TODO: enable these AMD runners + # - runner: nodai-amdgpu-w7900-x86-64 + # pjrt_platform: rocm + # - runner: nodai-amdgpu-w7900-x86-64 + # pjrt_platform: vulkan + name: Build and test + runs-on: ${{ matrix.runner }} + env: + PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages + VENV_DIR: ${{ github.workspace }}/.venv + steps: + - name: Checking out repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + submodules: true + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.1.0 + with: + # Must match the subset of versions built in pkgci_build_packages. + python-version: "3.11" + - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + name: linux_x86_64_release_packages + path: ${{ env.PACKAGE_DOWNLOAD_DIR }} + - name: Setup base venv + run: | + ./build_tools/pkgci/setup_venv.py ${VENV_DIR} \ + --artifact-path=${PACKAGE_DOWNLOAD_DIR} \ + --fetch-gh-workflow=${{ inputs.artifact_run_id }} + - name: Build and install + run: | + # install editable into venv + source ${VENV_DIR}/bin/activate + python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin + # install jax (must be no larger than 0.4.20, refer to #19223) + # TODO: switch to the latest JAX after #19223 is fixed + python -m pip install jax==0.4.20 jaxlib==0.4.20 'numpy<2' + - name: Run tests + run: | + source ${VENV_DIR}/bin/activate + build_tools/testing/run_jax_tests.sh ${{ matrix.pjrt_platform }} diff --git a/build_tools/testing/run_jax_tests.sh b/build_tools/testing/run_jax_tests.sh new file mode 100755 index 000000000000..36e502274d1f --- /dev/null +++ b/build_tools/testing/run_jax_tests.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set -xeo pipefail + +pjrt_platform=$1 + +if [ -z "${pjrt_platform}" ]; then + set +x + echo "Usage: run_jax_tests.sh " + echo " can be 'cpu', 'cuda', 'rocm' or 'vulkan'" + exit 1 +fi + +# cd into the PJRT plugin dir +ROOT_DIR="${ROOT_DIR:-$(git rev-parse --show-toplevel)}" +cd "${ROOT_DIR}/integrations/pjrt" + +# perform some differential testing +actual_jax_platform=iree_${pjrt_platform} +expected_jax_platform=cpu + +# this function will execute the test python script in +# both cpu mode and the IREE PJRT mode, +# and then compare the difference in the output +diff_jax_test() { + local test_py_file=$1 + + echo "executing ${test_py_file} in ${expected_jax_platform}.." + local expected_tmp_out=$(mktemp /tmp/jax_test_result_expected.XXXXXX) + JAX_PLATFORMS=$expected_jax_platform python $test_py_file > $expected_tmp_out + + echo "executing ${test_py_file} in ${actual_jax_platform}.." + local actual_tmp_out=$(mktemp /tmp/jax_test_result_actual.XXXXXX) + JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out + + echo "comparing ${expected_tmp_out} and ${actual_tmp_out}.." + diff --unified $expected_tmp_out $actual_tmp_out + echo "no difference found" +} + +# FIXME: due to #19223, we need to use jax no higher than 0.4.20, +# but in such version of jax, 'stablehlo.broadcast_in_dim' op +# will be emitted without attribute 'broadcast_dimensions', +# which leads to an error in IREE PJRT plugin. +# So currently any program with broadcast will fail, +# e.g. test/test_simple.py. +# After #19223 is fixed, we can uncomment the line below. + +# diff_jax_test test/test_simple.py + +diff_jax_test test/test_add.py + + +# FIXME: we can also utilize the native test cases from JAX, +# e.g. `tests/nn_test.py` from the JAX repo, as below, +# but currently some test cases in this file will fail. +# NOTE that `absl-py` is required to run these tests. + +# local jax_nn_test_file=$(mktemp /tmp/jax_nn_test.XXXXXX.py) +# wget https://github.com/jax-ml/jax/blob/jax-v0.4.20/tests/nn_test.py -O $jax_nn_test_file +# JAX_PLATFORMS=$actual_jax_platform python $jax_nn_test_file diff --git a/integrations/pjrt/CMakeLists.txt b/integrations/pjrt/CMakeLists.txt index 5b42f5bfa49d..9d904300ba99 100644 --- a/integrations/pjrt/CMakeLists.txt +++ b/integrations/pjrt/CMakeLists.txt @@ -38,7 +38,9 @@ if(IREE_PJRT_ENABLE_LTO) endif() # Customize defaults. -option(IREE_BUILD_COMPILER "Disable compiler for runtime-library build" OFF) +# IREE_BUILD_COMPILER should be enabled to make target IREELLVMIncludeSetup available, +# which is required by PJRT dylib targets +option(IREE_BUILD_COMPILER "Enable compiler for runtime-library build" ON) option(IREE_BUILD_SAMPLES "Disable samples for runtime-library build" OFF) # Include IREE. diff --git a/integrations/pjrt/python_packages/_setup_support/iree_pjrt_setup.py b/integrations/pjrt/python_packages/_setup_support/iree_pjrt_setup.py index a43bdc2852bc..940ee0793c7d 100644 --- a/integrations/pjrt/python_packages/_setup_support/iree_pjrt_setup.py +++ b/integrations/pjrt/python_packages/_setup_support/iree_pjrt_setup.py @@ -140,7 +140,7 @@ def build_configuration(self, cmake_build_dir, extra_cmake_args=()): cmake_args = [ "-GNinja", "--log-level=VERBOSE", - "-DIREE_BUILD_COMPILER=OFF", + "-DIREE_BUILD_COMPILER=ON", "-DIREE_BUILD_SAMPLES=OFF", "-DIREE_BUILD_TESTS=OFF", "-DIREE_HAL_DRIVER_DEFAULTS=OFF", diff --git a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml index f6c16894020a..f539b95a4ede 100644 --- a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml @@ -2,5 +2,6 @@ requires = [ "setuptools>=42", "wheel", + "ninja", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml index f6c16894020a..f539b95a4ede 100644 --- a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml @@ -2,5 +2,6 @@ requires = [ "setuptools>=42", "wheel", + "ninja", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml index f6c16894020a..f539b95a4ede 100644 --- a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml @@ -2,5 +2,6 @@ requires = [ "setuptools>=42", "wheel", + "ninja", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml index f6c16894020a..f539b95a4ede 100644 --- a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml @@ -2,5 +2,6 @@ requires = [ "setuptools>=42", "wheel", + "ninja", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/test/test_add.py b/integrations/pjrt/test/test_add.py new file mode 100644 index 000000000000..233c1ef7edce --- /dev/null +++ b/integrations/pjrt/test/test_add.py @@ -0,0 +1,9 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import jax + +a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]) + +print(a + a)