Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize fixes #1236

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 131 additions & 68 deletions onnx/src/ops/resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,68 @@ use std::hash::Hash;
use tract_hir::internal::*;

pub fn resize(
_ctx: &ParsingContext,
ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let coord_transformer =
match node.get_attr_opt("coordinate_transformation_mode")?.unwrap_or("half_pixel") {
"align_corners" => CoordTransformer::AlignCorners,
"half_pixel" => CoordTransformer::HalfPixel,
"asymmetric" => CoordTransformer::Asymmetric,
s => todo!("coordinate_transformation_mode: {}", s),
};
let interpolator = match node.get_attr_opt("mode")?.unwrap_or("nearest") {
"nearest" => Interpolator::Nearest,
"linear" => Interpolator::Linear,
s => todo!("mode: {}", s),
};
let nearest = match node.get_attr_opt("nearest_mode")?.unwrap_or("round_prefer_floor") {
"floor" => Nearest::Floor,
"ceil" => Nearest::Ceil,
"round_prefer_floor" => Nearest::RoundPreferFloor,
"round_prefer_ceil" => Nearest::RoundPreferCeil,
s => todo!("nearest_mode: {}", s),
let op = match ctx.onnx_operator_set_version {
10 => resize_10(node)?,
11..=12 => resize_11(node)?,
13..=17 => resize_13(node)?,
18.. => resize_18(node)?,
v => bail!("Unsupported operator set for Resize operator ({v})"),
};
let mut options = crate::model::optional_inputs(node).skip(2);
Ok((
Box::new(Resize {
optional_scales_input: options.next().unwrap(),
optional_sizes_input: options.next().unwrap(),
coord_transformer,
interpolator,
nearest,
}),
vec![],
))
Ok((Box::new(op), vec![]))
}

fn resize_10(node: &NodeProto) -> TractResult<Resize> {
Ok(Resize {
axes: None,
optional_roi_input: None,
optional_scales_input: Some(1),
optional_sizes_input: None,
coord_transformer: CoordTransformer::from_node(node)?,
interpolator: Interpolator::from_node(node)?,
nearest: Nearest::from_node(node)?,
})
}

fn resize_11(node: &NodeProto) -> TractResult<Resize> {
let mut options = crate::model::optional_inputs(node).skip(3);
Ok(Resize {
axes: None,
optional_roi_input: Some(1),
optional_scales_input: Some(2),
optional_sizes_input: options.next().unwrap(),
coord_transformer: CoordTransformer::from_node(node)?,
interpolator: Interpolator::from_node(node)?,
nearest: Nearest::from_node(node)?,
})
}

fn resize_13(node: &NodeProto) -> TractResult<Resize> {
let mut options = crate::model::optional_inputs(node).skip(1);
Ok(Resize {
axes: None,
optional_roi_input: options.next().unwrap(),
optional_scales_input: options.next().unwrap(),
optional_sizes_input: options.next().unwrap(),
coord_transformer: CoordTransformer::from_node(node)?,
interpolator: Interpolator::from_node(node)?,
nearest: Nearest::from_node(node)?,
})
}

fn resize_18(node: &NodeProto) -> TractResult<Resize> {
let mut options = crate::model::optional_inputs(node).skip(1);
Ok(Resize {
axes: node.get_attr_opt_vec("axes")?,
optional_roi_input: options.next().unwrap(),
optional_scales_input: options.next().unwrap(),
optional_sizes_input: options.next().unwrap(),
coord_transformer: CoordTransformer::from_node(node)?,
interpolator: Interpolator::from_node(node)?,
nearest: Nearest::from_node(node)?,
})
}

#[derive(Clone, Debug, Hash)]
Expand All @@ -56,6 +85,15 @@ impl CoordTransformer {
CoordTransformer::Asymmetric => (x_out as f32) / scale,
}
}

fn from_node(node: &NodeProto) -> TractResult<CoordTransformer> {
Ok(match node.get_attr_opt("coordinate_transformation_mode")?.unwrap_or("half_pixel") {
"align_corners" => CoordTransformer::AlignCorners,
"half_pixel" => CoordTransformer::HalfPixel,
"asymmetric" => CoordTransformer::Asymmetric,
s => bail!("coordinate_transformation_mode: {}", s),
})
}
}

#[derive(Clone, Debug, Hash)]
Expand Down Expand Up @@ -88,6 +126,14 @@ impl Interpolator {
},
}
}

fn from_node(node: &NodeProto) -> TractResult<Interpolator> {
Ok(match node.get_attr_opt("mode")?.unwrap_or("nearest") {
"nearest" => Interpolator::Nearest,
"linear" => Interpolator::Linear,
s => bail!("mode: {}", s),
})
}
}

#[derive(Clone, Copy, Debug, Hash)]
Expand All @@ -98,11 +144,25 @@ enum Nearest {
RoundPreferCeil,
}

impl Nearest {
fn from_node(node: &NodeProto) -> TractResult<Nearest> {
Ok(match node.get_attr_opt("nearest_mode")?.unwrap_or("round_prefer_floor") {
"floor" => Nearest::Floor,
"ceil" => Nearest::Ceil,
"round_prefer_floor" => Nearest::RoundPreferFloor,
"round_prefer_ceil" => Nearest::RoundPreferCeil,
s => bail!("nearest_mode: {}", s),
})
}
}

