Skip to content

Commit

Permalink
fix: correct elm wise mul with quantized tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienBalianSonos committed Feb 8, 2024
1 parent 632d25d commit ab78019
Showing 1 changed file with 61 additions and 3 deletions.
64 changes: 61 additions & 3 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bin_to_super_type!(mul, Mul,
crate::ndarray::Zip::from(c)
.and_broadcast(a)
.and_broadcast(b)
.for_each(|c,a,b| *c = scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16) + zp as i16, scale).clamp_cast());
.for_each(|c,a,b| *c = (scale_by((*a as i16 - zp as i16) * (*b as i16 - zp as i16), scale) + zp as i16).clamp_cast());
Ok(true)
}
DatumType::QU8(params) => {
Expand All @@ -82,14 +82,16 @@ bin_to_super_type!(mul, Mul,
crate::ndarray::Zip::from(c)
.and_broadcast(a)
.and_broadcast(b)
.for_each(|c,a,b| *c = scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32) + zp as i32, scale).clamp_cast());
.for_each(|c,a,b| *c = (scale_by((*a as i32 - zp as i32) * (*b as i32 - zp as i32), scale) + zp as i32).clamp_cast());
Ok(true)
}
_ => Ok(false)
}
}
},
q: [i8, u8, i32] => |c, a, b, _, _| *c = a.clone() * b;
q: [i8, u8, i32] => |c, a, b, zp, scale| {
*c = (scale_by((a.clone() as i32 - zp as i32) * (*b as i32 - zp as i32) , scale) + zp as i32).clamp_cast()
};
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b
);

Expand Down Expand Up @@ -628,6 +630,62 @@ mod tests {
Ok(())
}

struct TestMulAsQU8 {
tensor_mul_input1: [u8; 4],
scalar_mul_input_2: u8,
zero_point: i32,
scale: f32,
expected_output: [u8; 4],
}
impl TestMulAsQU8 {
fn check(&self) -> TractResult<()> {
// here we assume we can only mul quantized tensors
// already aligned with output tensor zp and scale
let mut model = TypedModel::default();
let input_dt =
DatumType::QU8(QParams::ZpScale { zero_point: self.zero_point, scale: self.scale });
let x = model.add_source("x", TypedFact::dt_shape(input_dt, [2_usize, 2]))?;
let mut a_tensor = tensor0(self.scalar_mul_input_2).broadcast_into_rank(2)?;
unsafe { a_tensor.set_datum_type(input_dt) };
let a = model.add_const("a", a_tensor.into_arc_tensor())?;
let y = model.wire_node("y", mul(), &[x, a])?[0];
model.set_output_outlets(&[y])?;
let mut input_data = Tensor::from_shape(&[2, 2], &self.tensor_mul_input1)?;
unsafe { input_data.set_datum_type(input_dt) };
let result = SimplePlan::new(&model)?.run(tvec!(input_data.into()))?;
let arr = result[0].to_array_view::<u8>()?;
assert_eq!(arr, Tensor::from_shape(&[2, 2], &self.expected_output)?.to_array_view()?);
Ok(())
}
}

#[test]
fn mul_as_qu8_overflow_clamp() -> TractResult<()> {
// last value in output tensor overflow hence is clamped
TestMulAsQU8 {
tensor_mul_input1: [1_u8, 2, 3, 128],
scalar_mul_input_2: 4_u8,
zero_point: 0,
scale: 1.,
expected_output: [4_u8, 8, 12, 255],
}
.check()
}

#[test]
fn mul_as_qu8_non_neutral_scale_and_offset() -> TractResult<()> {
// attempt with non neutral scale and offset
TestMulAsQU8 {
tensor_mul_input1: [1_u8, 2, 3, 128], // real: -3, 0, 3, 378
scalar_mul_input_2: 4_u8, // real: 6
zero_point: 2,
scale: 3.,
// optima in non quantized output real: -18, 0, 18, 2268
expected_output: [0_u8, 2, 8, 255], // approx obtained: -6, 0, 18, 759
}
.check()
}

#[test]
fn div_as_shift() -> TractResult<()> {
let mut model = TypedModel::default();
Expand Down

0 comments on commit ab78019

Please sign in to comment.