From ab78019bb0fef51c7358a0fff3a78889563f6996 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Thu, 8 Feb 2024 17:17:08 +0100 Subject: [PATCH] fix: correct elm wise mul with quantized tensors --- core/src/ops/math/mod.rs | 64 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index baf03431d4..8ea4e7dcd5 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -71,7 +71,7 @@ bin_to_super_type!(mul, Mul, crate::ndarray::Zip::from(c) .and_broadcast(a) .and_broadcast(b) - .for_each(|c,a,b| *c = scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16) + zp as i16, scale).clamp_cast()); + .for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast()); Ok(true) } DatumType::QU8(params) => { @@ -82,14 +82,16 @@ bin_to_super_type!(mul, Mul, crate::ndarray::Zip::from(c) .and_broadcast(a) .and_broadcast(b) - .for_each(|c,a,b| *c = scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32) + zp as i32, scale).clamp_cast()); + .for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast()); Ok(true) } _ => Ok(false) } } }, - q: [i8, u8, i32] => |c, a, b, _, _| *c = a.clone() * b; + q: [i8, u8, i32] => |c, a, b, zp, scale| { + *c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast() + }; [f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b ); @@ -628,6 +630,62 @@ mod tests { Ok(()) } + struct TestMulAsQU8 { + tensor_mul_input1: [u8; 4], + scalar_mul_input_2: u8, + zero_point: i32, + scale: f32, + expected_output: [u8; 4], + } + impl TestMulAsQU8 { + fn check(&self) -> TractResult<()> { + // here we assume we can only mul quantized tensors + // already aligned with output tensor zp and scale + let mut model = TypedModel::default(); + let input_dt = + DatumType::QU8(QParams::ZpScale { zero_point: self.zero_point, scale: self.scale }); + let x = model.add_source("x", TypedFact::dt_shape(input_dt, [2_usize, 2]))?; + let mut a_tensor = tensor0(self.scalar_mul_input_2).broadcast_into_rank(2)?; + unsafe { a_tensor.set_datum_type(input_dt) }; + let a = model.add_const("a", a_tensor.into_arc_tensor())?; + let y = model.wire_node("y", mul(), &[x, a])?[0]; + model.set_output_outlets(&[y])?; + let mut input_data = Tensor::from_shape(&[2, 2], &self.tensor_mul_input1)?; + unsafe { input_data.set_datum_type(input_dt) }; + let result = SimplePlan::new(&model)?.run(tvec!(input_data.into()))?; + let arr = result[0].to_array_view::()?; + assert_eq!(arr, Tensor::from_shape(&[2, 2], &self.expected_output)?.to_array_view()?); + Ok(()) + } + } + + #[test] + fn mul_as_qu8_overflow_clamp() -> TractResult<()> { + // last value in output tensor overflow hence is clamped + TestMulAsQU8 { + tensor_mul_input1: [1_u8, 2, 3, 128], + scalar_mul_input_2: 4_u8, + zero_point: 0, + scale: 1., + expected_output: [4_u8, 8, 12, 255], + } + .check() + } + + #[test] + fn mul_as_qu8_non_neutral_scale_and_offset() -> TractResult<()> { + // attempt with non neutral scale and offset + TestMulAsQU8 { + tensor_mul_input1: [1_u8, 2, 3, 128], // real: -3, 0, 3, 378 + scalar_mul_input_2: 4_u8, // real: 6 + zero_point: 2, + scale: 3., + // optima in non quantized output real: -18, 0, 18, 2268 + expected_output: [0_u8, 2, 8, 255], // approx obtained: -6, 0, 18, 759 + } + .check() + } + #[test] fn div_as_shift() -> TractResult<()> { let mut model = TypedModel::default();