Skip to content

Commit

Permalink
fix: add test and fix for i8 to u8 and u8 to i8 quantized type elment…
Browse files Browse the repository at this point in the history
…wise ops
  • Loading branch information
JulienBalianSonos authored and kali committed Mar 25, 2024
1 parent 64315ab commit 51df802
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
8 changes: 8 additions & 0 deletions core/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ macro_rules! element_wise {
let input_dt = t.datum_type();
let sout_dt = out_dt.unwrap_or(input_dt);
if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
if input_dt.unquantized() != sout_dt.unquantized() {
// align unquantized input type to unquantized output type
*t = match input_dt.unquantized() {
DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
}.into_tensor();
}
unsafe { t.set_datum_type(sout_dt) } // force cast
let t: &mut[$typ_dt] = t.as_slice_mut::<$typ_dt>()?;
let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
Expand Down
18 changes: 18 additions & 0 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,24 @@ impl Tensor {
t.into_arc_tensor()
}

/// Offsets the tensor as an u8 type if it's an u8 type, otherwise passes it unchanged.
pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
let mut t = if let DatumType::I8 = self.dt.unquantized() {
self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
} else {
return self.clone();
};

if let DatumType::QI8(qp) = self.dt {
if let QParams::ZpScale { zero_point, scale } = qp {
t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
} else {
t.dt = DatumType::QU8(qp);
}
}
t.into_arc_tensor()
}

pub fn to_aligned_default(&self) -> anyhow::Result<Self> {
if self.dt.is_copy() {
unsafe {
Expand Down
8 changes: 8 additions & 0 deletions test-rt/suite-unit/src/q_elmwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,13 @@ pub fn suite() -> TractResult<TestSuite> {
},
);

suite.add(
"cos_switch_qi8_to_qu8_case",
QElmWiseOpProblem {
operator: tract_core::ops::math::cos(),
tensor_input: qi8_tensor1(&[-16], 39, 0.5)?,
out_dt: qu8_dt(2, 0.5),
},
);
Ok(suite)
}
12 changes: 12 additions & 0 deletions test-rt/suite-unit/src/q_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ pub fn qu8_dt(zp: i32, scale: f32) -> DatumType {
u8::datum_type().with_zp_scale(zp, scale)
}

pub fn qi8_dt(zp: i32, scale: f32) -> DatumType {
i8::datum_type().with_zp_scale(zp, scale)
}

pub fn qu8_tensor(tensor: Tensor, zp: i32, scale: f32) -> TractResult<Tensor> {
Ok(tensor.cast_to_dt(qu8_dt(zp, scale))?.into_owned())
}

pub fn qi8_tensor(tensor: Tensor, zp: i32, scale: f32) -> TractResult<Tensor> {
Ok(tensor.cast_to_dt(qi8_dt(zp, scale))?.into_owned())
}

pub fn qu8_tensor0(value: u8, zp: i32, scale: f32) -> TractResult<Tensor> {
qu8_tensor(tensor0(value), zp, scale)
}
Expand All @@ -53,6 +61,10 @@ pub fn qu8_tensor1(values: &[u8], zp: i32, scale: f32) -> TractResult<Tensor> {
qu8_tensor(tensor1(values), zp, scale)
}

pub fn qi8_tensor1(values: &[i8], zp: i32, scale: f32) -> TractResult<Tensor> {
qi8_tensor(tensor1(values), zp, scale)
}

pub trait QOpProblem {
fn reference_float_ops(&self) -> TractResult<Tensor>;

Expand Down

0 comments on commit 51df802

Please sign in to comment.