#[derive(Clone, new, Debug, Hash)]
struct Resize {
axes: Option<Vec<i64>>,
coord_transformer: CoordTransformer,
interpolator: Interpolator,
nearest: Nearest,
optional_roi_input: Option<usize>,
optional_scales_input: Option<usize>,
optional_sizes_input: Option<usize>,
}
Expand Down Expand Up @@ -171,34 +231,33 @@ impl EvalOp for Resize {
scales.map(|t| &**t),
sizes.map(|t| &**t),
)?;
let scales: TVec<f32> = if let Some(scales) = scales {
scales.as_slice::<f32>()?.into()
} else {
output_shape.iter().zip(inputs[0].shape()).map(|(o, i)| *o as f32 / *i as f32).collect()
};
let mut data = inputs.remove(0).into_tensor().into_array::<f32>()?;
for axis in 0..data.ndim() {
#[allow(clippy::comparison_chain)]
if output_shape[axis] == data.shape()[axis] {
continue;
} else if output_shape[axis] > data.shape()[axis] {
let scale = output_shape[axis] as f32 / data.shape()[axis] as f32;
let mut new_shape: TVec<usize> = data.shape().into();
new_shape[axis] = output_shape[axis];
data = tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
let x_out = co_o[axis];
let x_in = self.coord_transformer.transform(
x_out,
scale,
data.shape()[axis],
new_shape[axis],
);
let mut co_i = co_o;
let x_left = (x_in as usize).clamp(0, data.shape()[axis] - 1);
co_i[axis] = x_left;
let y_left = data[&co_i];
let x_right = (x_left + 1).min(data.shape()[axis] - 1);
co_i[axis] = x_right;
let y_right = data[&co_i];
let x_frac = x_in - x_left as f32;
self.interpolator.interpolate(y_left, y_right, x_frac, self.nearest)
})
}
for (axis, scale) in scales.into_iter().enumerate().filter(|(_, s)| *s != 1.0) {
let mut new_shape: TVec<usize> = data.shape().into();
new_shape[axis] = output_shape[axis];
data = tract_ndarray::ArrayD::from_shape_fn(&*new_shape, |co_o| -> f32 {
let x_out = co_o[axis];
let x_in = self.coord_transformer.transform(
x_out,
scale,
data.shape()[axis],
new_shape[axis],
);
let mut co_i = co_o;
let x_left = (x_in as usize).clamp(0, data.shape()[axis] - 1);
co_i[axis] = x_left;
let y_left = data[&co_i];
let x_right = (x_left + 1).min(data.shape()[axis] - 1);
co_i[axis] = x_right;
let y_right = data[&co_i];
let x_frac = x_in - x_left as f32;
self.interpolator.interpolate(y_left, y_right, x_frac, self.nearest)
})
}
Ok(tvec!(data.into_tvalue()))
}
Expand All @@ -214,23 +273,26 @@ impl InferenceRulesOp for Resize {
check_output_arity(outputs, 1)?;
s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
s.equals(&inputs[0].rank, &outputs[0].rank)?;
if inputs.len() == 3 && self.optional_scales_input == Some(2) {
if self.optional_scales_input.is_some() {
rules_with_scales(self, s, inputs, outputs)
} else if inputs.len() == 3 && self.optional_sizes_input == Some(2) {
} else if self.optional_sizes_input.is_some() {
rules_with_sizes(self, s, inputs, outputs)
} else {
/*
// bogus 4 inputs case
s.given_2(
&inputs[0].rank,
&inputs[self.optional_scales_input.unwrap()].shape,
move |s, input_rank, scale_shape| {
if scale_shape.len() == 0 || scale_shape[0] != input_rank.to_dim() {
rules_with_sizes(self, s, inputs, outputs)
} else {
rules_with_scales(self, s, inputs, outputs)
}
},
&inputs[0].rank,
&inputs[self.optional_scales_input.unwrap()].shape,
move |s, input_rank, scale_shape| {
if scale_shape.len() == 0 || scale_shape[0] != input_rank.to_dim() {
rules_with_sizes(self, s, inputs, outputs)
} else {
rules_with_scales(self, s, inputs, outputs)
}
},
)
*/
todo!()
}
}

Expand Down Expand Up @@ -283,6 +345,7 @@ impl TypedOp for Resize {
as_op!();

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let _roi = self.optional_roi_input.and_then(|ix| inputs.get(ix));
let scales = self.optional_scales_input.and_then(|ix| inputs.get(ix));
let sizes = self.optional_sizes_input.and_then(|ix| inputs.get(ix));
let output_shape = self.compute_output_shape(
Expand Down
1 change: 1 addition & 0 deletions test-rt/suite-onnx/node.txt
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ test_reshape_reordered_dims
test_reshape_reordered_last_dims input:data
test_reshape_zero_and_negative_dim input:data
test_reshape_zero_dim input:data
test_resize_downsample_scales_linear input:X
test_resize_upsample_scales_linear_align_corners input:X not-nnef
test_rnn_seq_length
test_round
Expand Down
1 change: 1 addition & 0 deletions test-rt/test-onnx-nnef-cycle/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test_qlinearmatmul_2D
test_qlinearmatmul_3D
test_reshape_reordered_dims
test_resize_upsample_scales_linear_align_corners
test_resize_downsample_scales_linear
test_unsqueeze
"#
.trim()
Expand Down