Skip to content

Commit

Permalink
add some fixes to run
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jul 19, 2024
1 parent c8104e9 commit 4ed3c9d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 106 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def main():
prompts = args.prompt

config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
hp=configs.LlamaHParams.from_hf_props(dataset.properties),
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
Expand Down
44 changes: 27 additions & 17 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LlamaModelConfig:
block_seq_stride: int = 16

# Either "paged" or "direct".
kv_cache_type: str = "paged"
kv_cache_type: str = "direct"

# The device on which to place intermediate state.
device: Optional[torch.device] = None
Expand Down Expand Up @@ -113,7 +113,9 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.hp = hp
self.cache = config.create_kv_cache()
self.activation_dtype = config.activation_dtype

print(self.hp.block_count)
print("block count")

key = "token_embd"
if key not in list(theta.keys):
self.hf = True
Expand All @@ -133,11 +135,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
key = "output_norm" if "output_norm" in list(theta.keys) else "model.norm"
self.add_module(
"output_norm",
RMSNormLayer(
theta(key), epsilon=self.hp.attention_layer_norm_rms_epsilon
),
RMSNormLayer(theta(key), epsilon=self.hp.attention_layer_norm_rms_epsilon),
)
print(theta.keys)
key = "output" if "output" in list(theta.keys) else "lm_head"
self.add_module("output_lm_head", LinearLayer(theta(key)))
key = "blk" if "blk" in list(theta.keys) else "model.layers"
Expand All @@ -151,7 +150,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
hf = self.hf,
hf=self.hf,
)
for n in range(hp.block_count)
]
Expand All @@ -168,6 +167,8 @@ def prefill(
seq_block_ids: torch.Tensor,
cache_state: list[torch.Tensor],
):
print("tokens.shape: ")
print(tokens.shape)
self._assert_device(tokens)
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(seq_block_ids)
Expand Down Expand Up @@ -285,20 +286,25 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
hf = False
):
hf=False,
):
super().__init__(theta)
if hf:
#tensor = theta("self_attn.qkv.weight").tensor
#tensor = tensor.reshape(head_count_kv, head_count // head_count_kv + 2, head_dim, head_dim * head_count)
#print(tensor)
self.add_module("attn_norm", RMSNormLayer(theta("input_layernorm"), epsilon=rms_epsilon))
#self.add_module("attn_qkv", LinearLayer(theta("self_attn.qkv")))
# tensor = theta("self_attn.qkv.weight").tensor
# tensor = tensor.reshape(head_count_kv, head_count // head_count_kv + 2, head_dim, head_dim * head_count)
# print(tensor)
self.add_module(
"attn_norm", RMSNormLayer(theta("input_layernorm"), epsilon=rms_epsilon)
)
# self.add_module("attn_qkv", LinearLayer(theta("self_attn.qkv")))
self.add_module("attn_q", LinearLayer(theta("self_attn.q_proj")))
self.add_module("attn_k", LinearLayer(theta("self_attn.k_proj")))
self.add_module("attn_v", LinearLayer(theta("self_attn.v_proj")))
self.add_module("attn_output", LinearLayer(theta("self_attn.o_proj")))
self.add_module("ffn_norm", RMSNormLayer(theta("post_attention_layernorm"), epsilon=rms_epsilon))
self.add_module(
"ffn_norm",
RMSNormLayer(theta("post_attention_layernorm"), epsilon=rms_epsilon),
)
self.add_module("ffn_gate", LinearLayer(theta("mlp.gate_proj")))
self.add_module("ffn_up", LinearLayer(theta("mlp.up_proj")))
self.add_module("ffn_down", LinearLayer(theta("mlp.down_proj")))
Expand All @@ -319,8 +325,7 @@ def __init__(

self.block_index = block_index
self.cache = cache
print(head_count)
assert(isinstance(head_count, int))
assert isinstance(head_count, int)
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
Expand Down Expand Up @@ -348,9 +353,14 @@ def forward(
assert feature_dim == self.head_count * self.head_dim

xq = self.attn_q(x)
print("xq shape: ")
print(xq.shape)
xk = self.attn_k(x)
xv = self.attn_v(x)

print("batch_seq_len: ", batch_seq_len)
print("head count: ", self.head_count)
print("head_dim: ", self.head_dim)
xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim)
xk = xk.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)
xv = xv.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)
Expand Down
77 changes: 41 additions & 36 deletions sharktank/sharktank/models/llama/tools/import_brevitas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ def _load_theta(st_source) -> Theta:
]
return Theta(tensors)


