Skip to content

Commit

Permalink
fix zeropoint rank
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 20, 2023
1 parent 5e5c8ad commit e62c3d1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl EvalOp for TypedBinOp {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (a, b) = args_2!(inputs);
debug_assert_eq!(a.rank(), b.rank());
ensure!(a.rank() == b.rank());
Ok(tvec!(self.0.eval(a, b)?.into_tvalue()))
}
}
Expand Down
5 changes: 2 additions & 3 deletions core/src/ops/matmul/mir_quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ pub(crate) fn wire_offset_u8_as_i8(
zero_point: &mut OutletId,
zero_point_name: &str,
) -> TractResult<OutletId> {
let fact = model.outlet_fact(matrix)?;
if let DatumType::U8 = fact.datum_type.unquantized() {
if let DatumType::U8 = model.outlet_fact(matrix)?.datum_type.unquantized() {
match model.outlet_fact(*zero_point)?.datum_type.unquantized() {
DatumType::U8 => {
*zero_point = model.wire_node(
Expand All @@ -34,7 +33,7 @@ pub(crate) fn wire_offset_u8_as_i8(
)?[0];
let cst = model.add_const(
format!("{model_name}.offset_{zero_point_name}_as_i8.min"),
rctensor0(-128i32),
tensor0(-128i32).broadcast_into_rank(model.outlet_fact(*zero_point)?.rank())?
)?;
*zero_point = model.wire_node(
format!("{model_name}.offset_{zero_point_name}_as_i8"),
Expand Down

0 comments on commit e62c3d1

Please sign in to comment.