diff --git a/onnx/src/ops/resize.rs b/onnx/src/ops/resize.rs index 1a527cf596..69715406f4 100644 --- a/onnx/src/ops/resize.rs +++ b/onnx/src/ops/resize.rs @@ -4,39 +4,68 @@ use std::hash::Hash; use tract_hir::internal::*; pub fn resize( - _ctx: &ParsingContext, + ctx: &ParsingContext, node: &NodeProto, ) -> TractResult<(Box, Vec)> { - let coord_transformer = - match node.get_attr_opt("coordinate_transformation_mode")?.unwrap_or("half_pixel") { - "align_corners" => CoordTransformer::AlignCorners, - "half_pixel" => CoordTransformer::HalfPixel, - "asymmetric" => CoordTransformer::Asymmetric, - s => todo!("coordinate_transformation_mode: {}", s), - }; - let interpolator = match node.get_attr_opt("mode")?.unwrap_or("nearest") { - "nearest" => Interpolator::Nearest, - "linear" => Interpolator::Linear, - s => todo!("mode: {}", s), - }; - let nearest = match node.get_attr_opt("nearest_mode")?.unwrap_or("round_prefer_floor") { - "floor" => Nearest::Floor, - "ceil" => Nearest::Ceil, - "round_prefer_floor" => Nearest::RoundPreferFloor, - "round_prefer_ceil" => Nearest::RoundPreferCeil, - s => todo!("nearest_mode: {}", s), + let op = match ctx.onnx_operator_set_version { + 10 => resize_10(node)?, + 11..=12 => resize_11(node)?, + 13..=17 => resize_13(node)?, + 18.. => resize_18(node)?, + v => bail!("Unsupported operator set for Resize operator ({v})"), }; - let mut options = crate::model::optional_inputs(node).skip(2); - Ok(( - Box::new(Resize { - optional_scales_input: options.next().unwrap(), - optional_sizes_input: options.next().unwrap(), - coord_transformer, - interpolator, - nearest, - }), - vec![], - )) + Ok((Box::new(op), vec![])) +} + +fn resize_10(node: &NodeProto) -> TractResult { + Ok(Resize { + axes: None, + optional_roi_input: None, + optional_scales_input: Some(1), + optional_sizes_input: None, + coord_transformer: CoordTransformer::from_node(node)?, + interpolator: Interpolator::from_node(node)?, + nearest: Nearest::from_node(node)?, + }) +} + +fn resize_11(node: &NodeProto) -> TractResult { + let mut options = crate::model::optional_inputs(node).skip(3); + Ok(Resize { + axes: None, + optional_roi_input: Some(1), + optional_scales_input: Some(2), + optional_sizes_input: options.next().unwrap(), + coord_transformer: CoordTransformer::from_node(node)?, + interpolator: Interpolator::from_node(node)?, + nearest: Nearest::from_node(node)?, + }) +} + +fn resize_13(node: &NodeProto) -> TractResult { + let mut options = crate::model::optional_inputs(node).skip(1); + Ok(Resize { + axes: None, + optional_roi_input: options.next().unwrap(), + optional_scales_input: options.next().unwrap(), + optional_sizes_input: options.next().unwrap(), + coord_transformer: CoordTransformer::from_node(node)?, + interpolator: Interpolator::from_node(node)?, + nearest: Nearest::from_node(node)?, + }) +} + +fn resize_18(node: &NodeProto) -> TractResult { + let mut options = crate::model::optional_inputs(node).skip(1); + Ok(Resize { + axes: node.get_attr_opt_vec("axes")?, + optional_roi_input: options.next().unwrap(), + optional_scales_input: options.next().unwrap(), + optional_sizes_input: options.next().unwrap(), + coord_transformer: CoordTransformer::from_node(node)?, + interpolator: Interpolator::from_node(node)?, + nearest: Nearest::from_node(node)?, + }) } #[derive(Clone, Debug, Hash)] @@ -56,6 +85,15 @@ impl CoordTransformer { CoordTransformer::Asymmetric => (x_out as f32) / scale, } } + + fn from_node(node: &NodeProto) -> TractResult { + Ok(match node.get_attr_opt("coordinate_transformation_mode")?.unwrap_or("half_pixel") { + "align_corners" => CoordTransformer::AlignCorners, + "half_pixel" => CoordTransformer::HalfPixel, + "asymmetric" => CoordTransformer::Asymmetric, + s => bail!("coordinate_transformation_mode: {}", s), + }) + } } #[derive(Clone, Debug, Hash)] @@ -88,6 +126,14 @@ impl Interpolator { }, } } + + fn from_node(node: &NodeProto) -> TractResult { + Ok(match node.get_attr_opt("mode")?.unwrap_or("nearest") { + "nearest" => Interpolator::Nearest, + "linear" => Interpolator::Linear, + s => bail!("mode: {}", s), + }) + } } #[derive(Clone, Copy, Debug, Hash)] @@ -98,11 +144,25 @@ enum Nearest { RoundPreferCeil, } +impl Nearest { + fn from_node(node: &NodeProto) -> TractResult { + Ok(match node.get_attr_opt("nearest_mode")?.unwrap_or("round_prefer_floor") { + "floor" => Nearest::Floor, + "ceil" => Nearest::Ceil, + "round_prefer_floor" => Nearest::RoundPreferFloor, + "round_prefer_ceil" => Nearest::RoundPreferCeil, + s => bail!("nearest_mode: {}", s), + }) + } +} + #[derive(Clone, new, Debug, Hash)] struct Resize { + axes: Option>, coord_transformer: CoordTransformer, interpolator: Interpolator, nearest: Nearest, + optional_roi_input: Option, optional_scales_input: Option, optional_sizes_input: Option, } @@ -171,34 +231,33 @@ impl EvalOp for Resize { scales.map(|t| &**t), sizes.map(|t| &**t), )?; + let scales: TVec = if let Some(scales) = scales { + scales.as_slice::()?.into() + } else { + output_shape.iter().zip(inputs[0].shape()).map(|(o, i)| *o as f32 / *i as f32).collect() + }; let mut data = inputs.remove(0).into_tensor().into_array::()?; - for axis in 0..data.ndim() { - #[allow(clippy::comparison_chain)] - if output_shape[axis] == data.shape()[axis] { - continue; - } else if output_shape[axis] > data.shape()[axis] { - let scale = output_shape[axis] as f32 / data.shape()[axis] as f32; - let mut new_shape: TVec = data.shape().into(); - new_shape[axis] = output_shape[axis]; - data = tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 { - let x_out = co_o[axis]; - let x_in = self.coord_transformer.transform( - x_out, - scale, - data.shape()[axis], - new_shape[axis], - ); - let mut co_i = co_o; - let x_left = (x_in as usize).clamp(0, data.shape()[axis] - 1); - co_i[axis] = x_left; - let y_left = data[&co_i]; - let x_right = (x_left + 1).min(data.shape()[axis] - 1); - co_i[axis] = x_right; - let y_right = data[&co_i]; - let x_frac = x_in - x_left as f32; - self.interpolator.interpolate(y_left, y_right, x_frac, self.nearest) - }) - } + for (axis, scale) in scales.into_iter().enumerate().filter(|(_, s)| *s != 1.0) { + let mut new_shape: TVec = data.shape().into(); + new_shape[axis] = output_shape[axis]; + data = tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 { + let x_out = co_o[axis]; + let x_in = self.coord_transformer.transform( + x_out, + scale, + data.shape()[axis], + new_shape[axis], + ); + let mut co_i = co_o; + let x_left = (x_in as usize).clamp(0, data.shape()[axis] - 1); + co_i[axis] = x_left; + let y_left = data[&co_i]; + let x_right = (x_left + 1).min(data.shape()[axis] - 1); + co_i[axis] = x_right; + let y_right = data[&co_i]; + let x_frac = x_in - x_left as f32; + self.interpolator.interpolate(y_left, y_right, x_frac, self.nearest) + }) } Ok(tvec!(data.into_tvalue())) } @@ -214,23 +273,26 @@ impl InferenceRulesOp for Resize { check_output_arity(outputs, 1)?; s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?; s.equals(&inputs[0].rank, &outputs[0].rank)?; - if inputs.len() == 3 && self.optional_scales_input == Some(2) { + if self.optional_scales_input.is_some() { rules_with_scales(self, s, inputs, outputs) - } else if inputs.len() == 3 && self.optional_sizes_input == Some(2) { + } else if self.optional_sizes_input.is_some() { rules_with_sizes(self, s, inputs, outputs) } else { + /* // bogus 4 inputs case s.given_2( - &inputs[0].rank, - &inputs[self.optional_scales_input.unwrap()].shape, - move |s, input_rank, scale_shape| { - if scale_shape.len() == 0 || scale_shape[0] != input_rank.to_dim() { - rules_with_sizes(self, s, inputs, outputs) - } else { - rules_with_scales(self, s, inputs, outputs) - } - }, + &inputs[0].rank, + &inputs[self.optional_scales_input.unwrap()].shape, + move |s, input_rank, scale_shape| { + if scale_shape.len() == 0 || scale_shape[0] != input_rank.to_dim() { + rules_with_sizes(self, s, inputs, outputs) + } else { + rules_with_scales(self, s, inputs, outputs) + } + }, ) + */ + todo!() } } @@ -283,6 +345,7 @@ impl TypedOp for Resize { as_op!(); fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let _roi = self.optional_roi_input.and_then(|ix| inputs.get(ix)); let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix)); let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix)); let output_shape = self.compute_output_shape( diff --git a/test-rt/suite-onnx/node.txt b/test-rt/suite-onnx/node.txt index 8941ff966c..9a169b11a6 100644 --- a/test-rt/suite-onnx/node.txt +++ b/test-rt/suite-onnx/node.txt @@ -664,6 +664,7 @@ test_reshape_reordered_dims test_reshape_reordered_last_dims input:data test_reshape_zero_and_negative_dim input:data test_reshape_zero_dim input:data +test_resize_downsample_scales_linear input:X test_resize_upsample_scales_linear_align_corners input:X not-nnef test_rnn_seq_length test_round diff --git a/test-rt/test-onnx-nnef-cycle/build.rs b/test-rt/test-onnx-nnef-cycle/build.rs index fb5230e1d6..452db1db5d 100644 --- a/test-rt/test-onnx-nnef-cycle/build.rs +++ b/test-rt/test-onnx-nnef-cycle/build.rs @@ -49,6 +49,7 @@ test_qlinearmatmul_2D test_qlinearmatmul_3D test_reshape_reordered_dims test_resize_upsample_scales_linear_align_corners +test_resize_downsample_scales_linear test_unsqueeze "# .trim()