From e98e4586058735eb16e2c9392f1c498e393a6c3a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 2 Jan 2025 15:40:09 -0800 Subject: [PATCH] Fix signed extension in q4_1 sharktank kernel (#726) --- .../templates/mmt_block_scaled_offset_q4_unsigned.mlir | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index afe2928c0..00b98cf3f 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -98,12 +98,14 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) outs(%result_fill : !accum_tensor_type) { ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): - %bmm_mul = arith.mulf %a_element, %b_element : !a_type {% if accum_type == a_type %} + %bmm_mul = arith.mulf %a_element, %b_element : !a_type %bmm_accum = arith.addf %bmm_mul, %out : !a_type {% else %} - %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type - %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + %a_ext = arith.extf %a_element : !a_type to !accum_type + %b_ext = arith.extf %b_element : !a_type to !accum_type + %bmm_mul = arith.mulf %a_ext, %b_ext : !accum_type + %bmm_accum = arith.addf %bmm_mul, %out : !accum_type {% endif %} linalg.yield %bmm_accum : !accum_type } -> !accum_tensor_type