From cd023dd91c9b3e581aa6225691e601689fc8f565 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Mon, 20 Nov 2023 13:29:20 +0100 Subject: [PATCH] fix zeropoint rank --- core/src/ops/binary.rs | 2 +- core/src/ops/matmul/mir_quant.rs | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 2ec98f063b..30260f9572 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -144,7 +144,7 @@ impl EvalOp for TypedBinOp { fn eval(&self, inputs: TVec) -> TractResult> { let (a, b) = args_2!(inputs); - debug_assert_eq!(a.rank(), b.rank()); + ensure!(a.rank() == b.rank()); Ok(tvec!(self.0.eval(a, b)?.into_tvalue())) } } diff --git a/core/src/ops/matmul/mir_quant.rs b/core/src/ops/matmul/mir_quant.rs index 9f8df6cdc0..562e2b8559 100644 --- a/core/src/ops/matmul/mir_quant.rs +++ b/core/src/ops/matmul/mir_quant.rs @@ -16,8 +16,7 @@ pub(crate) fn wire_offset_u8_as_i8( zero_point: &mut OutletId, zero_point_name: &str, ) -> TractResult { - let fact = model.outlet_fact(matrix)?; - if let DatumType::U8 = fact.datum_type.unquantized() { + if let DatumType::U8 = model.outlet_fact(matrix)?.datum_type.unquantized() { match model.outlet_fact(*zero_point)?.datum_type.unquantized() { DatumType::U8 => { *zero_point = model.wire_node( @@ -34,7 +33,7 @@ pub(crate) fn wire_offset_u8_as_i8( )?[0]; let cst = model.add_const( format!("{model_name}.offset_{zero_point_name}_as_i8.min"), - rctensor0(-128i32), + tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())? )?; *zero_point = model.wire_node( format!("{model_name}.offset_{zero_point_name}_as_i8"),