Skip to content

Commit

Permalink
Merge branch 'main' into barrier_support_no_test
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Jan 20, 2025
2 parents 8cfaa6e + 81545be commit 84612ea
Showing 1 changed file with 30 additions and 38 deletions.
68 changes: 30 additions & 38 deletions sharktank/tests/models/flux/flux_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import pytest
import iree.compiler
import iree.runtime
from collections import OrderedDict
from diffusers import FluxTransformer2DModel
from sharktank.models.flux.export import (
Expand Down Expand Up @@ -84,7 +85,10 @@ def testExportDevRandomSingleLayerBf16(self):
)

def runCompareIreeAgainstTorchEager(
self, reference_model: FluxModelV1, target_dtype: torch.dtype
self,
reference_model: FluxModelV1,
target_dtype: torch.dtype,
atol: float,
):
target_theta = reference_model.theta.transform(
functools.partial(set_float_dtype, dtype=target_dtype)
Expand Down Expand Up @@ -164,22 +168,30 @@ def runCompareIreeAgainstTorchEager(
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
for i in range(len(expected_outputs))
]
# TODO: figure out a good metric. Probably per pixel comparison would be good
# enough.
torch.testing.assert_close(actual_outputs, expected_outputs)
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)

def runCompareDevRandomSingleLayerIreeAgainstTorchEager(
self, reference_dtype: torch.dtype, target_dtype: torch.dtype
def runTestCompareDevIreeAgainstHuggingFace(
self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
):
config = make_dev_single_layer_config()
parameters_output_path = self._temp_dir / "parameters.irpa"

reference_theta = make_random_theta(config, reference_dtype)
reference_theta.rename_tensors_to_paths()
import_flux_transformer_dataset_from_hugging_face(
repo_id="black-forest-labs/FLUX.1-dev/black-forest-labs-transformer",
parameters_output_path=parameters_output_path,
)
refrence_dataset = Dataset.load(parameters_output_path)
refrence_dataset.root_theta = Theta(
{
k: set_float_dtype(t, reference_dtype)
for k, t in refrence_dataset.root_theta.flatten().items()
}
)
reference_model = FluxModelV1(
theta=reference_theta,
params=config,
theta=refrence_dataset.root_theta,
params=FluxParams.from_hugging_face_properties(refrence_dataset.properties),
)
self.runCompareIreeAgainstTorchEager(reference_model, target_dtype)

self.runCompareIreeAgainstTorchEager(reference_model, target_dtype, atol=atol)

def runTestCompareTorchEagerAgainstHuggingFace(
self,
Expand Down Expand Up @@ -217,36 +229,16 @@ def runTestCompareTorchEagerAgainstHuggingFace(

torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0)

@pytest.mark.xfail(
raises=AssertionError,
reason="Accuracy is not good enough. The observed absolute error is 8976.53.",
)
@pytest.mark.skip(
reason=(
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
"Without it IREE compilation enters an infinite loop."
)
)
@with_flux_data
def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self):
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
reference_dtype=torch.float32, target_dtype=torch.bfloat16
def testCompareDevIreeF32AgainstHuggingFaceF32(self):
self.runTestCompareDevIreeAgainstHuggingFace(
reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2
)

@pytest.mark.xfail(
raises=AssertionError,
reason="Accuracy is probably not good enough. The observed absolute error is 73.25.",
)
@pytest.mark.skip(
reason=(
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
"Without it IREE compilation enters an infinite loop."
)
)
@with_flux_data
def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self):
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
reference_dtype=torch.float32, target_dtype=torch.float32
def testCompareDevIreeBf16AgainstHuggingFaceF32(self):
self.runTestCompareDevIreeAgainstHuggingFace(
reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1
)

@with_flux_data
Expand Down

0 comments on commit 84612ea

Please sign in to comment.