def as_torch_or_none(tensor: Optional[InferenceTensor]) -> Optional[torch.Tensor]:
if tensor is None:
return None
return tensor.as_torch()


def apply_per_layer_quant(
root_theta: Theta, layer_name: str, updated_tensors: dict[str, InferenceTensor]
):
Expand All @@ -79,80 +81,83 @@ def apply_per_layer_quant(
weight_quant_zero_point = torch.zeros(1, dtype=torch.float32)
else:
weight_quant_zero_point = weight_quant_zero_point.as_torch()
input_quant_scale = as_torch_or_none(layer_theta.optional_tensor("input_quant_scale"))
input_quant_scale = as_torch_or_none(
layer_theta.optional_tensor("input_quant_scale")
)

if weight_quant_scale is None:
print("weight quant scale not found for layer ", layer_name)
return

layer_parent = ".".join(layer_name.split('.')[:-1])
layer_parent = ".".join(layer_name.split(".")[:-1])
if "qkv" in layer_name:
print("qkv layer found")
print("weight_quant_scale shape: ", weight_quant_scale.shape)
print("layer_parent: ", layer_parent)
torch_weight = weight
print("torch weight shape: ", torch_weight.shape)
q_weight = torch_weight[:8192, :]
k_weight = torch_weight[8192:9216, :]
v_weight = torch_weight[9216:, :]

if "qkv" in layer_name:
q_weight_quant = PlanarQuantizedTensor(
shape=q_weight.shape,
name=layer_parent + ".q_proj.weight",
layout=TensorScaledLayout(
shape=q_weight.shape,
name=layer_parent + '.q_proj.weight',
layout=TensorScaledLayout(
shape=q_weight.shape,
d=1.0/weight_quant_scale,
qs=q_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
)
d=1.0 / weight_quant_scale,
qs=q_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
),
)
k_weight_quant = PlanarQuantizedTensor(
shape=k_weight.shape,
name=layer_parent + ".k_proj.weight",
layout=TensorScaledLayout(
shape=k_weight.shape,
name=layer_parent + '.k_proj.weight',
layout=TensorScaledLayout(
shape=k_weight.shape,
d=1.0/weight_quant_scale,
qs=k_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
)
d=1.0 / weight_quant_scale,
qs=k_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
),
)
v_weight_quant = PlanarQuantizedTensor(
shape=v_weight.shape,
name=layer_parent + ".v_proj.weight",
layout=TensorScaledLayout(
shape=v_weight.shape,
name=layer_parent + '.v_proj.weight',
layout=TensorScaledLayout(
shape=v_weight.shape,
d=1.0/weight_quant_scale,
qs=v_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
)
d=1.0 / weight_quant_scale,
qs=v_weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
),
)
print(q_weight_quant.name)
print(k_weight_quant.name)
print(v_weight_quant.name)
updated_tensors[q_weight_quant.name] = q_weight_quant
updated_tensors[k_weight_quant.name] = k_weight_quant
updated_tensors[v_weight_quant.name] = v_weight_quant
#updated_tensors[layer_name] = None
# updated_tensors[layer_name] = None
else:
weight_quant = PlanarQuantizedTensor(
shape=weight.shape,
name=layer_name + '.weight',
name=layer_name + ".weight",
layout=TensorScaledLayout(
shape=weight.shape,
d=1.0/weight_quant_scale,
d=1.0 / weight_quant_scale,
qs=weight.to(dtype=torch.float8_e4m3fnuz),
m=weight_quant_zero_point,
dtype=torch.float16, # Original dtype.
)
),
)
print(weight_quant.name)
updated_tensors[weight_quant.name] = weight_quant
# Spot check that things look sane.
#weight_dequant = weight_quant.unpack().dequant()
#torch.testing.assert_close(weight, weight_dequant, atol=3, rtol=3)
# weight_dequant = weight_quant.unpack().dequant()
# torch.testing.assert_close(weight, weight_dequant, atol=3, rtol=3)


def main(argv):
Expand Down Expand Up @@ -195,13 +200,13 @@ def main(argv):
# layer name. We process each of these in turn to produce a per-layer
# quantization scheme where no quantized tensors escape their layer.
updated_tensors: dict[str, InferenceTensor] = {}
model_layers = ["model.layers."+str(i) for i in range(80)]
sub_layers = ['mlp.down_proj', 'mlp.up_proj', 'self_attn.o_proj', 'self_attn.qkv' ]
model_layers = ["model.layers." + str(i) for i in range(32)]
sub_layers = ["mlp.down_proj", "mlp.up_proj", "self_attn.o_proj", "self_attn.qkv"]
for layer in model_layers:
for sub in sub_layers:

