Skip to content

Commit

Permalink
moe (#162)
Browse files Browse the repository at this point in the history
Throwing this up as a pr to make it easier to view

---------

Co-authored-by: archana-ramalingam <archana.ramalingam@amd.com>
Co-authored-by: Ian <ian.nordeng@amd.com>
  • Loading branch information
3 people authored Sep 5, 2024
1 parent e051c37 commit aead69f
Show file tree
Hide file tree
Showing 20 changed files with 1,412 additions and 228 deletions.
6 changes: 5 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
from ..models.mixtral.mixtral import *


def main():
Expand Down Expand Up @@ -52,7 +53,10 @@ def main():
llama_config = LlamaModelConfig(hp)
llama_config.static_tables = False # Rely on the compiler for hoisting tables.
llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
if llama_config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
return {
Expand Down
7 changes: 6 additions & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..types import *

# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
Expand Down Expand Up @@ -236,7 +237,11 @@ def main():
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
)
model = PagedLlamaModelV1(dataset.root_theta, config)

if config.hp.expert_count:
model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch

Expand Down
144 changes: 144 additions & 0 deletions sharktank/sharktank/examples/validate_direct_mixtral_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys

import torch

from sharktank.layers import *
from sharktank.types import *
from sharktank.models.mixtral.mixtral import *


def main(args: list[str]):
from ..utils import cli

torch.no_grad().__enter__()

parser = cli.create_parser()
cli.add_input_dataset_options(parser)
args = cli.parse(parser)

dataset = cli.get_input_dataset(args)
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
llama_config = LlamaModelConfig(hp)
llama_config.kv_cache_type = "direct"
llama_config.activation_dtype = torch.float16
model = PagedMixtralModelV1(dataset.root_theta, llama_config)

# bs ("batch size") == 1
cache_state = model.cache.allocate(bs=1)

start_index = 0
tokens = torch.tensor(
[
[
1,
1059,
31871,
1217,
322,
266,
3682,
6075,
31902,
13,
31849,
31871,
0,
0,
0,
0,
]
+ 48 * [0],
]
)
assert tokens.shape[1] % model.cache.block_seq_stride == 0
seq_block_ids = torch.tensor(
[
[127, 0, 0, 0],
]
)

# Important: Do not use a sequence length of 0 for empty batch slots
# as it will cause softmax to nan due to a mask of all -inf. This then
# propagates and causes badness.
seq_lens = torch.tensor([12])

attention_mask = model.attention_mask(
model.input_mask(seq_lens, tokens.shape[1]),
)

print(f"Step {start_index}")
logits = model.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
# TODO: Normalize the output of extract_tokens_from_logits into tensor [bs, 1].
tokens = torch.tensor(model.extract_tokens_from_logits(logits, seq_lens)).unsqueeze(
1
)
print(f" : tokens = {tokens}")

# Decode a step.
print("Decoding...")
print(tokens.shape, tokens)
start_positions = torch.tensor([12])
seq_lens = seq_lens + 1
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
seq_lens,
seq_block_ids.shape[1] * model.cache.block_seq_stride,
),
)
logits = model.decode(
tokens,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
tokens = torch.tensor(model.extract_tokens_from_logits(logits, [1])).unsqueeze(1)
print(f" : tokens = {tokens}")

def save_prefill_module(model):
from iree.compiler.extras.fx_importer import FxImporter
from iree.compiler.ir import AsmState

importer = FxImporter()

print("Generating FX graph")

class InferenceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module("prefill", model)

def forward(self, tokens, attention_mask, seq_block_ids, *cache_state):
return self.prefill.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=list(cache_state),
)

infmod = InferenceModule()
prog = torch.export.export(
infmod, (tokens, attention_mask, seq_block_ids) + tuple(cache_state)
)

print(f"FX prog:", prog)
importer.import_program(prog, func_name="prefill")
output_file = "/tmp/prefill.mlirbc"
print("Saving to:", output_file)
with open(output_file, "wb") as f:
importer.module_op.write_bytecode(f)


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
48 changes: 48 additions & 0 deletions sharktank/sharktank/examples/validate_mixtral_ref_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys

