Skip to content

Commit

Permalink
cast qp into the right type (i32 or f32)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 20, 2023
1 parent af31422 commit 55df767
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions hir/src/ops/cnn/conv.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::infer::*;
use crate::internal::*;
use crate::ops::cast::cast;

use tract_core::ops::cnn::conv::ConvUnary;
use tract_core::ops::cnn::conv::KernelFormat;
Expand Down Expand Up @@ -200,7 +201,7 @@ impl Expansion for Conv {
} else {
None
};
let mut wires = vec!(inputs[0]);
let mut wires = vec![inputs[0]];
let pool_spec = PoolSpec {
data_format: self.data_format,
padding: self.padding.clone(),
Expand All @@ -218,16 +219,23 @@ impl Expansion for Conv {
|| self.y_scale_input.is_some();
let output_type = self.override_output_datum_type.unwrap_or(input.datum_type);
if quantized {
let zero = model
.add_const(format!("{prefix}.zero"), Tensor::zero_scalar_dt(input.datum_type)?)?;
let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i32))?;
let one = model.add_const(format!("{prefix}.one"), tensor0(1f32))?;

wires.push(self.k_zero_point_input.map(|i| inputs[i]).unwrap_or(zero));
wires.push(self.k_scale_input.map(|i| inputs[i]).unwrap_or(one));
wires.push(self.x_zero_point_input.map(|i| inputs[i]).unwrap_or(zero));
wires.push(self.x_scale_input.map(|i| inputs[i]).unwrap_or(one));
wires.push(self.y_zero_point_input.map(|i| inputs[i]).unwrap_or(zero));
wires.push(self.y_scale_input.map(|i| inputs[i]).unwrap_or(one));
macro_rules! qp {
($id: ident, $def: expr, $ty: ty) => {
let wire = self.$id.map(|i| inputs[i]).unwrap_or($def);
let wire = model.wire_node(format!("{prefix}.cast_{}", stringify!($id)), cast(<$ty>::datum_type()), &[wire])?[0];
wires.push(wire);
}
}

qp!(k_zero_point_input, zero, i32);
qp!(k_scale_input, one, f32);
qp!(x_zero_point_input, zero, i32);
qp!(x_scale_input, one, f32);
qp!(y_zero_point_input, zero, i32);
qp!(y_scale_input, one, f32);
};

let reduced = ConvUnary::new(
Expand Down

0 comments on commit 55df767

Please sign in to comment.