Skip to content

Commit

Permalink
Merge branch 'main' into sharded-kvcache-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
stbaione authored Jan 16, 2025
2 parents f0fb408 + 0da9f25 commit 56261dc
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 209 deletions.
6 changes: 6 additions & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Needed to load Hugging Face Flux transformer with low_cpu_mem_usage=True.
# This is the only way to load this model that is split across multiple safetensors
# files.
# See https://github.com/huggingface/diffusers/issues/9343
accelerate

datasets==3.0.0
diffusers
parameterized
Expand Down
38 changes: 24 additions & 14 deletions sharktank/sharktank/layers/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .linear import LinearLayer
from .modulation import ModulationLayer
from .norm import RMSNormLayer
from .paged_llama_attention_block import PagedLlamaAttentionBlock
import functools


def qk_norm(q, k, v, rms_q, rms_k):
Expand All @@ -37,20 +37,19 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe) # todo

x = ops.scaled_dot_product_attention(
q=q, k=k, v=v, a=None, is_causal=True, scale=None
)
x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None)
x = ops.permute(x, (0, 2, 1, 3))
x = x.reshape(x.shape[0], x.shape[1], -1)

return x


class MMDITDoubleBlock(ThetaLayer):
def __init__(self, theta, num_heads: int):
def __init__(self, theta, num_heads: int, hidden_size: int):
super().__init__(theta)

