diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 1d621e506..58509753d 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -52,6 +52,8 @@ def main(): ) args = cli.parse(parser) + dataset_type = cli.get_input_data_files(args) + dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) hp = configs.LlamaHParams.from_gguf_props(dataset.properties) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 47d281565..6538890f6 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -8,6 +8,8 @@ from typing import Optional +from safetensors import safe_open + import math import sys @@ -20,7 +22,7 @@ from ..models.mixtral.mixtral import * from ..models.llama.llama import * from ..utils.debugging import trace_tensor -from ..utils.tokenizer import InferenceTokenizer, load_tokenizer +from ..utils.tokenizer import InferenceTokenizer class TorchGenerator: @@ -51,6 +53,7 @@ def begin_batch(self, prompts: list[str]): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride ) + token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) if self.shared_cache_state is not None: @@ -218,13 +221,25 @@ def main(): help="DType to use for activations in the model", default="float32", ) + parser.add_argument( + "--attention-dtype", + help="DType to use for attention in the model", + default="float16", + ) + parser.add_argument( + "--use-hf", + action="store_true", + default=False, + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) args = cli.parse(parser) device = torch.device(args.device) if args.device else None activation_dtype = getattr(torch, args.activation_dtype) + attention_dtype = getattr(torch, args.attention_dtype) assert isinstance(activation_dtype, torch.dtype) + assert isinstance(attention_dtype, torch.dtype) dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) prompts = args.prompt @@ -235,7 +250,8 @@ def main(): kv_cache_type=args.kv_cache_type, device=device, activation_dtype=activation_dtype, - attention_dtype=activation_dtype, + attention_dtype=attention_dtype, + use_hf=args.use_hf, ) if config.hp.expert_count: diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index 86e43d715..c5e2ea330 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -7,10 +7,8 @@ from typing import Optional import torch - from .. import ops from .base import Theta, ThetaLayer -from ..types.layout_utils import saturate_cast from ..types import ( DynamicScaledQuantizer, QuantizedTensor, @@ -54,19 +52,21 @@ def __init__( self.qdq_input: Optional[QuantizedTensor] = theta.optional_tensor("qdq_input") if self.q_input is not None and self.qdq_input is not None: raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input") + self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output") def forward(self, x): weight = self.weight bias = self.bias q_input = self.q_input qdq_input = self.qdq_input - + qdq_output = self.qdq_output if self.premul_input is not None: x = ops.elementwise(torch.mul, x, self.premul_input) if q_input is not None: x = q_input.quantize(x) elif qdq_input is not None: + # TODO: probably need a way to only do q_input if exporting. x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) @@ -76,4 +76,7 @@ def forward(self, x): # the QuantizedTensor escape. if isinstance(y, QuantizedTensor): y = y.unpack().dequant() + if qdq_output is not None: + # TODO: same as above. + y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index d062f1ffb..b9ae4b41a 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -25,6 +25,7 @@ def __init__( weight_name: str = "weight", epsilon: float = 1e-6, dtype: torch.dtype = torch.float32, + debug_save_file=None, ): super().__init__(theta) self.weight = self.theta_tensor(weight_name) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 8266872ad..36fbe1b60 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -8,7 +8,6 @@ from dataclasses import dataclass import math - import torch import torch.nn as nn import torch.nn.functional as F @@ -151,7 +150,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ), ) self.add_module("output_lm_head", LinearLayer(theta("output"))) - self.attn_blocks = nn.ModuleList( [ AttentionFFNBlock( @@ -349,6 +347,7 @@ def forward( xk_temp=xk_temp, xv_temp=xv_temp, ) + # Feed forward network. ffn_input = self.ffn_norm(h) ffn_down = self.ffn(ffn_input) diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py new file mode 100644 index 000000000..0d869932e --- /dev/null +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -0,0 +1,371 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Imports quark pre-processed weights and quantization config into a +Dataset of the gguf format. + +Usage: + python -m sharktank.models.llama.tools.import_quark_dataset \ + --params=llama2-7b-fp8.safetensors --output-irpa-file=new.irpa \ + --config-json=../llama2/config.json + +""" +from typing import Optional + +from safetensors.torch import save_file +import json +from pathlib import Path +import safetensors +import sys +import torch + +from sharktank.types import * +from sharktank.layers.configs.llm_configs import ( + _int_prop, + _float_prop, + _optional_int_prop, + _int_prop, +) + + +def _load_json(p: Path): + print(f"Loading {p}") + with open(p, "rb") as f: + return json.load(f) + + +def _get_dataset_props(config_json_struct) -> dict: + # Separate meta parameters (prefixed with _) from hparams. + meta_params = {k: v for k, v in config_json_struct.items() if k.startswith("_")} + hparams = {k: v for k, v in config_json_struct.items() if not k.startswith("_")} + return { + "meta": meta_params, + "hparams": hparams, + } + + +def _load_theta(st_source) -> Theta: + tensors = [ + DefaultPrimitiveTensor(name=name, data=st_source.get_tensor(name)) + for name in st_source.keys() + ] + return Theta(tensors) + + +def as_torch_or_none(tensor: Optional[InferenceTensor]) -> Optional[torch.Tensor]: + if tensor is None: + return None + return tensor.as_torch() + + +def hf_to_gguf(layer_name: str) -> str: + assert layer_name.startswith("model.layers") + mapping = { + "input_layernorm": "attn_norm", + "self_attn.q_proj": "attn_q", + "self_attn.k_proj": "attn_k", + "self_attn.v_proj": "attn_v", + "self_attn.o_proj": "attn_output", + "post_attention_layernorm": "ffn_norm", + "mlp.gate_proj": "ffn_gate", + "mlp.up_proj": "ffn_up", + "mlp.down_proj": "ffn_down", + } + + # Split the input string + parts = layer_name.split(".") + + # Extract the numerical value and the key to be mapped + numerical_value = parts[2] # The part after "models.layers" and its number + key_to_map = ".".join(parts[3:]) + + # Map the key + if key_to_map in mapping: + mapped_value = mapping[key_to_map] + else: + raise ValueError(f"Mapping for '{key_to_map}' not found.") + + # Construct the output string + output_str = f"blk.{numerical_value}.{mapped_value}" + return output_str + + +def apply_per_layer_quant( + root_theta: Theta, + layer_name: str, + updated_tensors: dict[str, InferenceTensor], + n_head: int, + split_sizes: list[int], +): + """Take the quantization parameters and hf weights from the imported Theta + and create InferenceTensors out of them, converting their names to gguf format + in the process. + """ + + layer_theta = root_theta(layer_name) + + weight_quant_scale = layer_theta.tensor("weight_quant_scale").as_torch() + + weight = layer_theta.tensor("weight").as_torch() + + # It looks dumb but, this step is required for numerical correctness against quark. + weight = weight.view(torch.float8_e4m3fn) + weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16) + + weight_quant_zero_point = layer_theta.optional_tensor("weight_quant_zero_point") + if weight_quant_zero_point == None: + weight_quant_zero_point = torch.zeros(1, dtype=torch.float32) + else: + weight_quant_zero_point = weight_quant_zero_point.as_torch() + input_quant_scale = as_torch_or_none( + layer_theta.optional_tensor("input_quant_scale") + ) + output_quant_scale = as_torch_or_none( + layer_theta.optional_tensor("output_quant_scale") + ) + + if weight_quant_scale is None: + print("weight quant scale not found for layer ", layer_name) + return + + layer_parent = ".".join(layer_name.split(".")[:-1]) + + def quantize_weight( + weight_name: str, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: Optional[torch.Tensor], + ): + # Our scale is the reciprocal of the quark scale + # We multiply scale by two to account for diff between fnuz and fn + weight_quantizer = StaticScaledQuantizer( + scale=1.0 / (weight_scale * 2.0), + reciprocal_scale=(weight_scale * 2.0), + offset=None + if (weight_zp is None or torch.count_nonzero(weight_zp) == 0) + else weight_zp, + dtype=torch.float8_e4m3fnuz, + ) + weight_quant = weight_quantizer.quantize(weight, name=weight_name) + updated_tensors[weight_quant.name] = weight_quant + + if "qkv" in layer_name: + # The qkv layer is fused in the quark model, decompose back into individual q, k , and v weights + q_weight, k_weight, v_weight = torch.split(weight, split_sizes) + q_weight = ( + q_weight.reshape( + n_head, 2, q_weight.shape[0] // n_head // 2, *q_weight.shape[1:] + ) + .swapaxes(1, 2) + .reshape(q_weight.shape) + ) + k_weight = ( + k_weight.reshape( + n_head, 2, k_weight.shape[0] // n_head // 2, *k_weight.shape[1:] + ) + .swapaxes(1, 2) + .reshape(k_weight.shape) + ) + q_name = hf_to_gguf(layer_parent + ".q_proj") + k_name = hf_to_gguf(layer_parent + ".k_proj") + v_name = hf_to_gguf(layer_parent + ".v_proj") + quantize_weight( + q_name + ".weight", q_weight, weight_quant_scale, weight_quant_zero_point + ) + quantize_weight( + k_name + ".weight", k_weight, weight_quant_scale, weight_quant_zero_point + ) + quantize_weight( + v_name + ".weight", v_weight, weight_quant_scale, weight_quant_zero_point + ) + # The output and input quantizers are duplicated for each of the q, k, and v weights + names = [f"{i}.qdq_output" for i in [q_name, k_name, v_name]] + for name in names: + updated_tensors[name] = StaticScaledQuantizer( + name=name, + scale=1.0 / (output_quant_scale * 2.0), + reciprocal_scale=output_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + names = [f"{i}.qdq_input" for i in [q_name, k_name, v_name]] + for name in names: + updated_tensors[name] = StaticScaledQuantizer( + name=name, + scale=1.0 / input_quant_scale * 2.0, + reciprocal_scale=input_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + # Remove the updated tensors from the original tree. + root_theta.pop(layer_parent + ".q_proj") + root_theta.pop(layer_parent + ".k_proj") + root_theta.pop(layer_parent + ".v_proj") + root_theta.pop(layer_name) + + else: + new_layer_name = hf_to_gguf(layer_name) + quantize_weight( + new_layer_name + ".weight", + weight, + weight_quant_scale, + weight_quant_zero_point, + ) + # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences + if input_quant_scale is not None: + updated_tensors[new_layer_name + ".qdq_input"] = StaticScaledQuantizer( + name=new_layer_name + ".qdq_input", + scale=1.0 / input_quant_scale, + reciprocal_scale=input_quant_scale, + dtype=torch.float8_e4m3fn, + ) + if output_quant_scale is not None: + updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer( + name=new_layer_name + ".qdq_output", + scale=1.0 / output_quant_scale, + reciprocal_scale=output_quant_scale, + dtype=torch.float8_e4m3fn, + ) + + # Remove the updated tensor from the original tree. + root_theta.pop(layer_name) + + +def convert_hf_hparams_to_gguf(hf_hparams: dict[str, any]) -> dict[str, any]: + hp = hf_hparams["hparams"] + attention_head_count = _int_prop(hp, "num_attention_heads") + attn_head_dim = int( + _int_prop(hp, "hidden_size") // _int_prop(hp, "num_attention_heads") + ) + + return { + "llama.context_length": _int_prop(hp, "max_position_embeddings"), + "llama.embedding_length": _int_prop(hp, "hidden_size"), + "llama.block_count": _int_prop(hp, "num_hidden_layers"), + "llama.feed_forward_length": _int_prop(hp, "intermediate_size"), + "llama.rope.dimension_count": attn_head_dim, + "llama.attention.head_count": attention_head_count, + "llama.attention.layer_norm_rms_epsilon": _float_prop(hp, "rms_norm_eps"), + "llama.attention.head_count_kv": _optional_int_prop( + hp, "num_key_value_heads", attention_head_count + ), + } + + +def update_norm_layer( + quant_theta: Theta, layer_name: str, updated_tensors: dict[str, InferenceTensor] +): + """Convert layernames for non quantized tensors and add them to the updated_tensors dict""" + for sub in ["input_layernorm", "post_attention_layernorm"]: + sub_name = layer_name + "." + sub + new_name = hf_to_gguf(sub_name) + ".weight" + single_replace(quant_theta, sub_name, new_name, updated_tensors) + kv_cache_scale = ( + quant_theta(layer_name).tensor("kv_cache_scaling_factor").as_torch() + ) + layer_idx = layer_name.split(".")[-1] + new_name = f"blk.{layer_idx}.kv_cache" + kv_cache_scale = DefaultPrimitiveTensor( + name=new_name + ".kv_cache_scaling_factor", data=kv_cache_scale + ) + updated_tensors[new_name] = kv_cache_scale + + +def single_replace( + quant_theta: Theta, + layer_name: str, + gguf_name: str, + updated_tensors: dict[str, InferenceTensor], +): + data = quant_theta(layer_name).tensor("weight").as_torch() + updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data) + + +def main(argv): + from sharktank.utils import cli + + parser = cli.create_parser() + cli.add_output_dataset_options(parser) + parser.add_argument( + "--config-json", type=Path, required=True, help="Path to the config.json file" + ) + parser.add_argument( + "--params", + type=Path, + default=Path("params.safetensors"), + help="Parameter file name, relative to config.json", + ) + parser.add_argument( + "--model-base", + type=str, + default="7b", + help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.", + choices=["7b", "70b"], + ) + args = cli.parse(parser, args=argv) + + config_json_path: Path = args.config_json + params_path: Path = args.params + # TODO: find a way to get this programatically so we don't have to flag for it + split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024] + num_layers = 32 if args.model_base == "7b" else 80 + + # Construct the pre-transform dataset. + dataset_props = _get_dataset_props(_load_json(config_json_path)) + with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: + quant_theta = _load_theta(st) + ds = Dataset(dataset_props, quant_theta) + + # Convert hyperparams to gguf format + updated_properties = convert_hf_hparams_to_gguf(ds.properties) + + head_count = (updated_properties["llama.attention.head_count"],) + + updated_tensors: dict[str, InferenceTensor] = {} + model_layers = [f"model.layers.{i}" for i in range(num_layers)] + + sub_layers = [ + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + "self_attn.o_proj", + "self_attn.qkv", + ] + for layer in model_layers: + for sub in sub_layers: + layer_name = layer + "." + sub + apply_per_layer_quant( + quant_theta, + layer_name, + updated_tensors, + n_head=head_count[0], + split_sizes=split_sizes, + ) + + # Update the non quantized weights (norm layers) + for layer_idx in model_layers: + update_norm_layer( + quant_theta, + layer_idx, + updated_tensors, + ) + + # The stragglers + stragglers = [ + ("model.embed_tokens", "token_embd.weight"), + ("model.norm", "output_norm.weight"), + ("lm_head", "output.weight"), + ] + for layer, new_name in stragglers: + single_replace(quant_theta, layer, new_name, updated_tensors) + + new_theta = Theta(updated_tensors) + # Make a new Dataset from the updated properties and tensors. + new_ds = Dataset(updated_properties, new_theta) + + new_ds.save(args.output_irpa_file, io_report_callback=print) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 0ab6053c2..b464c7199 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -249,7 +249,8 @@ def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: weight = unbox_tensor(weight) variance = x.pow(2).mean(-1, keepdim=True) output = x * torch.rsqrt(variance + epsilon) - output = output * weight + # The cast here is to match the hf implementation, affects numerics + output = weight * output.to(weight.dtype) return output diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py index 75189bdf3..21f1c89ec 100644 --- a/sharktank/sharktank/types/quantizers.py +++ b/sharktank/sharktank/types/quantizers.py @@ -139,14 +139,15 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor if axis is None: # Per tensor. if offset is None: + # Changed to t/reciprocal because narrow float types are garbage qs = saturate_cast( - t * self._scale, + t / self._reciprocal_scale, dtype=self.dtype, disable_saturate=self._disable_saturate, ) else: qs = saturate_cast( - t * self._scale + offset, + t / self._reciprocal_scale + offset, dtype=self.dtype, disable_saturate=self._disable_saturate, ) diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 975f54d24..3537726ec 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -110,6 +110,18 @@ def transform(self, *transforms: InferenceTensorTransform) -> "Theta": def to(self, *, device: Optional[Union[str, torch.device]] = None) -> "Theta": return self.transform(InferenceTensorTransforms.to_device(device)) + def pop(self, *name_path: str | int) -> "Theta": + # prune a subtree from the tree and return it as a new Theta object + name_path = ".".join(_norm_name_path(name_path)) + flat = self.flatten() + accum = {} + key_list = list(flat.keys()) + for key in key_list: + if key.startswith(name_path): + accum[key] = flat.pop(key) + self._tree = flat_to_nested_dict(flat) + return Theta(flat_to_nested_dict(accum)) + def flatten(self) -> dict[str, InferenceTensor]: results = {} diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 0d79785f6..1164fdbcf 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -77,6 +77,22 @@ def testTransform(self): self.assertIsNot(pt1, pt2) torch.testing.assert_close(pt1, pt2) + def testPop(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("a.b.3", 3, 4), + ) + ) + popped = t1.pop("a.b").flatten() + t1 = t1.flatten() + + self.assertIsNotNone("a.c.d", t1.keys()) + self.assertNotIn("a.b.c", t1.keys()) + self.assertNotIn("a.b.3", t1.keys()) + self.assertIn("a.b.3", popped.keys()) + class DatasetTest(unittest.TestCase): def setUp(self): diff --git a/sharktank/tests/types/quantizers_test.py b/sharktank/tests/types/quantizers_test.py index 787725e88..b712da06a 100644 --- a/sharktank/tests/types/quantizers_test.py +++ b/sharktank/tests/types/quantizers_test.py @@ -9,6 +9,7 @@ import torch from sharktank.types import * +from sharktank.types.layout_utils import saturate_cast from sharktank.utils.testing import TempDirTestBase @@ -164,6 +165,80 @@ def testQuantDequantf8fnuz(self): dequant_value = layout.dequant() torch.testing.assert_close(orig_value, dequant_value, atol=1e-1, rtol=1e-1) + def testQuarkF8Hell(self): + # we use hardcoded values here because they're representative of actual values from a quark model + scale = torch.tensor(0.0118, dtype=torch.float64) + orig = torch.tensor( + [ + -58, + -48, + -70, + 53, + -53, + 76, + -71, + -90, + 50, + 77, + 62, + -98, + 66, + -54, + 55, + -80, + -66, + -62, + -61, + -56, + 56, + -67, + 79, + -60, + -71, + 42, + 72, + -73, + 91, + 63, + 124, + -128, + ], + dtype=torch.int8, + ) + # mirrors dequant logic in quark and our importer + orig = orig.view(torch.float8_e4m3fn) + orig = (orig.to(torch.float64) * scale).to(torch.float16) + # Note that for fnuz we have to do scale*2 to account for the difference between types + # We specify the reciprocal scale explicitly to avoid adding more floating point error noise + fnuz = StaticScaledQuantizer( + name="dopoo", + scale=1.0 / (scale * 2), + reciprocal_scale=scale * 2, + offset=None, + dtype=torch.float8_e4m3fnuz, + ) + fn = StaticScaledQuantizer( + name="poodoo", + scale=1.0 / scale, + reciprocal_scale=scale, + offset=None, + dtype=torch.float8_e4m3fn, + ) + fnuz_quant = fnuz.quantize(orig) + fn_quant = fn.quantize(orig) + + dequant_fnuz = fnuz_quant.unpack().dequant() + dequant_fn = fn_quant.unpack().dequant() + + # redundant asserts for sanity + torch.testing.assert_close( + orig.to(torch.float16), dequant_fnuz, atol=1e-3, rtol=1e-3 + ) + torch.testing.assert_close( + orig.to(torch.float16), dequant_fn, atol=1e-3, rtol=1e-3 + ) + torch.testing.assert_close(dequant_fnuz, dequant_fn, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main()