Skip to content

Commit

Permalink
Extend VaeDecoderModel for flux compatibility (#750)
Browse files Browse the repository at this point in the history
Co-authored-by: Kyle Herndon <kyle.herndon@amd.com>
Co-authored-by: Kyle Herndon <kherndon@amd.com>
  • Loading branch information
3 people authored Jan 13, 2025
1 parent 498193c commit 088f006
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 45 deletions.
7 changes: 6 additions & 1 deletion sharktank/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ gguf>=0.11.0
numpy

# Model deps.
huggingface-hub==0.22.2
huggingface-hub
transformers==4.40.0
datasets
einops

# Serving deps.
fastapi>=0.112.2
uvicorn>=0.30.6
2 changes: 2 additions & 0 deletions sharktank/sharktank/models/vae/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class HParams:
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.13025
use_post_quant_conv: bool = True
shift_factor: float = 0.0

def assert_default_values(self, attr_names: Sequence[str]):
for name in attr_names:
Expand Down
3 changes: 1 addition & 2 deletions sharktank/sharktank/models/vae/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
ResnetBlock2D,
Upsample2D,
GroupNormLayer,
AttentionLayer,
)
from .config import *

Expand Down Expand Up @@ -84,7 +83,6 @@ def forward(
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)

query = self.to_q(hidden_states)

if encoder_hidden_states is None:
Expand All @@ -110,6 +108,7 @@ def forward(
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, self.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = self.to_out(hidden_states)
Expand Down
28 changes: 23 additions & 5 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from .layers import *
from sharktank.models.punet.layers import UpDownBlock2D, GroupNormLayer
from typing import Optional
from einops import rearrange
import math


class VaeDecoderModel(ThetaLayer):
Expand All @@ -23,12 +25,13 @@ def from_dataset(cls, ds: Dataset) -> "VaeDecoderModel":
hp = HParams.from_dict(ds.properties["hparams"])
return cls(hp, ds.root_theta)

def __init__(self, hp: HParams, theta: Theta):
def __init__(self, hp, theta: Theta):
super().__init__(theta)
self.hp = hp

# input conv
self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0))
if hp.use_post_quant_conv:
self.post_quant_conv = Conv2DLayer(theta("post_quant_conv"), padding=(0, 0))
self.conv_in = Conv2DLayer(theta("decoder")("conv_in"), padding=(1, 1))
# Mid
self.mid_block = self._create_mid_block(theta("decoder")("mid_block"))
Expand Down Expand Up @@ -71,9 +74,20 @@ def forward(
"latent_embeds": latent_embeds,
},
)
sample = 1 / self.hp.scaling_factor * sample
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)
sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
sample = self.post_quant_conv(sample)

sample = self.post_quant_conv(sample)
sample = self.conv_in(sample)
self.trace_golden("conv_in", sample)
# TODO add training and gradient checkpointing support
Expand All @@ -90,7 +104,11 @@ def forward(

sample = self.conv_act(sample)
sample = self.conv_out(sample)
sample = (sample / 2 + 0.5).clamp(0, 1)

if not self.hp.use_post_quant_conv:
sample = sample.clamp(-1, 1)
else:
sample = (sample / 2 + 0.5).clamp(0, 1)
return sample

def _create_mid_block(self, mid_block_theta: Theta) -> nn.Module:
Expand Down
32 changes: 31 additions & 1 deletion sharktank/sharktank/models/vae/tools/diffuser_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import argparse
import torch
from diffusers import AutoencoderKL
from einops import rearrange
import math


class VaeModel(torch.nn.Module):
Expand Down Expand Up @@ -51,6 +53,34 @@ def decode(self, inp):


def run_torch_vae(hf_model_name, example_input):

vae_model = VaeModel(hf_model_name)
return vae_model.decode(example_input)


# TODO Remove and integrate with VaeModel
class FluxAEWrapper(torch.nn.Module):
def __init__(self, height=1024, width=1024):
super().__init__()
self.ae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
)
self.height = height
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor
return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1)


def run_flux_vae(example_input, dtype):
# TODO add support for other height/width sizes
vae_model = FluxAEWrapper(1024, 1024).to(dtype)
return vae_model.forward(example_input)
28 changes: 22 additions & 6 deletions sharktank/sharktank/models/vae/tools/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
import numpy as np


def export_vae(model, sample_inputs, decomp_attn):
Expand Down Expand Up @@ -81,6 +82,19 @@ def main(argv):
action="store_true",
help="Compares results vs HF diffusers reference model",
)

parser.add_argument(
"--torch_model",
default="stabilityai/stable-diffusion-xl-base-1.0",
help="HF reference model id",
)

parser.add_argument(
"--sharktank_config",
default="sdxl",
help="Sharktank config providing hyperparamters [sdxl or flux]",
)

parser.add_argument(
"--decomp_attn",
action="store_true",
Expand All @@ -95,12 +109,13 @@ def main(argv):
ds.to(device=device)

mdl = VaeDecoderModel.from_dataset(ds)

# Run a step for debugging.
if args.inputs:
inputs = load_inputs(args.inputs, dtype=dtype, device=device, bs=args.bs)
else:
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs)
inputs = get_random_inputs(
dtype=dtype, device=device, bs=args.bs, config=args.sharktank_config
)

if args.export:
# TODO move export from a run_vae file
Expand All @@ -126,11 +141,12 @@ def main(argv):
intermediates_saver.save_file(args.save_intermediates_path)

if args.compare_vs_torch:
from .diffuser_ref import run_torch_vae
from .diffuser_ref import run_torch_vae, run_flux_vae

diffusers_results = run_torch_vae(
"stabilityai/stable-diffusion-xl-base-1.0", inputs
)
if args.sharktank_config == "flux":
diffusers_results = run_flux_vae(inputs, torch.bfloat16)
elif args.sharktank_config == "sdxl":
run_torch_vae(args.torch_model, inputs)
print("diffusers results:", diffusers_results)
torch.testing.assert_close(diffusers_results, results)

Expand Down
12 changes: 10 additions & 2 deletions sharktank/sharktank/models/vae/tools/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch


def get_random_inputs(dtype, device, bs: int = 2):
def get_random_inputs(dtype, device, bs: int = 2, config: str = "sdxl"):
height = 1024
width = 1024
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)
if config == "sdxl":
print("sdxl returning inputs")
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)
elif config == "flux":
print("flux returning inputs")
return torch.rand(bs, int(width * height / 256), 64, dtype=dtype).to(device)
else:
print("config: ", config)
raise AssertionError(f"{config} config not implmented [sdxl, flux] implemented")
5 changes: 4 additions & 1 deletion sharktank/sharktank/tools/import_hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def import_hf_dataset(
config_json_path: PathLike,
param_paths: list[PathLike],
output_irpa_file: Optional[PathLike] = None,
target_dtype=None,
) -> Optional[Dataset]:
import safetensors

Expand All @@ -50,7 +51,9 @@ def import_hf_dataset(
for params_path in param_paths:
with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
tensors = [
DefaultPrimitiveTensor(name=name, data=st.get_tensor(name))
DefaultPrimitiveTensor(
name=name, data=st.get_tensor(name).to(target_dtype)
)
for name in st.keys()
]

Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":

return squeeze(self, dim)

def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
from ..ops import squeeze

return squeeze(self, dim)

def transpose(self, dim0: int, dim1: int) -> "AnyTensor":
from ..ops import transpose

Expand Down
Loading

0 comments on commit 088f006

Please sign in to comment.