Skip to content

Commit

Permalink
feat: updated working test suite for qbinary
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienBalianSonos committed Feb 21, 2024
1 parent 221f869 commit 90fb229
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 93 deletions.
7 changes: 6 additions & 1 deletion core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
crate::ndarray::Zip::from(view)
.and_broadcast(a)
.and_broadcast(b)
.for_each(|c,a,b| *c = (scale_by((*a as i32 - a_zp as i32) / (*b as i32 - b_zp as i32), multiplier) + c_zp as i32).clamp_cast());
// maintain division in f32 before rescale to maintain high accuracy
.for_each(|c,a,b| *c = (
scale_by(
(*a as i32 - a_zp as i32) as f32 / (*b as i32 - b_zp as i32) as f32, multiplier
) as i32 + c_zp as i32
).clamp_cast());
Ok(c)
} else {
Div.generic_eval(a, b, c_dt)
Expand Down
227 changes: 135 additions & 92 deletions test-rt/suite-unit/src/q_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,35 @@ impl Arbitrary for QBinaryOpProblem {
(1..4usize),
(1..4usize),
(1..4usize),
(1..6usize),
)
.prop_flat_map(|(len, a_signed, b_signed, c_signed, a_scale, b_scale, c_scale)| {
let a_dt = Self::pick_signed_datum(a_signed);
let b_dt = Self::pick_signed_datum(b_signed);
let c_dt = Self::pick_signed_datum(c_signed);
fn just_scale(scale: usize) -> Just<f32> {
Just(scale as f32 * 0.5)
}
(
// tensor a
Just(a_dt),
qtensor(vec![1], a_dt),
just_scale(a_scale),
qtensor(vec![len], a_dt),
// tensor b
Just(b_dt),
qtensor(vec![1], b_dt),
just_scale(b_scale),
qtensor(vec![len], b_dt),
// dt of c
Just(c_dt),
qtensor(vec![1], c_dt),
just_scale(c_scale),
)
})
.prop_flat_map(
|(len, a_signed, b_signed, c_signed, a_scale, b_scale, c_scale, op_index)| {
let a_dt = Self::pick_signed_datum(a_signed);
let b_dt = Self::pick_signed_datum(b_signed);
let c_dt = Self::pick_signed_datum(c_signed);
fn just_scale(scale: usize) -> Just<f32> {
Just(scale as f32 * 0.5)
}
(
// tensor a
Just(a_dt),
qtensor(vec![1], a_dt),
just_scale(a_scale),
qtensor(vec![len], a_dt),
// tensor b
Just(b_dt),
qtensor(vec![1], b_dt),
just_scale(b_scale),
qtensor(vec![len], b_dt),
// dt of c
Just(c_dt),
qtensor(vec![1], c_dt),
just_scale(c_scale),
Just(op_index),
)
},
)
.prop_map(
|(
a_dt,
Expand All @@ -89,6 +93,7 @@ impl Arbitrary for QBinaryOpProblem {
c_dt,
c_zp,
c_scale,
op_index,
)| {
let tensor_a = Self::get_qtensor(
a_values.into_tensor(),
Expand All @@ -106,19 +111,40 @@ impl Arbitrary for QBinaryOpProblem {
zero_point: c_zp.into_tensor().cast_to_scalar::<i32>().unwrap(),
scale: c_scale,
});
let ops = [
tract_core::ops::math::mul(),
tract_core::ops::math::div(),
tract_core::ops::math::add(),
tract_core::ops::math::sub(),
tract_core::ops::math::min(),
tract_core::ops::math::max(),
];
QBinaryOpProblem {
operator: tract_core::ops::math::mul(),
operator: ops[op_index].to_owned(),
tensor_a,
tensor_b,
c_dt,
}
},
)
.prop_filter("div does not allow 0 divisor", |q_prob| {
!(q_prob.operator.name().to_string().as_str().to_lowercase() == "div"
&& q_prob
.tensor_b
.to_owned()
.cast_to_dt(DatumType::F32)
.unwrap()
.to_array_view()
.unwrap()
.iter()
.any(|x: &f32| *x == 0.0))
})
.boxed()
}
}

