diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index afe2928c0..00b98cf3f 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -98,12 +98,14 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) outs(%result_fill : !accum_tensor_type) { ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): - %bmm_mul = arith.mulf %a_element, %b_element : !a_type {% if accum_type == a_type %} + %bmm_mul = arith.mulf %a_element, %b_element : !a_type %bmm_accum = arith.addf %bmm_mul, %out : !a_type {% else %} - %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type - %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + %a_ext = arith.extf %a_element : !a_type to !accum_type + %b_ext = arith.extf %b_element : !a_type to !accum_type + %bmm_mul = arith.mulf %a_ext, %b_ext : !accum_type + %bmm_accum = arith.addf %bmm_mul, %out : !accum_type {% endif %} linalg.yield %bmm_accum : !accum_type } -> !accum_tensor_type