layer_name = layer + '.' + sub
#if layern_name not in ["quantization", "decoder_type",
layer_name = layer + "." + sub
# if layern_name not in ["quantization", "decoder_type",
print(f"Applying per-layer quants: {layer_name}")
apply_per_layer_quant(quant_theta, layer_name, updated_tensors)

Expand Down
52 changes: 0 additions & 52 deletions sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def apply_per_layer_quant(
layer_theta = root_theta(layer_name)
weight = layer_theta.tensor("weight")
weight_dtype = weight.as_torch().dtype
<<<<<<< HEAD
bias = layer_theta.optional_tensor("bias")

# The config file has tensors as 1d and a _shape suffixed array with the
Expand All @@ -91,15 +90,6 @@ def _get_json_tensor(
dtype_str = qp[f"{name}_dtype"]
dtype = _dtype_str_to_dtype(dtype_str)

=======

# The config file has tensors as 1d and a _shape suffixed array with the
# concrete shape.
def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:
data_1d = qp.get(name)
if data_1d is None:
return None
>>>>>>> 08af8e6 (re-add original)
shape = qp[f"{name}_shape"]
return torch.tensor(data_1d, dtype=dtype).reshape(shape)

Expand All @@ -115,26 +105,11 @@ def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:

input_scale = _get_json_tensor("input_scale", torch.float32)
weight_scale = _get_json_tensor("weight_scale", torch.float32)
<<<<<<< HEAD
weight_zp = _get_json_tensor("weight_zp", dtype=None)

# In the current version, we assume that the input is per-tensor quantized
# for signed arithmetic.
input_zp = _get_json_tensor("input_zp", dtype=None)
=======
# In the JSON file, zero points are purely positive numbers representing
# the zero point assuming a uint8 datatype. Since we prefer to use
# signed arithmetic, since that is better accelerated, we offset these by
# -128. By decoding them as uint8, it validates our range assumption.
# Then we widen/cast to offset.
weight_zp = _get_json_tensor("weight_zp", torch.uint8)
if weight_zp is not None:
weight_zp = (weight_zp.to(dtype=torch.int32) - 128).to(torch.int8)

# In the current version, we assume that the input is not offset and is
# per-tensor quantized for signed arithmetic.
input_zp = _get_json_tensor("input_zp", torch.int8)
>>>>>>> 08af8e6 (re-add original)
if input_zp is not None:
assert torch.count_nonzero(input_zp) == 0

Expand All @@ -149,17 +124,12 @@ def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:

# Quantized layer must have all quantization info.
assert (
<<<<<<< HEAD
weight_scale is not None
=======
weight_scale is not None and weight_zp is not None
>>>>>>> 08af8e6 (re-add original)
), f"Could not find weight scale (in {qp.keys()}) for {layer_name}"
assert (
input_scale is not None
), f"Could not find input scale (in {qp.keys()}) for {layer_name}"

<<<<<<< HEAD
def quantize_weight(
weight_name: str,
weight: torch.Tensor,
Expand Down Expand Up @@ -188,28 +158,6 @@ def quantize_bias(
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
=======
# Weight scaling.
# There is an implicit assumption that the weight is asym (uint8) quantized.
# Our quantizer uses scale/offset nomenclature. The offset maps to
# zero-point, and the scale maps to the *dequant* scale (so terms differ
# by reciprocal).
weight_quantizer = StaticScaledQuantizer(
scale=1.0 / weight_scale,
reciprocal_scale=weight_scale,
offset=None if torch.count_nonzero(weight_zp) == 0 else weight_zp,
dtype=torch.int8,
)
weight_quant = weight_quantizer.quantize(weight, name=weight.name)
updated_tensors[weight_quant.name] = weight_quant
# Spot check that things look sane.
weight_dequant = weight_quant.unpack().dequant()
weight_diff = weight.as_torch() - weight_dequant

# Bias/output scaling.
bias = layer_theta.optional_tensor("bias")
if QUANTIZE_BIAS and bias is not None:
>>>>>>> 08af8e6 (re-add original)
# If the bias is present, it dictates the overall output quantization
# and will not be checked for correct parameters at runtime. It must
# be quantized to match properly.
Expand Down

0 comments on commit 4ed3c9d

Please sign in to comment.