From b13bb05b4763939288eca07cb19630963777e71d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Sat, 27 Jan 2024 10:34:31 +0100 Subject: [PATCH] fix --- onnx/src/ops/resize.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnx/src/ops/resize.rs b/onnx/src/ops/resize.rs index 69715406f4..5a534189cc 100644 --- a/onnx/src/ops/resize.rs +++ b/onnx/src/ops/resize.rs @@ -2,12 +2,13 @@ use crate::model::ParsingContext; use crate::pb::*; use std::hash::Hash; use tract_hir::internal::*; +use tract_nnef::tract_num_traits::Zero; pub fn resize( ctx: &ParsingContext, node: &NodeProto, ) -> TractResult<(Box, Vec)> { - let op = match ctx.onnx_operator_set_version { + let op = match dbg!(ctx.onnx_operator_set_version) { 10 => resize_10(node)?, 11..=12 => resize_11(node)?, 13..=17 => resize_13(node)?, @@ -273,8 +274,14 @@ impl InferenceRulesOp for Resize { check_output_arity(outputs, 1)?; s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?; s.equals(&inputs[0].rank, &outputs[0].rank)?; - if self.optional_scales_input.is_some() { - rules_with_scales(self, s, inputs, outputs) + if let Some(scales) = self.optional_scales_input { + s.given(&inputs[scales].shape[0], move |s, len| { + if len.is_zero() { + rules_with_sizes(self, s, inputs, outputs) + } else { + rules_with_scales(self, s, inputs, outputs) + } + }) } else if self.optional_sizes_input.is_some() { rules_with_sizes(self, s, inputs, outputs) } else {