From 90fb22959c0185a39a07090353b2cdae62d72c53 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Wed, 21 Feb 2024 14:53:35 +0100 Subject: [PATCH] feat: updated working test suite for qbinary --- core/src/ops/math/mod.rs | 7 +- test-rt/suite-unit/src/q_binary.rs | 227 +++++++++++++++++------------ 2 files changed, 141 insertions(+), 93 deletions(-) diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index 3192960b3e..f349873c82 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -147,7 +147,12 @@ eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult { crate::ndarray::Zip::from(view) .and_broadcast(a) .and_broadcast(b) - .for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) / (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast()); + // maintain division in f32 before rescale to maintain high accuracy + .for_each(|c,a,b| *c = ( + scale_by( + (*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier + ) as i32 + c_zp as i32 + ).clamp_cast()); Ok(c) } else { Div.generic_eval(a, b, c_dt) diff --git a/test-rt/suite-unit/src/q_binary.rs b/test-rt/suite-unit/src/q_binary.rs index d7292d8411..0789bd20ec 100644 --- a/test-rt/suite-unit/src/q_binary.rs +++ b/test-rt/suite-unit/src/q_binary.rs @@ -51,31 +51,35 @@ impl Arbitrary for QBinaryOpProblem { (1..4usize), (1..4usize), (1..4usize), + (1..6usize), ) - .prop_flat_map(|(len, a_signed, b_signed, c_signed, a_scale, b_scale, c_scale)| { - let a_dt = Self::pick_signed_datum(a_signed); - let b_dt = Self::pick_signed_datum(b_signed); - let c_dt = Self::pick_signed_datum(c_signed); - fn just_scale(scale: usize) -> Just { - Just(scale as f32 * 0.5) - } - ( - // tensor a - Just(a_dt), - qtensor(vec![1], a_dt), - just_scale(a_scale), - qtensor(vec![len], a_dt), - // tensor b - Just(b_dt), - qtensor(vec![1], b_dt), - just_scale(b_scale), - qtensor(vec![len], b_dt), - // dt of c - Just(c_dt), - qtensor(vec![1], c_dt), - just_scale(c_scale), - ) - }) + .prop_flat_map( + |(len, a_signed, b_signed, c_signed, a_scale, b_scale, c_scale, op_index)| { + let a_dt = Self::pick_signed_datum(a_signed); + let b_dt = Self::pick_signed_datum(b_signed); + let c_dt = Self::pick_signed_datum(c_signed); + fn just_scale(scale: usize) -> Just { + Just(scale as f32 * 0.5) + } + ( + // tensor a + Just(a_dt), + qtensor(vec![1], a_dt), + just_scale(a_scale), + qtensor(vec![len], a_dt), + // tensor b + Just(b_dt), + qtensor(vec![1], b_dt), + just_scale(b_scale), + qtensor(vec![len], b_dt), + // dt of c + Just(c_dt), + qtensor(vec![1], c_dt), + just_scale(c_scale), + Just(op_index), + ) + }, + ) .prop_map( |( a_dt, @@ -89,6 +93,7 @@ impl Arbitrary for QBinaryOpProblem { c_dt, c_zp, c_scale, + op_index, )| { let tensor_a = Self::get_qtensor( a_values.into_tensor(), @@ -106,19 +111,40 @@ impl Arbitrary for QBinaryOpProblem { zero_point: c_zp.into_tensor().cast_to_scalar::().unwrap(), scale: c_scale, }); + let ops = [ + tract_core::ops::math::mul(), + tract_core::ops::math::div(), + tract_core::ops::math::add(), + tract_core::ops::math::sub(), + tract_core::ops::math::min(), + tract_core::ops::math::max(), + ]; QBinaryOpProblem { - operator: tract_core::ops::math::mul(), + operator: ops[op_index].to_owned(), tensor_a, tensor_b, c_dt, } }, ) + .prop_filter("div does not allow 0 divisor", |q_prob| { + !(q_prob.operator.name().to_string().as_str().to_lowercase() == "div" + && q_prob + .tensor_b + .to_owned() + .cast_to_dt(DatumType::F32) + .unwrap() + .to_array_view() + .unwrap() + .iter() + .any(|x: &f32| *x == 0.0)) + }) .boxed() } } impl Test for QBinaryOpProblem { + #[warn(unused_variables)] fn run_with_approx( &self, id: &str, @@ -155,21 +181,37 @@ impl Test for QBinaryOpProblem { - zero_point as f32) * scale; - reference.to_array_view_mut()?.iter_mut().for_each(|x: &mut f32| { - *x = round_ties_to_even((*x).clamp(min_repr_val, max_repr_val)) - }); + reference + .to_array_view_mut()? + .iter_mut() + .for_each(|x: &mut f32| *x = (*x).clamp(min_repr_val, max_repr_val)); - let mut comparison = result.cast_to::()?.into_owned(); - comparison.to_array_view_mut()?.iter_mut().for_each(|x: &mut f32| { - *x = round_ties_to_even((*x).clamp(min_repr_val, max_repr_val)) - }); - comparison.close_enough(&reference, approx) + let mut diff = result.cast_to::()?.into_owned(); + + let acceptable_scale_error_ratio = match approx { + Approximation::Exact => 0., + Approximation::Approximate => 1., + _ => 2., + }; + tract_core::ndarray::Zip::from(diff.to_array_view_mut()?) + .and(reference.to_array_view()?) + .all(|x: &mut f32, xref: &f32| { + let closest_x = (*x).clamp(min_repr_val, max_repr_val); + // core maximal accepted distance by default + let distance = if &closest_x < xref { + (xref - closest_x).abs() + } else { + (closest_x - xref).abs() + }; + distance <= scale * acceptable_scale_error_ratio + }); + Ok(()) } } pub fn suite() -> TractResult { let mut suite = TestSuite::default(); - //suite.add_arbitrary::("proptest", ()); + suite.add_arbitrary::("proptest", ()); // simplification 0 at declutter constant suite.add( @@ -292,64 +334,65 @@ pub fn suite() -> TractResult { }, ); - // suite.add( - // "trivial_max_15_as_qu8_non_aligned_scale_and_offset", - // QBinaryOpProblem { - // operator: tract_core::ops::math::max(), - // tensor_a: tensor1(&[5_u8, 9, 8, 20]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 5, scale: 4. }), - // ) - // .unwrap() - // .into_owned(), - // tensor_b: tensor1(&[15u8]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 10, scale: 3. }), - // ) - // .unwrap() - // .into_owned(), - // c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), - // }, - // ); - // suite.add( - // "trivial_add_as_qu8_non_aligned_scale_and_offset", - // QBinaryOpProblem { - // operator: tract_core::ops::math::add(), - // tensor_a: tensor1(&[3_u8, 4, 10, 25]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }), - // ) - // .unwrap() - // .into_owned(), - // tensor_b: tensor1(&[6u8]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }), - // ) - // .unwrap() - // .into_owned(), - // c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), - // }, - // ); - // - // suite.add( - // "trivial_div_as_qu8_non_aligned_scale_and_offset", - // QBinaryOpProblem { - // operator: tract_core::ops::math::div(), - // tensor_a: tensor1(&[3_u8, 4, 10, 25]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }), - // ) - // .unwrap() - // .into_owned(), - // tensor_b: tensor1(&[6u8]) - // .cast_to_dt( - // u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }), - // ) - // .unwrap() - // .into_owned(), - // c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), - // }, - // ); + suite.add( + "trivial_max_15_as_qu8_non_aligned_scale_and_offset", + QBinaryOpProblem { + operator: tract_core::ops::math::max(), + tensor_a: tensor1(&[5_u8, 9, 8, 20]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 5, scale: 4. }), + ) + .unwrap() + .into_owned(), + tensor_b: tensor1(&[15u8]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 10, scale: 3. }), + ) + .unwrap() + .into_owned(), + c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), + }, + ); + + suite.add( + "trivial_add_as_qu8_non_aligned_scale_and_offset", + QBinaryOpProblem { + operator: tract_core::ops::math::add(), + tensor_a: tensor1(&[3_u8, 4, 10, 25]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }), + ) + .unwrap() + .into_owned(), + tensor_b: tensor1(&[6u8]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }), + ) + .unwrap() + .into_owned(), + c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), + }, + ); + + suite.add( + "trivial_div_as_qu8_non_aligned_scale_and_offset", + QBinaryOpProblem { + operator: tract_core::ops::math::div(), + tensor_a: tensor1(&[3_u8, 4, 10, 25]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }), + ) + .unwrap() + .into_owned(), + tensor_b: tensor1(&[6u8]) + .cast_to_dt( + u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }), + ) + .unwrap() + .into_owned(), + c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }), + }, + ); Ok(suite) }