Skip to content

Commit

Permalink
Add build and test workflow for PJRT plugin in pkgci (iree-org#19222)
Browse files Browse the repository at this point in the history
This changes are for adding the build workflow for `iree-pjrt-plugin-*`
packages in `build_tools/pkgci` scripts.

Also, it enables the build of package `iree-pjrt-plugin-cpu` in Github
Actions (pkgci.yaml).

Some other changes are for fixing some build problems:
- ninja is missing from pyproject.toml
- IREE_BUILD_COMPILER need to be enabled to make IREELLVMIncludeSetup
available

It closes iree-org#19221.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
Signed-off-by: PragmaTwice <twice.mliu@gmail.com>
  • Loading branch information
PragmaTwice authored Nov 22, 2024
1 parent 12476d9 commit 38d8d0a
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/pkgci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,9 @@ jobs:
needs: [setup, build_packages]
if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_tensorflow')
uses: ./.github/workflows/pkgci_test_tensorflow.yml

test_pjrt:
name: Test PJRT plugin
needs: [setup, build_packages]
if: contains(fromJson(needs.setup.outputs.enabled-jobs), 'test_pjrt')
uses: ./.github/workflows/pkgci_test_pjrt.yml
69 changes: 69 additions & 0 deletions .github/workflows/pkgci_test_pjrt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: PkgCI Test PJRT plugin
on:
workflow_call:
inputs:
artifact_run_id:
type: string
default: ""
workflow_dispatch:
inputs:
artifact_run_id:
type: string
default: ""

jobs:
build_and_test:
strategy:
matrix:
include:
- runner: ubuntu-20.04
pjrt_platform: cpu
# TODO: cuda runner is not available yet, refer to #18814
# - runner: some-cuda-available-runner
# pjrt_platform: cuda
# TODO: enable these AMD runners
# - runner: nodai-amdgpu-w7900-x86-64
# pjrt_platform: rocm
# - runner: nodai-amdgpu-w7900-x86-64
# pjrt_platform: vulkan
name: Build and test
runs-on: ${{ matrix.runner }}
env:
PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages
VENV_DIR: ${{ github.workspace }}/.venv
steps:
- name: Checking out repository
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
submodules: true
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.1.0
with:
# Must match the subset of versions built in pkgci_build_packages.
python-version: "3.11"
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
name: linux_x86_64_release_packages
path: ${{ env.PACKAGE_DOWNLOAD_DIR }}
- name: Setup base venv
run: |
./build_tools/pkgci/setup_venv.py ${VENV_DIR} \
--artifact-path=${PACKAGE_DOWNLOAD_DIR} \
--fetch-gh-workflow=${{ inputs.artifact_run_id }}
- name: Build and install
run: |
# install editable into venv
source ${VENV_DIR}/bin/activate
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
# install jax (must be no larger than 0.4.20, refer to #19223)
# TODO: switch to the latest JAX after #19223 is fixed
python -m pip install jax==0.4.20 jaxlib==0.4.20 'numpy<2'
- name: Run tests
run: |
source ${VENV_DIR}/bin/activate
build_tools/testing/run_jax_tests.sh ${{ matrix.pjrt_platform }}
67 changes: 67 additions & 0 deletions build_tools/testing/run_jax_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

set -xeo pipefail

pjrt_platform=$1

if [ -z "${pjrt_platform}" ]; then
set +x
echo "Usage: run_jax_tests.sh <pjrt_platform>"
echo " <pjrt_platform> can be 'cpu', 'cuda', 'rocm' or 'vulkan'"
exit 1
fi

# cd into the PJRT plugin dir
ROOT_DIR="${ROOT_DIR:-$(git rev-parse --show-toplevel)}"
cd "${ROOT_DIR}/integrations/pjrt"

# perform some differential testing
actual_jax_platform=iree_${pjrt_platform}
expected_jax_platform=cpu

# this function will execute the test python script in
# both cpu mode and the IREE PJRT mode,
# and then compare the difference in the output
diff_jax_test() {
local test_py_file=$1

echo "executing ${test_py_file} in ${expected_jax_platform}.."
local expected_tmp_out=$(mktemp /tmp/jax_test_result_expected.XXXXXX)
JAX_PLATFORMS=$expected_jax_platform python $test_py_file > $expected_tmp_out

echo "executing ${test_py_file} in ${actual_jax_platform}.."
local actual_tmp_out=$(mktemp /tmp/jax_test_result_actual.XXXXXX)
JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

echo "comparing ${expected_tmp_out} and ${actual_tmp_out}.."
diff --unified $expected_tmp_out $actual_tmp_out
echo "no difference found"
}

# FIXME: due to #19223, we need to use jax no higher than 0.4.20,
# but in such version of jax, 'stablehlo.broadcast_in_dim' op
# will be emitted without attribute 'broadcast_dimensions',
# which leads to an error in IREE PJRT plugin.
# So currently any program with broadcast will fail,
# e.g. test/test_simple.py.
# After #19223 is fixed, we can uncomment the line below.

# diff_jax_test test/test_simple.py

diff_jax_test test/test_add.py


# FIXME: we can also utilize the native test cases from JAX,
# e.g. `tests/nn_test.py` from the JAX repo, as below,
# but currently some test cases in this file will fail.
# NOTE that `absl-py` is required to run these tests.

# local jax_nn_test_file=$(mktemp /tmp/jax_nn_test.XXXXXX.py)
# wget https://github.com/jax-ml/jax/blob/jax-v0.4.20/tests/nn_test.py -O $jax_nn_test_file
# JAX_PLATFORMS=$actual_jax_platform python $jax_nn_test_file
4 changes: 3 additions & 1 deletion integrations/pjrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ if(IREE_PJRT_ENABLE_LTO)
endif()

# Customize defaults.
option(IREE_BUILD_COMPILER "Disable compiler for runtime-library build" OFF)
# IREE_BUILD_COMPILER should be enabled to make target IREELLVMIncludeSetup available,
# which is required by PJRT dylib targets
option(IREE_BUILD_COMPILER "Enable compiler for runtime-library build" ON)
option(IREE_BUILD_SAMPLES "Disable samples for runtime-library build" OFF)

# Include IREE.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def build_configuration(self, cmake_build_dir, extra_cmake_args=()):
cmake_args = [
"-GNinja",
"--log-level=VERBOSE",
"-DIREE_BUILD_COMPILER=OFF",
"-DIREE_BUILD_COMPILER=ON",
"-DIREE_BUILD_SAMPLES=OFF",
"-DIREE_BUILD_TESTS=OFF",
"-DIREE_HAL_DRIVER_DEFAULTS=OFF",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
requires = [
"setuptools>=42",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
requires = [
"setuptools>=42",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
requires = [
"setuptools>=42",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
requires = [
"setuptools>=42",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
9 changes: 9 additions & 0 deletions integrations/pjrt/test/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import jax

a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])

print(a + a)

0 comments on commit 38d8d0a

Please sign in to comment.