From 088f0069621eec1a6226734f691bc4bddbe6deae Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:52:26 -0800 Subject: [PATCH] Extend VaeDecoderModel for flux compatibility (#750) Co-authored-by: Kyle Herndon Co-authored-by: Kyle Herndon --- sharktank/requirements.txt | 7 +- sharktank/sharktank/models/vae/config.py | 2 + sharktank/sharktank/models/vae/layers.py | 3 +- sharktank/sharktank/models/vae/model.py | 28 ++- .../models/vae/tools/diffuser_ref.py | 32 ++- .../sharktank/models/vae/tools/run_vae.py | 28 ++- .../sharktank/models/vae/tools/sample_data.py | 12 +- .../sharktank/tools/import_hf_dataset.py | 5 +- sharktank/sharktank/types/tensors.py | 5 + sharktank/tests/models/vae/vae_test.py | 212 +++++++++++++++--- 10 files changed, 289 insertions(+), 45 deletions(-) diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 241b49f79..bd6699e54 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -5,6 +5,11 @@ gguf>=0.11.0 numpy # Model deps. -huggingface-hub==0.22.2 +huggingface-hub transformers==4.40.0 datasets +einops + +# Serving deps. +fastapi>=0.112.2 +uvicorn>=0.30.6 diff --git a/sharktank/sharktank/models/vae/config.py b/sharktank/sharktank/models/vae/config.py index 9ee2f0427..a27d30a19 100644 --- a/sharktank/sharktank/models/vae/config.py +++ b/sharktank/sharktank/models/vae/config.py @@ -38,6 +38,8 @@ class HParams: layers_per_block: int = 2 norm_num_groups: int = 32 scaling_factor: float = 0.13025 + use_post_quant_conv: bool = True + shift_factor: float = 0.0 def assert_default_values(self, attr_names: Sequence[str]): for name in attr_names: diff --git a/sharktank/sharktank/models/vae/layers.py b/sharktank/sharktank/models/vae/layers.py index 0d7033f4a..9ce57440e 100644 --- a/sharktank/sharktank/models/vae/layers.py +++ b/sharktank/sharktank/models/vae/layers.py @@ -17,7 +17,6 @@ ResnetBlock2D, Upsample2D, GroupNormLayer, - AttentionLayer, ) from .config import * @@ -84,7 +83,6 @@ def forward( hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) - query = self.to_q(hidden_states) if encoder_hidden_states is None: @@ -110,6 +108,7 @@ def forward( hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, self.heads * head_dim ) + hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = self.to_out(hidden_states) diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index 1054108c7..d689aaf72 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -15,6 +15,8 @@ from .layers import * from sharktank.models.punet.layers import UpDownBlock2D, GroupNormLayer from typing import Optional +from einops import rearrange +import math class VaeDecoderModel(ThetaLayer): @@ -23,12 +25,13 @@ def from_dataset(cls, ds: Dataset) -> "VaeDecoderModel": hp = HParams.from_dict(ds.properties["hparams"]) return cls(hp, ds.root_theta) - def __init__(self, hp: HParams, theta: Theta): + def __init__(self, hp, theta: Theta): super().__init__(theta) self.hp = hp # input conv - self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0)) + if hp.use_post_quant_conv: + self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0)) self.conv_in = Conv2DLayer(theta("decoder")("conv_in"), padding=(1, 1)) # Mid self.mid_block = self._create_mid_block(theta("decoder")("mid_block")) @@ -71,9 +74,20 @@ def forward( "latent_embeds": latent_embeds, }, ) - sample = 1 / self.hp.scaling_factor * sample + if not self.hp.use_post_quant_conv: + sample = rearrange( + sample, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(1024 / 16), + w=math.ceil(1024 / 16), + ph=2, + pw=2, + ) + sample = sample / self.hp.scaling_factor + self.hp.shift_factor + + if self.hp.use_post_quant_conv: + sample = self.post_quant_conv(sample) - sample = self.post_quant_conv(sample) sample = self.conv_in(sample) self.trace_golden("conv_in", sample) # TODO add training and gradient checkpointing support @@ -90,7 +104,11 @@ def forward( sample = self.conv_act(sample) sample = self.conv_out(sample) - sample = (sample / 2 + 0.5).clamp(0, 1) + + if not self.hp.use_post_quant_conv: + sample = sample.clamp(-1, 1) + else: + sample = (sample / 2 + 0.5).clamp(0, 1) return sample def _create_mid_block(self, mid_block_theta: Theta) -> nn.Module: diff --git a/sharktank/sharktank/models/vae/tools/diffuser_ref.py b/sharktank/sharktank/models/vae/tools/diffuser_ref.py index c2c283197..8598573d0 100644 --- a/sharktank/sharktank/models/vae/tools/diffuser_ref.py +++ b/sharktank/sharktank/models/vae/tools/diffuser_ref.py @@ -7,6 +7,8 @@ import argparse import torch from diffusers import AutoencoderKL +from einops import rearrange +import math class VaeModel(torch.nn.Module): @@ -51,6 +53,34 @@ def decode(self, inp): def run_torch_vae(hf_model_name, example_input): - vae_model = VaeModel(hf_model_name) return vae_model.decode(example_input) + + +# TODO Remove and integrate with VaeModel +class FluxAEWrapper(torch.nn.Module): + def __init__(self, height=1024, width=1024): + super().__init__() + self.ae = AutoencoderKL.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16 + ) + self.height = height + self.width = width + + def forward(self, z): + d_in = rearrange( + z, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(self.height / 16), + w=math.ceil(self.width / 16), + ph=2, + pw=2, + ) + d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor + return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) + + +def run_flux_vae(example_input, dtype): + # TODO add support for other height/width sizes + vae_model = FluxAEWrapper(1024, 1024).to(dtype) + return vae_model.forward(example_input) diff --git a/sharktank/sharktank/models/vae/tools/run_vae.py b/sharktank/sharktank/models/vae/tools/run_vae.py index 540436fd1..8c06fb788 100644 --- a/sharktank/sharktank/models/vae/tools/run_vae.py +++ b/sharktank/sharktank/models/vae/tools/run_vae.py @@ -21,6 +21,7 @@ from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) +import numpy as np def export_vae(model, sample_inputs, decomp_attn): @@ -81,6 +82,19 @@ def main(argv): action="store_true", help="Compares results vs HF diffusers reference model", ) + + parser.add_argument( + "--torch_model", + default="stabilityai/stable-diffusion-xl-base-1.0", + help="HF reference model id", + ) + + parser.add_argument( + "--sharktank_config", + default="sdxl", + help="Sharktank config providing hyperparamters [sdxl or flux]", + ) + parser.add_argument( "--decomp_attn", action="store_true", @@ -95,12 +109,13 @@ def main(argv): ds.to(device=device) mdl = VaeDecoderModel.from_dataset(ds) - # Run a step for debugging. if args.inputs: inputs = load_inputs(args.inputs, dtype=dtype, device=device, bs=args.bs) else: - inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs) + inputs = get_random_inputs( + dtype=dtype, device=device, bs=args.bs, config=args.sharktank_config + ) if args.export: # TODO move export from a run_vae file @@ -126,11 +141,12 @@ def main(argv): intermediates_saver.save_file(args.save_intermediates_path) if args.compare_vs_torch: - from .diffuser_ref import run_torch_vae + from .diffuser_ref import run_torch_vae, run_flux_vae - diffusers_results = run_torch_vae( - "stabilityai/stable-diffusion-xl-base-1.0", inputs - ) + if args.sharktank_config == "flux": + diffusers_results = run_flux_vae(inputs, torch.bfloat16) + elif args.sharktank_config == "sdxl": + run_torch_vae(args.torch_model, inputs) print("diffusers results:", diffusers_results) torch.testing.assert_close(diffusers_results, results) diff --git a/sharktank/sharktank/models/vae/tools/sample_data.py b/sharktank/sharktank/models/vae/tools/sample_data.py index cd946088e..49717d1a7 100644 --- a/sharktank/sharktank/models/vae/tools/sample_data.py +++ b/sharktank/sharktank/models/vae/tools/sample_data.py @@ -11,7 +11,15 @@ import torch -def get_random_inputs(dtype, device, bs: int = 2): +def get_random_inputs(dtype, device, bs: int = 2, config: str = "sdxl"): height = 1024 width = 1024 - return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device) + if config == "sdxl": + print("sdxl returning inputs") + return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device) + elif config == "flux": + print("flux returning inputs") + return torch.rand(bs, int(width * height / 256), 64, dtype=dtype).to(device) + else: + print("config: ", config) + raise AssertionError(f"{config} config not implmented [sdxl, flux] implemented") diff --git a/sharktank/sharktank/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py index 8b8feed9f..8fb5b3f16 100644 --- a/sharktank/sharktank/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -38,6 +38,7 @@ def import_hf_dataset( config_json_path: PathLike, param_paths: list[PathLike], output_irpa_file: Optional[PathLike] = None, + target_dtype=None, ) -> Optional[Dataset]: import safetensors @@ -50,7 +51,9 @@ def import_hf_dataset( for params_path in param_paths: with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: tensors = [ - DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) + DefaultPrimitiveTensor( + name=name, data=st.get_tensor(name).to(target_dtype) + ) for name in st.keys() ] diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 9470aba7e..ea0482e71 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -388,6 +388,11 @@ def squeeze(self, dim: Optional[int] = None) -> "AnyTensor": return squeeze(self, dim) + def squeeze(self, dim: Optional[int] = None) -> "AnyTensor": + from ..ops import squeeze + + return squeeze(self, dim) + def transpose(self, dim0: int, dim1: int) -> "AnyTensor": from ..ops import transpose diff --git a/sharktank/tests/models/vae/vae_test.py b/sharktank/tests/models/vae/vae_test.py index 9b77d835e..2812df95a 100644 --- a/sharktank/tests/models/vae/vae_test.py +++ b/sharktank/tests/models/vae/vae_test.py @@ -13,7 +13,7 @@ from sharktank.types import Dataset from sharktank.models.vae.model import VaeDecoderModel -from sharktank.models.vae.tools.diffuser_ref import run_torch_vae +from sharktank.models.vae.tools.diffuser_ref import run_torch_vae, run_flux_vae from sharktank.models.vae.tools.run_vae import export_vae from sharktank.models.vae.tools.sample_data import get_random_inputs @@ -32,55 +32,59 @@ call_torch_module_function, flatten_for_iree_signature, iree_to_torch, + device_array_to_host, ) import iree.compiler from collections import OrderedDict +from sharktank.utils.testing import TempDirTestBase + with_vae_data = pytest.mark.skipif("not config.getoption('with_vae_data')") @with_vae_data -class VaeSDXLDecoderTest(unittest.TestCase): +class VaeSDXLDecoderTest(TempDirTestBase): def setUp(self): + super().setUp() hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0" hf_hub_download( repo_id=hf_model_id, - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/config.json", ) hf_hub_download( repo_id=hf_model_id, - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/diffusion_pytorch_model.safetensors", ) hf_hub_download( repo_id="amd-shark/sdxl-quant-models", - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/vae.safetensors", ) torch.manual_seed(12345) f32_dataset = import_hf_dataset( - "sdxl_vae/vae/config.json", - ["sdxl_vae/vae/diffusion_pytorch_model.safetensors"], + "{self._temp_dir}/vae/config.json", + ["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"], ) - f32_dataset.save("sdxl_vae/vae_f32.irpa", io_report_callback=print) + f32_dataset.save("{self._temp_dir}/vae_f32.irpa", io_report_callback=print) f16_dataset = import_hf_dataset( - "sdxl_vae/vae/config.json", ["sdxl_vae/vae/vae.safetensors"] + "{self._temp_dir}/vae/config.json", ["{self._temp_dir}/vae/vae.safetensors"] ) - f16_dataset.save("sdxl_vae/vae_f16.irpa", io_report_callback=print) + f16_dataset.save("{self._temp_dir}/vae_f16.irpa", io_report_callback=print) def testCompareF32EagerVsHuggingface(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") results = model.forward(inputs) @@ -91,9 +95,9 @@ def testCompareF32EagerVsHuggingface(self): def testCompareF16EagerVsHuggingface(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") results = model.forward(inputs.to(torch.float16)) @@ -106,10 +110,10 @@ def testCompareF16EagerVsHuggingface(self): def testVaeIreeVsHuggingFace(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds_f16 = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") - ds_f32 = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + ds_f16 = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa") + ds_f32 = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa") model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu") model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") @@ -118,8 +122,8 @@ def testVaeIreeVsHuggingFace(self): module_f16 = export_vae(model_f16, inputs.to(torch.float16), True) module_f32 = export_vae(model_f32, inputs, True) - module_f16.save_mlir("sdxl_vae/vae_f16.mlir") - module_f32.save_mlir("sdxl_vae/vae_f32.mlir") + module_f16.save_mlir("{self._temp_dir}/vae_f16.mlir") + module_f32.save_mlir("{self._temp_dir}/vae_f32.mlir") extra_args = [ "--iree-hal-target-backends=rocm", "--iree-hip-target=gfx942", @@ -136,22 +140,22 @@ def testVaeIreeVsHuggingFace(self): ] iree.compiler.compile_file( - "sdxl_vae/vae_f16.mlir", - output_file="sdxl_vae/vae_f16.vmfb", + "{self._temp_dir}/vae_f16.mlir", + output_file="{self._temp_dir}/vae_f16.vmfb", extra_args=extra_args, ) iree.compiler.compile_file( - "sdxl_vae/vae_f32.mlir", - output_file="sdxl_vae/vae_f32.vmfb", + "{self._temp_dir}/vae_f32.mlir", + output_file="{self._temp_dir}/vae_f32.vmfb", extra_args=extra_args, ) iree_devices = get_iree_devices(driver="hip", device_count=1) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="sdxl_vae/vae_f16.vmfb", + module_path="{self._temp_dir}/vae_f16.vmfb", devices=iree_devices, - parameters_path="sdxl_vae/vae_f16.irpa", + parameters_path="{self._temp_dir}/vae_f16.irpa", ) input_args = OrderedDict([("inputs", inputs.to(torch.float16))]) @@ -177,9 +181,9 @@ def testVaeIreeVsHuggingFace(self): ) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="sdxl_vae/vae_f32.vmfb", + module_path="{self._temp_dir}/vae_f32.vmfb", devices=iree_devices, - parameters_path="sdxl_vae/vae_f32.irpa", + parameters_path="{self._temp_dir}/vae_f32.irpa", ) input_args = OrderedDict([("inputs", inputs)]) @@ -201,5 +205,159 @@ def testVaeIreeVsHuggingFace(self): ) +@with_vae_data +class VaeFluxDecoderTest(TempDirTestBase): + def setUp(self): + super().setUp() + hf_model_id = "black-forest-labs/FLUX.1-dev" + hf_hub_download( + repo_id=hf_model_id, + local_dir="{self._temp_dir}/flux_vae/", + local_dir_use_symlinks=False, + revision="main", + filename="vae/config.json", + ) + hf_hub_download( + repo_id=hf_model_id, + local_dir="{self._temp_dir}/flux_vae/", + local_dir_use_symlinks=False, + revision="main", + filename="vae/diffusion_pytorch_model.safetensors", + ) + torch.manual_seed(12345) + dataset = import_hf_dataset( + "{self._temp_dir}/flux_vae/vae/config.json", + ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"], + ) + dataset.save("{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print) + dataset_f32 = import_hf_dataset( + "{self._temp_dir}/flux_vae/vae/config.json", + ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"], + target_dtype=torch.float32, + ) + dataset_f32.save("{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print) + + def testCompareBF16EagerVsHuggingface(self): + dtype = torch.bfloat16 + inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux") + ref_results = run_flux_vae(inputs, dtype) + + ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa") + model = VaeDecoderModel.from_dataset(ds).to(device="cpu") + + results = model.forward(inputs) + # TODO: verify numerics + torch.testing.assert_close(ref_results, results, atol=3e-2, rtol=3e5) + + def testCompareF32EagerVsHuggingface(self): + dtype = torch.float32 + inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux") + ref_results = run_flux_vae(inputs, dtype) + + ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa") + model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype) + + results = model.forward(inputs) + torch.testing.assert_close(ref_results, results) + + def testVaeIreeVsHuggingFace(self): + dtype = torch.bfloat16 + inputs = get_random_inputs( + dtype=torch.float32, device="cpu", bs=1, config="flux" + ) + ref_results = run_flux_vae(inputs.to(dtype), dtype) + ref_results_f32 = run_flux_vae(inputs, torch.float32) + + ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa") + ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa") + + model = VaeDecoderModel.from_dataset(ds).to(device="cpu") + model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") + + # TODO: Decomposing attention due to https://github.com/iree-org/iree/issues/19286, remove once issue is resolved + module = export_vae(model, inputs, True) + module_f32 = export_vae(model_f32, inputs, True) + + module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir") + module_f32.save_mlir("{self._temp_dir}/flux_vae_f32.mlir") + + extra_args = [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + "--iree-opt-const-eval=false", + "--iree-opt-strip-assertions=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-hip-waves-per-eu=2", + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", + ] + + iree.compiler.compile_file( + "{self._temp_dir}/flux_vae_bf16.mlir", + output_file="{self._temp_dir}/flux_vae_bf16.vmfb", + extra_args=extra_args, + ) + iree.compiler.compile_file( + "{self._temp_dir}/flux_vae_f32.mlir", + output_file="{self._temp_dir}/flux_vae_f32.vmfb", + extra_args=extra_args, + ) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path="{self._temp_dir}/flux_vae_bf16.vmfb", + devices=iree_devices, + parameters_path="{self._temp_dir}/flux_vae_bf16.irpa", + ) + + input_args = OrderedDict([("inputs", inputs)]) + iree_args = flatten_for_iree_signature(input_args) + + iree_args = prepare_iree_module_function_args( + args=iree_args, devices=iree_devices + ) + iree_result = device_array_to_host( + run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name="forward", + )[0] + ) + + # TODO verify these numerics + torch.testing.assert_close(ref_results, iree_result, atol=3.3e-2, rtol=4e5) + + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path="{self._temp_dir}/flux_vae_f32.vmfb", + devices=iree_devices, + parameters_path="{self._temp_dir}/flux_vae_f32.irpa", + ) + + input_args = OrderedDict([("inputs", inputs)]) + iree_args = flatten_for_iree_signature(input_args) + + iree_args = prepare_iree_module_function_args( + args=iree_args, devices=iree_devices + ) + iree_result_f32 = device_array_to_host( + run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name="forward", + )[0] + ) + + torch.testing.assert_close(ref_results_f32, iree_result_f32) + + if __name__ == "__main__": unittest.main()