-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[libshortfin] Add simple invocation test. (#170)
This test is not particularly inspired (and the API needs to be simplified) but it represents the first full system test in the repo. In order to run the test, it is downloading a mobilenet onnx file from the zoo, upgrading it, and compiling. In the future, I'd like to switch this to a simpler model like MNIST for basic functionality, but I had some issues getting that to work via ONNX import and punted. While a bit inefficient (it will fetch on each pytest run), this will keep things held together until we can do something more comprehensive. Note that my experience here prompted me to file iree-org/iree#18289, as this is way too much code and sharp edges to compile from ONNX (but it does work). Verifies numerics against a silly test image. Includes some fixes: * Reworked the system detect marker so that we only run system specific tests (like amdgpu) on opt-in via a `--system amdgpu` pytest arg. This refinement was prompted by an ASAN violation in the HIP runtime code which was tripping me up when enabled by default. Filed here: iree-org/iree#18449 * Fixed a bug revealed when writing the test where an exception thrown from main could trigger a use-after-free because we were clearing workers when shutting down (vs at destruction) when all objects owned at the system level need to have a lifetime no less than the system.
- Loading branch information
1 parent
a038133
commit 5a198e9
Showing
13 changed files
with
193 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
leak:PyUnicode_New | ||
leak:_PyUnicodeWriter_PrepareInternal | ||
leak:_PyUnicodeWriter_Finish | ||
leak:numpy | ||
leak:_mlir_libs | ||
leak:google/_upb | ||
leak:import_find_and_load | ||
leak:ufunc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
pytest | ||
requests | ||
fastapi | ||
onnx | ||
uvicorn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--system", | ||
action="store", | ||
metavar="NAME", | ||
nargs="*", | ||
help="Enable tests for system name ('amdgpu', ...)", | ||
) | ||
|
||
|
||
def pytest_configure(config): | ||
config.addinivalue_line( | ||
"markers", "system(name): mark test to run only on a named system" | ||
) | ||
|
||
|
||
def pytest_runtest_setup(item): | ||
required_system_names = [mark.args[0] for mark in item.iter_markers("system")] | ||
if required_system_names: | ||
available_system_names = item.config.getoption("--system") or [] | ||
if not all(name in available_system_names for name in required_system_names): | ||
pytest.skip( | ||
f"test requires system in {required_system_names!r} but has " | ||
f"{available_system_names!r} (set with --system arg)" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
import urllib.request | ||
|
||
|
||
def upgrade_onnx(original_path, converted_path): | ||
import onnx | ||
|
||
original_model = onnx.load_model(original_path) | ||
converted_model = onnx.version_converter.convert_version(original_model, 17) | ||
onnx.save(converted_model, converted_path) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def mobilenet_onnx_path(tmp_path_factory): | ||
try: | ||
import onnx | ||
except ModuleNotFoundError: | ||
raise pytest.skip("onnx python package not available") | ||
print("Downloading mobilenet.onnx") | ||
parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") | ||
orig_onnx_path = parent_dir / "mobilenet_orig.onnx" | ||
urllib.request.urlretrieve( | ||
"https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", | ||
orig_onnx_path, | ||
) | ||
upgraded_onnx_path = parent_dir / "mobilenet.onnx" | ||
upgrade_onnx(orig_onnx_path, upgraded_onnx_path) | ||
return upgraded_onnx_path | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def mobilenet_compiled_cpu_path(mobilenet_onnx_path): | ||
try: | ||
import iree.compiler.tools as tools | ||
import iree.compiler.tools.import_onnx.__main__ as import_onnx | ||
except ModuleNotFoundError: | ||
raise pytest.skip("iree.compiler packages not available") | ||
print("Compiling mobilenet") | ||
mlir_path = mobilenet_onnx_path.parent / "mobilenet.mlir" | ||
vmfb_path = mobilenet_onnx_path.parent / "mobilenet_cpu.vmfb" | ||
args = import_onnx.parse_arguments(["-o", str(mlir_path), str(mobilenet_onnx_path)]) | ||
import_onnx.main(args) | ||
tools.compile_file( | ||
str(mlir_path), | ||
output_file=str(vmfb_path), | ||
target_backends=["llvm-cpu"], | ||
input_type="onnx", | ||
) | ||
return vmfb_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import array | ||
import functools | ||
import pytest | ||
|
||
import shortfin as sf | ||
import shortfin.array as sfnp | ||
|
||
|
||
@pytest.fixture | ||
def lsys(): | ||
sc = sf.host.CPUSystemBuilder() | ||
lsys = sc.create_system() | ||
yield lsys | ||
lsys.shutdown() | ||
|
||
|
||
@pytest.fixture | ||
def scope(lsys): | ||
return lsys.create_scope() | ||
|
||
|
||
@pytest.fixture | ||
def device(scope): | ||
return scope.device(0) | ||
|
||
|
||
def test_invoke_mobilenet(lsys, scope, mobilenet_compiled_cpu_path): | ||
device = scope.device(0) | ||
dummy_data = array.array( | ||
"f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) | ||
) | ||
program_module = lsys.load_module(mobilenet_compiled_cpu_path) | ||
program = sf.Program([program_module], scope=scope) | ||
main_function = program["module.torch-jit-export"] | ||
|
||
async def main(): | ||
device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) | ||
staging_input = device_input.for_transfer() | ||
staging_input.storage.data = dummy_data | ||
device_input.copy_from(staging_input) | ||
(device_output,) = await main_function(device_input) | ||
host_output = device_output.for_transfer() | ||
host_output.copy_from(device_output) | ||
await device | ||
flat_output = array.array("f") | ||
flat_output.frombytes(host_output.storage.data) | ||
absmean = functools.reduce( | ||
lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 | ||
) | ||
assert absmean == pytest.approx(5.01964943873882) | ||
|
||
lsys.run(main()) |