From 81545be12e4761eb916982aa6f8935f400cbf253 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Mon, 20 Jan 2025 11:16:12 -0800 Subject: [PATCH] Add flux transformer Dev test with IREE (#843) Remove the single layer with random weights test in favor for the pretrained Dev model variant. We test IREE f32 and bf16 against eager f32. The particular initialization parameters for the random values caused large intermediate values during execution, which deteriorated the model output. The pretrained variant does not suffer from this problem and the numerical error looks reasonable. --- sharktank/tests/models/flux/flux_test.py | 68 +++++++++++------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index 7e1842825..24c1ddb12 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -10,6 +10,7 @@ import torch import pytest import iree.compiler +import iree.runtime from collections import OrderedDict from diffusers import FluxTransformer2DModel from sharktank.models.flux.export import ( @@ -84,7 +85,10 @@ def testExportDevRandomSingleLayerBf16(self): ) def runCompareIreeAgainstTorchEager( - self, reference_model: FluxModelV1, target_dtype: torch.dtype + self, + reference_model: FluxModelV1, + target_dtype: torch.dtype, + atol: float, ): target_theta = reference_model.theta.transform( functools.partial(set_float_dtype, dtype=target_dtype) @@ -164,22 +168,30 @@ def runCompareIreeAgainstTorchEager( ops.to(iree_result[i], dtype=expected_outputs[i].dtype) for i in range(len(expected_outputs)) ] - # TODO: figure out a good metric. Probably per pixel comparison would be good - # enough. - torch.testing.assert_close(actual_outputs, expected_outputs) + torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0) - def runCompareDevRandomSingleLayerIreeAgainstTorchEager( - self, reference_dtype: torch.dtype, target_dtype: torch.dtype + def runTestCompareDevIreeAgainstHuggingFace( + self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float ): - config = make_dev_single_layer_config() + parameters_output_path = self._temp_dir / "parameters.irpa" - reference_theta = make_random_theta(config, reference_dtype) - reference_theta.rename_tensors_to_paths() + import_flux_transformer_dataset_from_hugging_face( + repo_id="black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + parameters_output_path=parameters_output_path, + ) + refrence_dataset = Dataset.load(parameters_output_path) + refrence_dataset.root_theta = Theta( + { + k: set_float_dtype(t, reference_dtype) + for k, t in refrence_dataset.root_theta.flatten().items() + } + ) reference_model = FluxModelV1( - theta=reference_theta, - params=config, + theta=refrence_dataset.root_theta, + params=FluxParams.from_hugging_face_properties(refrence_dataset.properties), ) - self.runCompareIreeAgainstTorchEager(reference_model, target_dtype) + + self.runCompareIreeAgainstTorchEager(reference_model, target_dtype, atol=atol) def runTestCompareTorchEagerAgainstHuggingFace( self, @@ -217,36 +229,16 @@ def runTestCompareTorchEagerAgainstHuggingFace( torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0) - @pytest.mark.xfail( - raises=AssertionError, - reason="Accuracy is not good enough. The observed absolute error is 8976.53.", - ) - @pytest.mark.skip( - reason=( - "Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. " - "Without it IREE compilation enters an infinite loop." - ) - ) @with_flux_data - def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self): - self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( - reference_dtype=torch.float32, target_dtype=torch.bfloat16 + def testCompareDevIreeF32AgainstHuggingFaceF32(self): + self.runTestCompareDevIreeAgainstHuggingFace( + reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2 ) - @pytest.mark.xfail( - raises=AssertionError, - reason="Accuracy is probably not good enough. The observed absolute error is 73.25.", - ) - @pytest.mark.skip( - reason=( - "Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. " - "Without it IREE compilation enters an infinite loop." - ) - ) @with_flux_data - def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self): - self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( - reference_dtype=torch.float32, target_dtype=torch.float32 + def testCompareDevIreeBf16AgainstHuggingFaceF32(self): + self.runTestCompareDevIreeAgainstHuggingFace( + reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1 ) @with_flux_data