diff --git a/core/src/ops/element_wise.rs b/core/src/ops/element_wise.rs index bbd2a4ea12..3d7cf43257 100644 --- a/core/src/ops/element_wise.rs +++ b/core/src/ops/element_wise.rs @@ -197,7 +197,7 @@ macro_rules! element_wise { $( $( $( - let input_dt = t.datum_type(); + let mut input_dt = t.datum_type(); let sout_dt = out_dt.unwrap_or(input_dt); if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() { if input_dt.unquantized() != sout_dt.unquantized() { @@ -207,6 +207,7 @@ macro_rules! element_wise { DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(), unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt) }.into_tensor(); + input_dt = t.datum_type(); // because zero_point change } unsafe { t.set_datum_type(sout_dt) } // force cast let t: &mut[$typ_dt] = t.as_slice_mut::<$typ_dt>()?; diff --git a/data/src/tensor.rs b/data/src/tensor.rs index 8e8609e544..7f3fb68504 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -1470,7 +1470,7 @@ impl Tensor { t.into_arc_tensor() } - /// Offsets the tensor as an u8 type if it's an u8 type, otherwise passes it unchanged. + /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged. pub fn offset_i8_as_u8(self: &Arc) -> Arc { let mut t = if let DatumType::I8 = self.dt.unquantized() { self.to_array_view::().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()