From 55df767c674b8e9696591e7865f7ef4cd9f09db3 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 31 Oct 2023 16:01:07 +0100 Subject: [PATCH] cast qp into the right type (i32 or f32) --- hir/src/ops/cnn/conv.rs | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/hir/src/ops/cnn/conv.rs b/hir/src/ops/cnn/conv.rs index 33f7aa8c9a..ecf7d2a727 100644 --- a/hir/src/ops/cnn/conv.rs +++ b/hir/src/ops/cnn/conv.rs @@ -1,5 +1,6 @@ use crate::infer::*; use crate::internal::*; +use crate::ops::cast::cast; use tract_core::ops::cnn::conv::ConvUnary; use tract_core::ops::cnn::conv::KernelFormat; @@ -200,7 +201,7 @@ impl Expansion for Conv { } else { None }; - let mut wires = vec!(inputs[0]); + let mut wires = vec![inputs[0]]; let pool_spec = PoolSpec { data_format: self.data_format, padding: self.padding.clone(), @@ -218,16 +219,23 @@ impl Expansion for Conv { || self.y_scale_input.is_some(); let output_type = self.override_output_datum_type.unwrap_or(input.datum_type); if quantized { - let zero = model - .add_const(format!("{prefix}.zero"), Tensor::zero_scalar_dt(input.datum_type)?)?; + let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i32))?; let one = model.add_const(format!("{prefix}.one"), tensor0(1f32))?; - wires.push(self.k_zero_point_input.map(|i| inputs[i]).unwrap_or(zero)); - wires.push(self.k_scale_input.map(|i| inputs[i]).unwrap_or(one)); - wires.push(self.x_zero_point_input.map(|i| inputs[i]).unwrap_or(zero)); - wires.push(self.x_scale_input.map(|i| inputs[i]).unwrap_or(one)); - wires.push(self.y_zero_point_input.map(|i| inputs[i]).unwrap_or(zero)); - wires.push(self.y_scale_input.map(|i| inputs[i]).unwrap_or(one)); + macro_rules! qp { + ($id: ident, $def: expr, $ty: ty) => { + let wire = self.$id.map(|i| inputs[i]).unwrap_or($def); + let wire = model.wire_node(format!("{prefix}.cast_{}", stringify!($id)), cast(<$ty>::datum_type()), &[wire])?[0]; + wires.push(wire); + } + } + + qp!(k_zero_point_input, zero, i32); + qp!(k_scale_input, one, f32); + qp!(x_zero_point_input, zero, i32); + qp!(x_scale_input, one, f32); + qp!(y_zero_point_input, zero, i32); + qp!(y_scale_input, one, f32); }; let reduced = ConvUnary::new(