Skip to content

Commit

Permalink
Add flux transformer Dev test with IREE (#843)
Browse files Browse the repository at this point in the history
Remove the single layer with random weights test in favor for the
pretrained Dev model variant.
We test IREE f32 and bf16 against eager f32.
The particular initialization parameters for the random values caused
large intermediate values during execution, which deteriorated the model
output.
The pretrained variant does not suffer from this problem and the
numerical error looks reasonable.
  • Loading branch information
sogartar authored Jan 20, 2025
1 parent 3d36fe8 commit 81545be
Showing 1 changed file with 30 additions and 38 deletions.
68 changes: 30 additions & 38 deletions sharktank/tests/models/flux/flux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import pytest
import iree.compiler
import iree.runtime
from collections import OrderedDict
from diffusers import FluxTransformer2DModel
from sharktank.models.flux.export import (
Expand Down Expand Up @@ -84,7 +85,10 @@ def testExportDevRandomSingleLayerBf16(self):
)

def runCompareIreeAgainstTorchEager(
self, reference_model: FluxModelV1, target_dtype: torch.dtype
self,
reference_model: FluxModelV1,
target_dtype: torch.dtype,
atol: float,
):
target_theta = reference_model.theta.transform(
functools.partial(set_float_dtype, dtype=target_dtype)
Expand Down Expand Up @@ -164,22 +168,30 @@ def runCompareIreeAgainstTorchEager(
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
for i in range(len(expected_outputs))
]
# TODO: figure out a good metric. Probably per pixel comparison would be good
# enough.
torch.testing.assert_close(actual_outputs, expected_outputs)
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)

def runCompareDevRandomSingleLayerIreeAgainstTorchEager(
self, reference_dtype: torch.dtype, target_dtype: torch.dtype
def runTestCompareDevIreeAgainstHuggingFace(
self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
):
config = make_dev_single_layer_config()
parameters_output_path = self._temp_dir / "parameters.irpa"

reference_theta = make_random_theta(config, reference_dtype)
reference_theta.rename_tensors_to_paths()
import_flux_transformer_dataset_from_hugging_face(
repo_id="black-forest-labs/FLUX.1-dev/black-forest-labs-transformer",
parameters_output_path=parameters_output_path,
)
refrence_dataset = Dataset.load(parameters_output_path)
refrence_dataset.root_theta = Theta(
{
k: set_float_dtype(t, reference_dtype)
for k, t in refrence_dataset.root_theta.flatten().items()
}
)
reference_model = FluxModelV1(
theta=reference_theta,
params=config,
theta=refrence_dataset.root_theta,
params=FluxParams.from_hugging_face_properties(refrence_dataset.properties),
)
self.runCompareIreeAgainstTorchEager(reference_model, target_dtype)

self.runCompareIreeAgainstTorchEager(reference_model, target_dtype, atol=atol)

def runTestCompareTorchEagerAgainstHuggingFace(
self,
Expand Down Expand Up @@ -217,36 +229,16 @@ def runTestCompareTorchEagerAgainstHuggingFace(

torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0)

@pytest.mark.xfail(
raises=AssertionError,
reason="Accuracy is not good enough. The observed absolute error is 8976.53.",
)
@pytest.mark.skip(
reason=(
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
"Without it IREE compilation enters an infinite loop."
)
)
@with_flux_data
def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self):
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
reference_dtype=torch.float32, target_dtype=torch.bfloat16
def testCompareDevIreeF32AgainstHuggingFaceF32(self):
self.runTestCompareDevIreeAgainstHuggingFace(
reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2
)

@pytest.mark.xfail(
raises=AssertionError,
reason="Accuracy is probably not good enough. The observed absolute error is 73.25.",
)
@pytest.mark.skip(
reason=(
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
"Without it IREE compilation enters an infinite loop."
)
)
@with_flux_data
def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self):
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
reference_dtype=torch.float32, target_dtype=torch.float32
def testCompareDevIreeBf16AgainstHuggingFaceF32(self):
self.runTestCompareDevIreeAgainstHuggingFace(
reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1
)

@with_flux_data
Expand Down

0 comments on commit 81545be

Please sign in to comment.