import torch

from sharktank.layers import *
from sharktank.types import *
from sharktank.models.mixtral.mixtral_ref import *


def main(args: list[str]):
from ..utils import cli

torch.no_grad().__enter__()

parser = cli.create_parser()
cli.add_input_dataset_options(parser)
args = cli.parse(parser)

dataset = cli.get_input_dataset(args)
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
ref_llama_config = RefLlamaModelConfig(hp)
ref_llama_config.activation_dtype = torch.float16
model = DirectCacheMixtralModelV1(dataset.root_theta, ref_llama_config)

kv_cache = model.create_cache(bs=1)
start_index = 0
next_tokens = [1, 1059, 31871, 1217, 322, 266, 3682, 6075, 31902, 13, 31849, 31871]
print(f"Step {start_index}")
tokens = model.forward(
torch.tensor([next_tokens]), start_index=start_index, local_kv_cache=kv_cache
)
print(f" : tokens = {tokens}")

# Decode a step.
print("Decoding...")
print(tokens.shape, tokens)
decode_token = model.forward(tokens, start_index=12, local_kv_cache=kv_cache)
print(f" : decode tokens = {decode_token}")


if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
5 changes: 5 additions & 0 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,10 @@
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .token_embedding import TokenEmbeddingLayer
from .llama_attention_block import LlamaAttentionBlock
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock

from . import configs
5 changes: 1 addition & 4 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from ..utils import debugging

__all__ = [
"LinearLayer",
"RotaryEmbeddingLayer",
"RMSNormLayer",
"BaseLayer",
"ThetaLayer",
"TokenEmbedding",
]


Expand Down
32 changes: 26 additions & 6 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import torch

__all__ = [
"LlamaHParams",
]
__all__ = ["LlamaHParams"]


@dataclass
Expand All @@ -36,14 +34,21 @@ class LlamaHParams:
block_count: int
feed_forward_length: int
rope_dimension_count: int
rope_freq_base: float
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int
expert_count: int
expert_used_count: int

@staticmethod
def from_gguf_props(p: dict[str, Any]):
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
attention_head_count = _int_prop(p, "llama.attention.head_count")

return LlamaHParams(
context_length=_int_prop(p, "llama.context_length"),
embedding_length=_int_prop(p, "llama.embedding_length"),
Expand All @@ -58,6 +63,15 @@ def from_gguf_props(p: dict[str, Any]):
attention_head_count_kv=_optional_int_prop(
p, "llama.attention.head_count_kv", attention_head_count
),
rope_freq_base=_optional_float_prop(
p, "llama.rope.freq_base", default_rope_freq_base
),
expert_count=_optional_int_prop(
p, "llama.expert_count", default_expert_count
),
expert_used_count=_optional_int_prop(
p, "llama.expert_used_count", default_expert_used_count
),
)


Expand All @@ -79,10 +93,16 @@ def _int_prop(p: dict[str, Any], name: str) -> int:
raise KeyError(f"Property '{name}' not found (among keys {p.keys()})")


def _optional_float_prop(p: dict[str, Any], name: str, default_value: float) -> float:
value = p.get(name, default_value)
try:
return float(value)
except ValueError as e:
raise ValueError(f"Property '{name}' expected to be a float and was not") from e


def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
value = p[name]
if value is None:
return default_value
value = p.get(name, default_value)
try:
return int(value)
except ValueError as e:
Expand Down
38 changes: 38 additions & 0 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

import torch
import torch.nn.functional as F

from .base import Theta, ThetaLayer
from .linear import LinearLayer

__all__ = [
"FFN",
]


class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
):
super().__init__(theta)

self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

def forward(
self,
h: torch.Tensor,
):
ffn_gate = F.silu(self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
Loading

0 comments on commit aead69f

Please sign in to comment.