self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True))
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv")))
self.add_module(
Expand Down Expand Up @@ -96,7 +95,9 @@ def forward(
txt_mod1, txt_mod2 = self.txt_mod(vec)

# prepare image for attention
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
img_modulated = ops.layer_norm(
img, normalized_shape=(self.hidden_size,), weight=None, bias=None, eps=1e-6
)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn_qkv(img_modulated)
img_qkv_2 = img_qkv.view(
Expand All @@ -109,7 +110,9 @@ def forward(
)

# prepare text for attention
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
txt_modulated = ops.layer_norm(
txt, normalized_shape=(self.hidden_size,), weight=None, bias=None, eps=1e-6
)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_qkv_2 = txt_qkv.view(
Expand All @@ -133,32 +136,37 @@ def forward(
# TODO: Refactor this for code reuse with the txt blocks
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
img, None, None, eps=1e-6
img, normalized_shape=(self.hidden_size,), weight=None, bias=None, eps=1e-6
) + img_mod2.shift
img_mlp_out1 = self.img_mlp1(img_mlp_in)
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
img_mlp_out2 = ops.elementwise(
functools.partial(F.gelu, approximate="tanh"), img_mlp_out1
)
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
img = img + img_mod2.gate * img_mlp_out3

# calculate the text blocks
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
txt, None, None, eps=1e-6
txt, normalized_shape=(self.hidden_size,), weight=None, bias=None, eps=1e-6
) + txt_mod2.shift
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
# TODO: Unify with modulation layer by taking act_fn as an arg
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
txt_mlp_out2 = ops.elementwise(
functools.partial(F.gelu, approximate="tanh"), txt_mlp_out1
)
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
txt = txt + txt_mod2.gate * txt_mlp_out3

return img, txt


class MMDITSingleBlock(ThetaLayer):
def __init__(self, theta, num_heads: int):
def __init__(self, theta, num_heads: int, hidden_size: int):
super().__init__(theta)

self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_module("mod", ModulationLayer(theta("modulation"), double=False))
self.add_module(
"attn_norm_q",
Expand All @@ -177,7 +185,9 @@ def __init__(self, theta, num_heads: int):

def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
mod, _ = self.mod(vec)
x_norm = ops.layer_norm(x, None, None, eps=1e-6)
x_norm = ops.layer_norm(
x, normalized_shape=(self.hidden_size,), weight=None, bias=None, eps=1e-6
)
x_mod = (1 + mod.scale) * x_norm + mod.shift
x_lin = self.linear1(x_mod)
qkv, mlp = torch.split(
Expand All @@ -192,6 +202,6 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
gelu = ops.elementwise(F.gelu, mlp)
gelu = ops.elementwise(functools.partial(F.gelu, approximate="tanh"), mlp)
output = self.linear2(torch.cat((attn, gelu), 2))
return x + mod.gate * output
3 changes: 0 additions & 3 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ def make_mmdit_single_block_random_theta(
mlp_ratio: float = 4.0,
dtype: torch.dtype | None = None,
) -> Theta:
in_channels = 128
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
mlp_hidden_size2 = int((mlp_ratio + 1) * hidden_size)
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
Expand Down
18 changes: 13 additions & 5 deletions sharktank/sharktank/models/flux/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ def export_flux_transformer(
)


def export_flux_transformer_from_hugging_face(
repo_id: str,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
def import_flux_transformer_dataset_from_hugging_face(
repo_id: str, parameters_output_path: PathLike
):
hf_dataset = get_dataset(
repo_id,
Expand All @@ -67,6 +64,17 @@ def export_flux_transformer_from_hugging_face(
output_irpa_file=parameters_output_path,
)


def export_flux_transformer_from_hugging_face(
repo_id: str,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
import_flux_transformer_dataset_from_hugging_face(
repo_id=repo_id, parameters_output_path=parameters_output_path
)

dataset = Dataset.load(parameters_output_path)
model = FluxModelV1(
theta=dataset.root_theta,
Expand Down
14 changes: 8 additions & 6 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __init__(self, theta: Theta, params: FluxParams):
[
MMDITDoubleBlock(
theta("double_blocks", i),
self.num_heads,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
)
for i in range(params.depth)
]
Expand All @@ -138,7 +139,8 @@ def __init__(self, theta: Theta, params: FluxParams):
[
MMDITSingleBlock(
theta("single_blocks", i),
self.num_heads,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
)
for i in range(params.depth_single_blocks)
]
Expand Down Expand Up @@ -305,10 +307,10 @@ def timestep_embedding(
return embedding


def layer_norm(inp):
weight = torch.ones(inp.shape)
bias = torch.zeros(inp.shape)
return ops.layer_norm(inp, weight, bias, eps=1e-6)
def layer_norm(inp: torch.Tensor):
return ops.layer_norm(
inp, normalized_shape=(inp.shape[-1],), weight=None, bias=None, eps=1e-6
)


def qk_norm(q, k, v, rms_q, rms_k):
Expand Down
137 changes: 35 additions & 102 deletions sharktank/sharktank/models/flux/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,40 @@

import torch
from os import PathLike
from collections import OrderedDict

from .flux import FluxParams, FluxModelV1
from .export import export_flux_transformer, flux_transformer_default_batch_sizes
from ...types import DefaultPrimitiveTensor, Theta, save_load_theta
from ...layers.testing import (
make_rand_torch,
make_mmdit_double_block_random_theta,
make_mmdit_single_block_random_theta,
)


def convert_flux_transformer_input_for_hugging_face_model(
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
return OrderedDict(
[
("hidden_states", img),
("encoder_hidden_states", txt),
("pooled_projections", y),
("timestep", timesteps),
("img_ids", img_ids.reshape(img_ids.shape[1:])),
("txt_ids", txt_ids.reshape(txt_ids.shape[1:])),
("guidance", guidance),
]
)


def make_random_theta(config: FluxParams, dtype: torch.dtype):
# TODO: do not hardcode values.

Expand Down Expand Up @@ -69,108 +94,6 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
"vector_in.out_layer.bias": DefaultPrimitiveTensor( #
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"double_blocks.0.img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"double_blocks.0.img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"double_blocks.0.img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"double_blocks.0.img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
),
"double_blocks.0.img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
),
"double_blocks.0.img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"double_blocks.0.img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"double_blocks.0.img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"double_blocks.0.txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"double_blocks.0.txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"double_blocks.0.txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"double_blocks.0.txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
),
"double_blocks.0.txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
),
"double_blocks.0.txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"double_blocks.0.txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"double_blocks.0.txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels2,), dtype=dtype)
),
"single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"single_blocks.0.attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"single_blocks.0.linear1.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size5,), dtype=dtype)
),
"single_blocks.0.linear1.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size5, hidden_size), dtype=dtype)
),
"single_blocks.0.linear2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"single_blocks.0.linear2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype)
),
"single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"final_layer.linear.weight": DefaultPrimitiveTensor( #
data=make_rand_torch(
(patch_size * patch_size * out_channels, hidden_size), dtype=dtype
Expand All @@ -187,6 +110,16 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
),
}

for i in range(config.depth):
tensor_dict[f"double_blocks.{i}"] = make_mmdit_double_block_random_theta(
in_channels=in_channels, hidden_size=hidden_size, mlp_ratio=mlp_ratio
).flatten()

for i in range(config.depth_single_blocks):
tensor_dict[f"single_blocks.{i}"] = make_mmdit_single_block_random_theta(
in_channels=in_channels2, hidden_size=hidden_size, mlp_ratio=mlp_ratio
).flatten()

if config.guidance_embed:
tensor_dict["guidance_in.in_layer.weight"] = DefaultPrimitiveTensor( #
data=make_rand_torch(
Expand Down
8 changes: 6 additions & 2 deletions sharktank/sharktank/models/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ python -m sharktank.models.punet.tools.import_hf_dataset \
```

# Run Vae decoder model eager mode
# Sample SDXL command
```
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --dtype=float32
```
# Sample Flux command to run through iree and compare vs huggingface diffusers torch model
```
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --compare_vs_torch --dtype=float32 --sharktank_config=flux --torch_model=black-forest-labs/FLUX.1-dev
```

## License

Significant portions of this implementation were derived from diffusers,
Expand Down
Loading

0 comments on commit 56261dc

Please sign in to comment.