impl Test for QBinaryOpProblem {
#[warn(unused_variables)]
fn run_with_approx(
&self,
id: &str,
Expand Down Expand Up @@ -155,21 +181,37 @@ impl Test for QBinaryOpProblem {
- zero_point as f32)
* scale;

reference.to_array_view_mut()?.iter_mut().for_each(|x: &mut f32| {
*x = round_ties_to_even((*x).clamp(min_repr_val, max_repr_val))
});
reference
.to_array_view_mut()?
.iter_mut()
.for_each(|x: &mut f32| *x = (*x).clamp(min_repr_val, max_repr_val));

let mut comparison = result.cast_to::<f32>()?.into_owned();
comparison.to_array_view_mut()?.iter_mut().for_each(|x: &mut f32| {
*x = round_ties_to_even((*x).clamp(min_repr_val, max_repr_val))
});
comparison.close_enough(&reference, approx)
let mut diff = result.cast_to::<f32>()?.into_owned();

let acceptable_scale_error_ratio = match approx {
Approximation::Exact => 0.,
Approximation::Approximate => 1.,
_ => 2.,
};
tract_core::ndarray::Zip::from(diff.to_array_view_mut()?)
.and(reference.to_array_view()?)
.all(|x: &mut f32, xref: &f32| {
let closest_x = (*x).clamp(min_repr_val, max_repr_val);
// core maximal accepted distance by default
let distance = if &closest_x < xref {
(xref - closest_x).abs()
} else {
(closest_x - xref).abs()
};
distance <= scale * acceptable_scale_error_ratio
});
Ok(())
}
}

pub fn suite() -> TractResult<TestSuite> {
let mut suite = TestSuite::default();
//suite.add_arbitrary::<QBinaryOpProblem>("proptest", ());
suite.add_arbitrary::<QBinaryOpProblem>("proptest", ());

// simplification 0 at declutter constant
suite.add(
Expand Down Expand Up @@ -292,64 +334,65 @@ pub fn suite() -> TractResult<TestSuite> {
},
);

// suite.add(
// "trivial_max_15_as_qu8_non_aligned_scale_and_offset",
// QBinaryOpProblem {
// operator: tract_core::ops::math::max(),
// tensor_a: tensor1(&[5_u8, 9, 8, 20])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 5, scale: 4. }),
// )
// .unwrap()
// .into_owned(),
// tensor_b: tensor1(&[15u8])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 10, scale: 3. }),
// )
// .unwrap()
// .into_owned(),
// c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
// },
// );
// suite.add(
// "trivial_add_as_qu8_non_aligned_scale_and_offset",
// QBinaryOpProblem {
// operator: tract_core::ops::math::add(),
// tensor_a: tensor1(&[3_u8, 4, 10, 25])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }),
// )
// .unwrap()
// .into_owned(),
// tensor_b: tensor1(&[6u8])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }),
// )
// .unwrap()
// .into_owned(),
// c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
// },
// );
//
// suite.add(
// "trivial_div_as_qu8_non_aligned_scale_and_offset",
// QBinaryOpProblem {
// operator: tract_core::ops::math::div(),
// tensor_a: tensor1(&[3_u8, 4, 10, 25])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }),
// )
// .unwrap()
// .into_owned(),
// tensor_b: tensor1(&[6u8])
// .cast_to_dt(
// u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }),
// )
// .unwrap()
// .into_owned(),
// c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
// },
// );
suite.add(
"trivial_max_15_as_qu8_non_aligned_scale_and_offset",
QBinaryOpProblem {
operator: tract_core::ops::math::max(),
tensor_a: tensor1(&[5_u8, 9, 8, 20])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 5, scale: 4. }),
)
.unwrap()
.into_owned(),
tensor_b: tensor1(&[15u8])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 10, scale: 3. }),
)
.unwrap()
.into_owned(),
c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
},
);

suite.add(
"trivial_add_as_qu8_non_aligned_scale_and_offset",
QBinaryOpProblem {
operator: tract_core::ops::math::add(),
tensor_a: tensor1(&[3_u8, 4, 10, 25])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }),
)
.unwrap()
.into_owned(),
tensor_b: tensor1(&[6u8])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }),
)
.unwrap()
.into_owned(),
c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
},
);

suite.add(
"trivial_div_as_qu8_non_aligned_scale_and_offset",
QBinaryOpProblem {
operator: tract_core::ops::math::div(),
tensor_a: tensor1(&[3_u8, 4, 10, 25])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 3, scale: 4.5 }),
)
.unwrap()
.into_owned(),
tensor_b: tensor1(&[6u8])
.cast_to_dt(
u8::datum_type().quantize(QParams::ZpScale { zero_point: 4, scale: 2.5 }),
)
.unwrap()
.into_owned(),
c_dt: DatumType::QU8(QParams::ZpScale { zero_point: 0, scale: 1. }),
},
);

Ok(suite)
}

0 comments on commit 90fb229

Please sign in to comment.