From 8ae9522f43bfafa1e7ade46aa2f7dee4471c484f Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 17 Oct 2023 17:47:42 +0200 Subject: [PATCH 1/5] setup a mask for deconv pulsing mode --- core/src/ops/cnn/deconv/deconv_sum.rs | 12 +- pulse-opl/src/lib.rs | 3 + pulse-opl/src/mask.rs | 151 ++++++++++++++++++++++++++ pulse-opl/src/pad.rs | 54 ++++----- pulse/src/ops/array/mod.rs | 1 + pulse/src/ops/cnn/deconv.rs | 17 ++- pulse/src/ops/mask.rs | 0 pulse/src/ops/mod.rs | 1 + 8 files changed, 198 insertions(+), 41 deletions(-) create mode 100644 pulse-opl/src/mask.rs create mode 100644 pulse/src/ops/mask.rs diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 16f47d7b03..dc8a835f8f 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -188,9 +188,7 @@ impl DeconvSum { }; unsafe { let value = *n_o_hkwk_hw.uget((n, o, kx, gx)); - if !value.is_nan() { - *output.uget_mut(coord) += value; - } + *output.uget_mut(coord) += value; } } } @@ -377,9 +375,7 @@ impl DeconvSum { }; unsafe { let value = *n_o_hkwk_hw.uget((n, o, kix, gix)); - if !value.is_nan() { - *output.uget_mut(coord) += value; - } + *output.uget_mut(coord) += value; } } } @@ -425,9 +421,7 @@ impl DeconvSum { let ocoord = self.pool_spec.data_format.with_n().from_n_c_hw(n, o, ocoord)?; let value = n_o_hkwk_hw[(n, o, kix, gix)]; - if !value.is_nan() { - output[&*ocoord.shape] += value - } + output[&*ocoord.shape] += value } } } diff --git a/pulse-opl/src/lib.rs b/pulse-opl/src/lib.rs index 2cca6947e7..574797c621 100644 --- a/pulse-opl/src/lib.rs +++ b/pulse-opl/src/lib.rs @@ -3,6 +3,7 @@ use tract_nnef::internal::*; mod concat; mod deconv_delay; mod delay; +mod mask; mod pad; mod slice; @@ -17,6 +18,7 @@ pub mod prelude { pub mod ops { pub use super::deconv_delay::DeconvDelay; pub use super::delay::{ Delay, DelayState }; + pub use super::mask::PulseMask; pub use super::pad::PulsePad; pub use super::slice::PulsedAxisSlice; } @@ -41,6 +43,7 @@ pub fn tract_nnef_registry() -> Registry { let mut reg = Registry::new("tract_pulse"); reg.aliases.push("pulse".into()); delay::register(&mut reg); + mask::register(&mut reg); pad::register(&mut reg); reg } diff --git a/pulse-opl/src/mask.rs b/pulse-opl/src/mask.rs new file mode 100644 index 0000000000..b13c412edd --- /dev/null +++ b/pulse-opl/src/mask.rs @@ -0,0 +1,151 @@ +use tract_nnef::internal::*; +use tract_nnef::ser::tdim; +use tract_nnef::tract_core::trivial_op_state_freeeze; + +pub fn register(registry: &mut Registry) { + registry.register_primitive( + "tract_pulse_mask", + &[ + TypeName::Scalar.tensor().named("input"), + TypeName::Integer.named("axis"), + TypeName::Integer.named("begin"), + TypeName::Integer.named("end"), + TypeName::Scalar.named("value"), + ], + &[("output", TypeName::Scalar.tensor())], + deser, + ); + registry.register_dumper(TypeId::of::(), ser) +} + +fn ser(ast: &mut IntoAst, node: &TypedNode) -> TractResult>> { + let op = node.op_as::().unwrap(); + let wire = ast.mapping[&node.inputs[0]].clone(); + let params = vec![ + ("axis", numeric(op.axis)), + ("begin", numeric(op.begin)), + ("end", tdim(&op.end)), + ("value", numeric(op.value.cast_to_scalar::())), + ]; + Ok(Some(invocation("tract_pulse_mask", &[wire], ¶ms))) +} + +fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let wire = invocation.named_arg_as(builder, "input")?; + let axis = invocation.named_arg_as(builder, "axis")?; + let begin = invocation.named_arg_as(builder, "begin")?; + let value: Tensor = tensor0(invocation.named_arg_as::(builder, "value")?); + let end = builder.allowing_new_symbols(|builder| { + TractResult::Ok(invocation.named_arg_as(builder, "end")?) + })?; + let op = PulseMask { axis, begin, end, value }; + builder.wire(op, &[wire]) +} + +#[derive(Debug, Clone, Default, Hash)] +struct PulseMaskOpState { + current_pos: usize, +} + +impl OpState for PulseMaskOpState { + fn eval( + &mut self, + session: &mut SessionState, + op: &dyn Op, + inputs: TVec, + ) -> TractResult> { + let input = args_1!(inputs).into_tensor(); + let op = op.downcast_ref::().ok_or_else(|| format_err!("Wrong Op type"))?; + let tensor = self.pad(session, op, input)?; + Ok(tvec!(tensor.into_tvalue())) + } +} + +impl PulseMaskOpState { + fn pad( + &mut self, + session: &SessionState, + op: &PulseMask, + mut input: Tensor, + ) -> TractResult { + let pulse = input.shape()[op.axis]; + let pulse_begin = self.current_pos; + let pulse_end = self.current_pos + pulse; + self.current_pos += pulse; + let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(std::usize::MAX); + + // pulse is entirely in valid input, just forward + if pulse_begin >= op.begin && pulse_end <= end { + return Ok(input); + } + + if pulse_begin < op.begin { + let fill_up_to = (op.begin - pulse_begin).min(pulse); + unsafe { + dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())( + &mut input, + &op.value, + op.axis, + 0..fill_up_to + )) + }; + } + if pulse_end > end { + let fill_from = pulse - (pulse_end - end).min(pulse); + unsafe { + dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())( + &mut input, + &op.value, + op.axis, + fill_from..pulse + )) + } + } + + Ok(input) + } +} + +#[derive(Debug, Clone, Default, Hash)] +pub struct PulseMask { + pub axis: usize, + pub begin: usize, + pub end: TDim, + pub value: Tensor, +} + +impl Op for PulseMask { + fn name(&self) -> Cow { + "PulseMask".into() + } + + fn info(&self) -> TractResult> { + Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)]) + } + + op_as_typed_op!(); +} + +impl EvalOp for PulseMask { + fn is_stateless(&self) -> bool { + false + } + + fn state( + &self, + _session: &mut SessionState, + _node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::::default())) + } +} + +impl TypedOp for PulseMask { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(inputs[0].clone())) + } + + as_op!(); +} + +trivial_op_state_freeeze!(PulseMaskOpState); diff --git a/pulse-opl/src/pad.rs b/pulse-opl/src/pad.rs index 987962eee3..0468bf9148 100644 --- a/pulse-opl/src/pad.rs +++ b/pulse-opl/src/pad.rs @@ -64,6 +64,29 @@ fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractRe builder.wire(op, &[wire]) } +pub(crate) unsafe fn fill_slice_constant( + data: &mut Tensor, + constant: &Tensor, + axis: usize, + range: std::ops::Range, +) { + let c = constant.to_scalar_unchecked::(); + data.to_array_view_mut_unchecked::().slice_axis_mut(Axis(axis), range.into()).fill(*c); +} + +unsafe fn fill_slice_with_frame( + data: &mut Tensor, + axis: usize, + valid: &Tensor, + range: std::ops::Range, +) { + let mut data = data.to_array_view_mut_unchecked::(); + let valid = valid.to_array_view_unchecked::(); + for i in range { + data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid); + } +} + #[derive(Debug, Clone, Default, Hash)] struct PulsePadOpState { current_pos: usize, @@ -91,29 +114,6 @@ impl PulsePadOpState { Some(data.index_axis(Axis(op.axis), frame).to_owned().into_tensor()); } - unsafe fn fill_slice_constant( - data: &mut Tensor, - constant: &Tensor, - axis: usize, - range: std::ops::Range, - ) { - let c = constant.to_scalar_unchecked::(); - data.to_array_view_mut_unchecked::().slice_axis_mut(Axis(axis), range.into()).fill(*c); - } - - unsafe fn fill_slice_with_frame( - data: &mut Tensor, - axis: usize, - valid: &Tensor, - range: std::ops::Range, - ) { - let mut data = data.to_array_view_mut_unchecked::(); - let valid = valid.to_array_view_unchecked::(); - for i in range { - data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid); - } - } - fn pad( &mut self, session: &SessionState, @@ -156,7 +156,7 @@ impl PulsePadOpState { let fill_up_to = (op.begin_input - pulse_begin).min(pulse); match &op.mode { PadMode::Constant(c) => unsafe { - dispatch_copy_by_size!(Self::fill_slice_constant(input.datum_type())( + dispatch_copy_by_size!(fill_slice_constant(input.datum_type())( &mut input, c, op.axis, @@ -166,7 +166,7 @@ impl PulsePadOpState { PadMode::Edge => { let frame = input.slice(op.axis, fill_up_to, fill_up_to + 1)?; unsafe { - dispatch_copy_by_size!(Self::fill_slice_with_frame(input.datum_type())( + dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())( &mut input, op.axis, &frame, @@ -181,7 +181,7 @@ impl PulsePadOpState { let fill_from = pulse - (pulse_end - end_input).min(pulse); match &op.mode { PadMode::Constant(c) => unsafe { - dispatch_copy_by_size!(Self::fill_slice_constant(input.datum_type())( + dispatch_copy_by_size!(fill_slice_constant(input.datum_type())( &mut input, c, op.axis, @@ -191,7 +191,7 @@ impl PulsePadOpState { PadMode::Edge => { let last_frame = self.last_valid_frame.as_ref().unwrap(); unsafe { - dispatch_copy_by_size!(Self::fill_slice_with_frame(input.datum_type())( + dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())( &mut input, op.axis, last_frame, diff --git a/pulse/src/ops/array/mod.rs b/pulse/src/ops/array/mod.rs index b47bacb7d6..e76aee6a1b 100644 --- a/pulse/src/ops/array/mod.rs +++ b/pulse/src/ops/array/mod.rs @@ -2,6 +2,7 @@ use crate::internal::*; mod broadcast; mod concat; +mod mask; mod pad; mod slice; diff --git a/pulse/src/ops/cnn/deconv.rs b/pulse/src/ops/cnn/deconv.rs index f92a09b799..3e3701f829 100644 --- a/pulse/src/ops/cnn/deconv.rs +++ b/pulse/src/ops/cnn/deconv.rs @@ -3,6 +3,7 @@ use tract_core::num_traits::Zero; use tract_core::ops::cnn::DeconvUnary; use tract_core::ops::cnn::PaddingSpec; use tract_pulse_opl::ops::DeconvDelay; +use tract_pulse_opl::ops::PulseMask; register_all!(DeconvUnary: pulsify); @@ -37,9 +38,15 @@ fn pulsify( let mut pulse_op = op.clone(); pulse_op.adjustments[geo_axis] = stride - 1; pulse_op.pool_spec.padding = PaddingSpec::Valid; - let deconv = - target.wire_node(format!("{}.deconv", node.name), pulse_op, &[mapping[&node.inputs[0]]])? - [0]; + let mut wire = tvec![mapping[&node.inputs[0]]]; + let mask = PulseMask { + axis: stream.axis, + begin: stream.delay, + end: stream.dim.clone() + stream.delay, + value: Tensor::zero_scalar_dt(fact.datum_type)?, + }; + wire = target.wire_node(format!("{}.mask", node.name), mask, &wire)?; + wire = target.wire_node(format!("{}.deconv", node.name), pulse_op, &wire)?; let overlap = overlap(stream.axis, op); let deconv_input_dim = (stream.dim.clone() - 1) * stride + 1; let output_shape = tract_core::ops::cnn::deconv::output_shape( @@ -56,7 +63,7 @@ fn pulsify( &op.pool_spec.strides(), &op.adjustments, )?; - let mut wire = target.wire_node( + wire = target.wire_node( &node.name, DeconvDelay { axis: stream.axis, @@ -67,7 +74,7 @@ fn pulsify( pulse: pulse.to_owned(), deconv_output_dim: output_shape[stream.axis].clone(), }, - &[deconv], + &wire, )?; for (geo_axis, padding) in paddings.iter().enumerate() { diff --git a/pulse/src/ops/mask.rs b/pulse/src/ops/mask.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pulse/src/ops/mod.rs b/pulse/src/ops/mod.rs index 3279771038..c3be10d78d 100644 --- a/pulse/src/ops/mod.rs +++ b/pulse/src/ops/mod.rs @@ -10,6 +10,7 @@ pub mod cnn; pub mod delay; pub mod downsample; pub mod dummy; +pub mod mask; pub mod scan; pub mod slice; pub mod source; From 1c7f30a97dfe2eb525761ac4c97a5f64467bab54 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 17 Oct 2023 17:56:16 +0200 Subject: [PATCH 2/5] clip --- pulse-opl/src/mask.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pulse-opl/src/mask.rs b/pulse-opl/src/mask.rs index b13c412edd..95a84d6b55 100644 --- a/pulse-opl/src/mask.rs +++ b/pulse-opl/src/mask.rs @@ -35,9 +35,7 @@ fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractRe let axis = invocation.named_arg_as(builder, "axis")?; let begin = invocation.named_arg_as(builder, "begin")?; let value: Tensor = tensor0(invocation.named_arg_as::(builder, "value")?); - let end = builder.allowing_new_symbols(|builder| { - TractResult::Ok(invocation.named_arg_as(builder, "end")?) - })?; + let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?; let op = PulseMask { axis, begin, end, value }; builder.wire(op, &[wire]) } From 6d4a0e7eb1d00709015a115e1c575ecb487a77c0 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 17 Oct 2023 18:06:36 +0200 Subject: [PATCH 3/5] missing file --- pulse/src/ops/array/mask.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 pulse/src/ops/array/mask.rs diff --git a/pulse/src/ops/array/mask.rs b/pulse/src/ops/array/mask.rs new file mode 100644 index 0000000000..de44846056 --- /dev/null +++ b/pulse/src/ops/array/mask.rs @@ -0,0 +1,11 @@ +use crate::internal::*; +use tract_pulse_opl::ops::PulseMask; + +impl PulsedOp for PulseMask { + fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult> { + Ok(inputs.iter().cloned().cloned().collect()) + } + + as_op!(); + pulsed_op_to_typed_op!(); +} From 212af2d3425acb624365f6370e966d82918bd021 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 18 Oct 2023 10:55:59 +0200 Subject: [PATCH 4/5] DeconvSum f16 on arm8.2 --- core/src/ops/cnn/deconv/deconv_sum.rs | 682 ++++++++++++++------------ 1 file changed, 371 insertions(+), 311 deletions(-) diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index dc8a835f8f..83f4f5fc77 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -102,355 +102,415 @@ impl DeconvSum { if !self.pool_spec.data_format.has_n() { tensor.insert_axis(0)?; } - dispatch_floatlike!(Self::eval_t(dt)( + eval( self, &input_shape, &output_shape, &spatial_output_details, &n_o_hkwk_hw, - &mut tensor - ))?; + &mut tensor, + )?; if !self.pool_spec.data_format.has_n() { tensor.remove_axis(0)?; } Ok(tvec!(tensor.into_tvalue())) } +} + +impl TypedOp for DeconvSum { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + let shape = super::output_shape(&self.pool_spec, &self.input_shape, &self.adjustments)?; + Ok(tvec!(inputs[0].datum_type.fact(&*shape))) + } - fn eval_t>( + fn concretize_dims( &self, - input_shape: &DataShape, - output_shape: &DataShape, - spatial_output_details: &[ComputedPaddedDim], - n_o_hkwk_hw: &Tensor, - output: &mut Tensor, - ) -> TractResult<()> { - let output = output.to_array_view_mut::()?; - let n_o_hkwk_hw: ArrayView4 = n_o_hkwk_hw.to_array_view::()?.into_dimensionality()?; - match input_shape.hw_rank() { - 1 => self.main_loop_1d( - input_shape, - output_shape, - spatial_output_details, - &n_o_hkwk_hw, - &mut output.into_dimensionality().unwrap(), - ), - 2 => self.main_loop_2d( - input_shape, - output_shape, - spatial_output_details, - &n_o_hkwk_hw, - &mut output.into_dimensionality().unwrap(), - ), - 3 => self.main_loop_3d( - input_shape, - output_shape, - spatial_output_details, - &n_o_hkwk_hw, - &mut output.into_dimensionality().unwrap(), - ), - _ => self.main_loop( + _source: &TypedModel, + node: &TypedNode, + target: &mut TypedModel, + mapping: &HashMap, + values: &SymbolValues, + ) -> TractResult> { + target.wire_node( + &node.name, + Self { input_shape: self.input_shape.eval(values)?.into_owned(), ..self.clone() }, + &[mapping[&node.inputs[0]]], + ) + } + + as_op!(); +} + +fn eval( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &Tensor, + output: &mut Tensor, +) -> TractResult<()> { + let dt = output.datum_type(); + unsafe { + #[cfg(target_arch = "aarch64")] + if dt == f16::datum_type() && tract_linalg::arm64::has_fp16() { + return eval_t_aarch64fp16::( + op, input_shape, output_shape, spatial_output_details, - &n_o_hkwk_hw, - &mut output.into_dimensionality().unwrap(), - ), + n_o_hkwk_hw, + output, + |a, b| tract_linalg::arm64::add_f16(a, b), + ); } + dispatch_floatlike!(eval_t_generic(dt)( + op, + input_shape, + output_shape, + spatial_output_details, + n_o_hkwk_hw, + output, + |a, b| a + b + )) } +} - pub fn main_loop_1d( - &self, - input_shape: &DataShape, - output_shape: &DataShape, - spatial_output_details: &[ComputedPaddedDim], - n_o_hkwk_hw: &ArrayView4, - output: &mut ArrayViewMut3, - ) -> TractResult<()> { - let n = *output_shape.n().unwrap_or(&1); - let kernel_len = self.pool_spec.kernel_shape[0]; - let geo_input_len = input_shape.hw_dims()[0]; - let geo_output_len = output_shape.hw_dims()[0]; - let x_stride = self.pool_spec.strides().as_ref()[0]; - let x_dil = self.pool_spec.dilations().as_ref()[0]; - let x_pad = spatial_output_details[0].pad_before as isize; - for n in 0..n { - for o in 0..*output_shape.c() { - for kx in 0..kernel_len { - for gx in 0..geo_input_len { - let x = (kx * x_dil + gx * x_stride) as isize - x_pad; - if x < 0 || x >= geo_output_len as isize { - continue; +macro_rules! impl_eval { + ($(#[$meta: meta])* $suffix: ident) => { + paste::paste! { + $(#[$meta])* + unsafe fn []>( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &Tensor, + output: &mut Tensor, + add: impl Fn(T, T) -> T + Copy + 'static, + ) -> TractResult<()> { + let output = output.to_array_view_mut::()?; + let n_o_hkwk_hw: ArrayView4 = n_o_hkwk_hw.to_array_view::()?.into_dimensionality()?; + match input_shape.hw_rank() { + 1 => []( + op, + input_shape, + output_shape, + spatial_output_details, + &n_o_hkwk_hw, + &mut output.into_dimensionality().unwrap(), + add, + ), + 2 => []( + op, + input_shape, + output_shape, + spatial_output_details, + &n_o_hkwk_hw, + &mut output.into_dimensionality().unwrap(), + add, + ), + 3 => []( + op, + input_shape, + output_shape, + spatial_output_details, + &n_o_hkwk_hw, + &mut output.into_dimensionality().unwrap(), + add, + ), + _ => []( + op, + input_shape, + output_shape, + spatial_output_details, + &n_o_hkwk_hw, + &mut output.into_dimensionality().unwrap(), + add, + ), } - let coord = if self.pool_spec.data_format.c_is_last() { - [n, x as usize, o] - } else { - [n, o, x as usize] - }; - unsafe { - let value = *n_o_hkwk_hw.uget((n, o, kx, gx)); - *output.uget_mut(coord) += value; + } + + pub fn []( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &ArrayView4, + output: &mut ArrayViewMut3, + add: impl Fn(T, T) -> T + Copy + 'static, + ) -> TractResult<()> { + let n = *output_shape.n().unwrap_or(&1); + let kernel_len = op.pool_spec.kernel_shape[0]; + let geo_input_len = input_shape.hw_dims()[0]; + let geo_output_len = output_shape.hw_dims()[0]; + let x_stride = op.pool_spec.strides().as_ref()[0]; + let x_dil = op.pool_spec.dilations().as_ref()[0]; + let x_pad = spatial_output_details[0].pad_before as isize; + for n in 0..n { + for o in 0..*output_shape.c() { + for kx in 0..kernel_len { + for gx in 0..geo_input_len { + let x = (kx * x_dil + gx * x_stride) as isize - x_pad; + if x < 0 || x >= geo_output_len as isize { + continue; + } + let coord = if op.pool_spec.data_format.c_is_last() { + [n, x as usize, o] + } else { + [n, o, x as usize] + }; + unsafe { + let value = *n_o_hkwk_hw.uget((n, o, kx, gx)); + *output.uget_mut(coord) = add(*output.uget(coord), value); + } + } + } } } + Ok(()) } - } - } - Ok(()) - } - pub fn main_loop_2d( - &self, - input_shape: &DataShape, - output_shape: &DataShape, - spatial_output_details: &[ComputedPaddedDim], - n_o_hkwk_hw: &ArrayView4, - output: &mut ArrayViewMut4, - ) -> TractResult<()> { - let n = *output_shape.n().unwrap_or(&1); - let x_stride = self.pool_spec.strides().as_ref()[0]; - let y_stride = self.pool_spec.strides().as_ref()[1]; - let x_dil = self.pool_spec.dilations().as_ref()[0]; - let y_dil = self.pool_spec.dilations().as_ref()[1]; - let x_pad = spatial_output_details[0].pad_before as isize; - let y_pad = spatial_output_details[1].pad_before as isize; - let output_c = *output_shape.c(); - let output_c_stride = *output_shape.c_stride() as isize; - let output_x_stride = output_shape.hw_strides()[0] as isize; - let output_y_stride = output_shape.hw_strides()[1] as isize; - let temp_n_stride = n_o_hkwk_hw.strides()[0]; - let temp_o_stride = n_o_hkwk_hw.strides()[1]; - let temp_k_stride = n_o_hkwk_hw.strides()[2]; - let temp_i_stride = n_o_hkwk_hw.strides()[3]; - let ox_len = output_shape.hw_dims()[0]; - let oy_len = output_shape.hw_dims()[1]; - let ix_len = input_shape.hw_dims()[0]; - let iy_len = input_shape.hw_dims()[1]; - let kx_len = self.pool_spec.kernel_shape[0]; - let ky_len = self.pool_spec.kernel_shape[1]; - unsafe { - for n in 0..n { - let output = output.as_mut_ptr().add(n * *output_shape.n_stride().unwrap_or(&0)); - let temp = n_o_hkwk_hw.as_ptr().offset(n as isize * temp_n_stride); - for kx in 0..kx_len { - let temp = temp.offset((kx * ky_len) as isize * temp_k_stride); - for ix in 0..ix_len { - let ox = (kx * x_dil + ix * x_stride) as isize - x_pad; - if ox < 0 || ox >= ox_len as isize { - continue; - } - let temp = temp.offset((ix * iy_len) as isize * temp_i_stride); - let output = output.offset(ox * output_x_stride); - for ky in 0..ky_len { - let temp = temp.offset(ky as isize * temp_k_stride); - let oy = (ky * y_dil) as isize - y_pad; - for iy in 0..iy_len { - let oy = oy + (iy * y_stride) as isize; - if oy < 0 || oy >= oy_len as isize { - continue; + pub fn []( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &ArrayView4, + output: &mut ArrayViewMut4, + add: impl Fn(T, T) -> T + Copy + 'static, + ) -> TractResult<()> { + let n = *output_shape.n().unwrap_or(&1); + let x_stride = op.pool_spec.strides().as_ref()[0]; + let y_stride = op.pool_spec.strides().as_ref()[1]; + let x_dil = op.pool_spec.dilations().as_ref()[0]; + let y_dil = op.pool_spec.dilations().as_ref()[1]; + let x_pad = spatial_output_details[0].pad_before as isize; + let y_pad = spatial_output_details[1].pad_before as isize; + let output_c = *output_shape.c(); + let output_c_stride = *output_shape.c_stride() as isize; + let output_x_stride = output_shape.hw_strides()[0] as isize; + let output_y_stride = output_shape.hw_strides()[1] as isize; + let temp_n_stride = n_o_hkwk_hw.strides()[0]; + let temp_o_stride = n_o_hkwk_hw.strides()[1]; + let temp_k_stride = n_o_hkwk_hw.strides()[2]; + let temp_i_stride = n_o_hkwk_hw.strides()[3]; + let ox_len = output_shape.hw_dims()[0]; + let oy_len = output_shape.hw_dims()[1]; + let ix_len = input_shape.hw_dims()[0]; + let iy_len = input_shape.hw_dims()[1]; + let kx_len = op.pool_spec.kernel_shape[0]; + let ky_len = op.pool_spec.kernel_shape[1]; + unsafe { + for n in 0..n { + let output = output.as_mut_ptr().add(n * *output_shape.n_stride().unwrap_or(&0)); + let temp = n_o_hkwk_hw.as_ptr().offset(n as isize * temp_n_stride); + for kx in 0..kx_len { + let temp = temp.offset((kx * ky_len) as isize * temp_k_stride); + for ix in 0..ix_len { + let ox = (kx * x_dil + ix * x_stride) as isize - x_pad; + if ox < 0 || ox >= ox_len as isize { + continue; + } + let temp = temp.offset((ix * iy_len) as isize * temp_i_stride); + let output = output.offset(ox * output_x_stride); + for ky in 0..ky_len { + let temp = temp.offset(ky as isize * temp_k_stride); + let oy = (ky * y_dil) as isize - y_pad; + for iy in 0..iy_len { + let oy = oy + (iy * y_stride) as isize; + if oy < 0 || oy >= oy_len as isize { + continue; + } + let temp = temp.offset(iy as isize * temp_i_stride); + let output = output.offset(oy * output_y_stride); + []( + output_c, + temp, + temp_o_stride, + output, + output_c_stride, + add, + ) + } + } } - let temp = temp.offset(iy as isize * temp_i_stride); - let output = output.offset(oy * output_y_stride); - Self::main_loop_2d_inner( - output_c, - temp, - temp_o_stride, - output, - output_c_stride, - ) } } } + Ok(()) } - } - } - Ok(()) - } - #[inline(never)] - #[allow(clippy::erasing_op)] - #[allow(clippy::identity_op)] - unsafe fn main_loop_2d_inner( - output_c: usize, - temp: *const T, - temp_o_stride: isize, - output: *mut T, - output_c_stride: isize, - ) { - let mut c = 0; - let mut right = temp; - let mut left = output; - while c + 8 < output_c { - let mut left0 = *left.offset(0 * output_c_stride); - let mut left1 = *left.offset(1 * output_c_stride); - let mut left2 = *left.offset(2 * output_c_stride); - let mut left3 = *left.offset(3 * output_c_stride); - let mut left4 = *left.offset(4 * output_c_stride); - let mut left5 = *left.offset(5 * output_c_stride); - let mut left6 = *left.offset(6 * output_c_stride); - let mut left7 = *left.offset(7 * output_c_stride); - let right0 = *right.offset(0 * temp_o_stride); - let right1 = *right.offset(1 * temp_o_stride); - let right2 = *right.offset(2 * temp_o_stride); - let right3 = *right.offset(3 * temp_o_stride); - let right4 = *right.offset(4 * temp_o_stride); - let right5 = *right.offset(5 * temp_o_stride); - let right6 = *right.offset(6 * temp_o_stride); - let right7 = *right.offset(7 * temp_o_stride); - left0 += right0; - left1 += right1; - left2 += right2; - left3 += right3; - left4 += right4; - left5 += right5; - left6 += right6; - left7 += right7; - *left.offset(0 * output_c_stride) = left0; - *left.offset(1 * output_c_stride) = left1; - *left.offset(2 * output_c_stride) = left2; - *left.offset(3 * output_c_stride) = left3; - *left.offset(4 * output_c_stride) = left4; - *left.offset(5 * output_c_stride) = left5; - *left.offset(6 * output_c_stride) = left6; - *left.offset(7 * output_c_stride) = left7; - c += 8; - left = left.offset(8 * output_c_stride); - right = right.offset(8 * temp_o_stride); - } - for c in c..output_c { - let value = *temp.offset(c as isize * temp_o_stride); - *output.offset(c as isize * output_c_stride) += value; - } - } + #[inline(never)] + #[allow(clippy::erasing_op)] + #[allow(clippy::identity_op)] + unsafe fn []( + output_c: usize, + temp: *const T, + temp_o_stride: isize, + output: *mut T, + output_c_stride: isize, + add: impl Fn(T, T) -> T + Copy + 'static, + ) { + let mut c = 0; + let mut right = temp; + let mut left = output; + while c + 8 < output_c { + let mut left0 = *left.offset(0 * output_c_stride); + let mut left1 = *left.offset(1 * output_c_stride); + let mut left2 = *left.offset(2 * output_c_stride); + let mut left3 = *left.offset(3 * output_c_stride); + let mut left4 = *left.offset(4 * output_c_stride); + let mut left5 = *left.offset(5 * output_c_stride); + let mut left6 = *left.offset(6 * output_c_stride); + let mut left7 = *left.offset(7 * output_c_stride); + let right0 = *right.offset(0 * temp_o_stride); + let right1 = *right.offset(1 * temp_o_stride); + let right2 = *right.offset(2 * temp_o_stride); + let right3 = *right.offset(3 * temp_o_stride); + let right4 = *right.offset(4 * temp_o_stride); + let right5 = *right.offset(5 * temp_o_stride); + let right6 = *right.offset(6 * temp_o_stride); + let right7 = *right.offset(7 * temp_o_stride); + left0 = add(left0, right0); + left1 = add(left1, right1); + left2 = add(left2, right2); + left3 = add(left3, right3); + left4 = add(left4, right4); + left5 = add(left5, right5); + left6 = add(left6, right6); + left7 = add(left7, right7); + *left.offset(0 * output_c_stride) = left0; + *left.offset(1 * output_c_stride) = left1; + *left.offset(2 * output_c_stride) = left2; + *left.offset(3 * output_c_stride) = left3; + *left.offset(4 * output_c_stride) = left4; + *left.offset(5 * output_c_stride) = left5; + *left.offset(6 * output_c_stride) = left6; + *left.offset(7 * output_c_stride) = left7; + c += 8; + left = left.offset(8 * output_c_stride); + right = right.offset(8 * temp_o_stride); + } + for c in c..output_c { + let value = *temp.offset(c as isize * temp_o_stride); + let ptr = output.offset(c as isize * output_c_stride); + *ptr = add(*ptr, value); + } + } - pub fn main_loop_3d( - &self, - input_shape: &DataShape, - output_shape: &DataShape, - spatial_output_details: &[ComputedPaddedDim], - n_o_hkwk_hw: &ArrayView4, - output: &mut ArrayViewMut5, - ) -> TractResult<()> { - let n = *output_shape.n().unwrap_or(&1); - let kernel_shape: [usize; 3] = [ - self.pool_spec.kernel_shape[0], - self.pool_spec.kernel_shape[1], - self.pool_spec.kernel_shape[2], - ]; - let geo_input_shape: [usize; 3] = - [input_shape.hw_dims()[0], input_shape.hw_dims()[1], input_shape.hw_dims()[2]]; - let geo_output_shape: [usize; 3] = - [output_shape.hw_dims()[0], output_shape.hw_dims()[1], output_shape.hw_dims()[2]]; - let x_stride = self.pool_spec.strides().as_ref()[0]; - let y_stride = self.pool_spec.strides().as_ref()[1]; - let z_stride = self.pool_spec.strides().as_ref()[2]; - let x_dil = self.pool_spec.dilations().as_ref()[0]; - let y_dil = self.pool_spec.dilations().as_ref()[1]; - let z_dil = self.pool_spec.dilations().as_ref()[2]; - let x_pad = spatial_output_details[0].pad_before as isize; - let y_pad = spatial_output_details[1].pad_before as isize; - let z_pad = spatial_output_details[2].pad_before as isize; - for n in 0..n { - for o in 0..*output_shape.c() { - for (kix, (kx, ky, kz)) in - tract_ndarray::indices(kernel_shape).into_iter().enumerate() - { - for (gix, (gx, gy, gz)) in - tract_ndarray::indices(geo_input_shape).into_iter().enumerate() - { - let x = (kx * x_dil + gx * x_stride) as isize - x_pad; - let y = (ky * y_dil + gy * y_stride) as isize - y_pad; - let z = (kz * z_dil + gz * z_stride) as isize - z_pad; - if x < 0 - || y < 0 - || z < 0 - || x >= geo_output_shape[0] as isize - || y >= geo_output_shape[1] as isize - || z >= geo_output_shape[2] as isize - { - continue; - } - let coord = if self.pool_spec.data_format.c_is_last() { - [n, x as usize, y as usize, z as usize, o] - } else { - [n, o, x as usize, y as usize, z as usize] - }; - unsafe { - let value = *n_o_hkwk_hw.uget((n, o, kix, gix)); - *output.uget_mut(coord) += value; + pub fn []( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &ArrayView4, + output: &mut ArrayViewMut5, + add: impl Fn(T, T) -> T + Copy + 'static, + ) -> TractResult<()> { + let n = *output_shape.n().unwrap_or(&1); + let kernel_shape: [usize; 3] = + [op.pool_spec.kernel_shape[0], op.pool_spec.kernel_shape[1], op.pool_spec.kernel_shape[2]]; + let geo_input_shape: [usize; 3] = + [input_shape.hw_dims()[0], input_shape.hw_dims()[1], input_shape.hw_dims()[2]]; + let geo_output_shape: [usize; 3] = + [output_shape.hw_dims()[0], output_shape.hw_dims()[1], output_shape.hw_dims()[2]]; + let x_stride = op.pool_spec.strides().as_ref()[0]; + let y_stride = op.pool_spec.strides().as_ref()[1]; + let z_stride = op.pool_spec.strides().as_ref()[2]; + let x_dil = op.pool_spec.dilations().as_ref()[0]; + let y_dil = op.pool_spec.dilations().as_ref()[1]; + let z_dil = op.pool_spec.dilations().as_ref()[2]; + let x_pad = spatial_output_details[0].pad_before as isize; + let y_pad = spatial_output_details[1].pad_before as isize; + let z_pad = spatial_output_details[2].pad_before as isize; + for n in 0..n { + for o in 0..*output_shape.c() { + for (kix, (kx, ky, kz)) in tract_ndarray::indices(kernel_shape).into_iter().enumerate() + { + for (gix, (gx, gy, gz)) in + tract_ndarray::indices(geo_input_shape).into_iter().enumerate() + { + let x = (kx * x_dil + gx * x_stride) as isize - x_pad; + let y = (ky * y_dil + gy * y_stride) as isize - y_pad; + let z = (kz * z_dil + gz * z_stride) as isize - z_pad; + if x < 0 + || y < 0 + || z < 0 + || x >= geo_output_shape[0] as isize + || y >= geo_output_shape[1] as isize + || z >= geo_output_shape[2] as isize + { + continue; + } + let coord = if op.pool_spec.data_format.c_is_last() { + [n, x as usize, y as usize, z as usize, o] + } else { + [n, o, x as usize, y as usize, z as usize] + }; + unsafe { + let value = *n_o_hkwk_hw.uget((n, o, kix, gix)); + *output.uget_mut(coord) = add(*output.uget(coord), value); + } + } + } } } + Ok(()) } - } - } - Ok(()) - } - pub fn main_loop( - &self, - input_shape: &DataShape, - output_shape: &DataShape, - spatial_output_details: &[ComputedPaddedDim], - n_o_hkwk_hw: &ArrayView4, - output: &mut ArrayViewMutD, - ) -> TractResult<()> { - let n = *output_shape.n().unwrap_or(&1); - for n in 0..n { - for o in 0..*output_shape.c() { - for (kix, kcoords) in - tract_ndarray::indices(&*self.pool_spec.kernel_shape).into_iter().enumerate() - { - for (gix, gcoords) in - tract_ndarray::indices(input_shape.hw_dims()).into_iter().enumerate() - { - // h' = stride * hg + dil * hk - let ocoord: TVec = tract_itertools::izip!( - kcoords.slice(), - gcoords.slice(), - self.pool_spec.strides().as_ref(), - self.pool_spec.dilations().as_ref(), - spatial_output_details - ) - .map(|(k, g, s, d, details)| { - (k * d + g * s) as isize - details.pad_before as isize - }) - .collect(); - if ocoord - .iter() - .zip(output_shape.hw_dims().iter()) - .all(|(x, dim)| *x >= 0 && (*x as usize) < *dim) - { - let ocoord = ocoord.iter().map(|x| *x as usize).collect::>(); - let ocoord = - self.pool_spec.data_format.with_n().from_n_c_hw(n, o, ocoord)?; - let value = n_o_hkwk_hw[(n, o, kix, gix)]; - output[&*ocoord.shape] += value + + pub fn []( + op: &DeconvSum, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &ArrayView4, + output: &mut ArrayViewMutD, + add: impl Fn(T, T) -> T + Copy + 'static, + ) -> TractResult<()> { + let n = *output_shape.n().unwrap_or(&1); + for n in 0..n { + for o in 0..*output_shape.c() { + for (kix, kcoords) in + tract_ndarray::indices(&*op.pool_spec.kernel_shape).into_iter().enumerate() + { + for (gix, gcoords) in + tract_ndarray::indices(input_shape.hw_dims()).into_iter().enumerate() + { + // h' = stride * hg + dil * hk + let ocoord: TVec = tract_itertools::izip!( + kcoords.slice(), + gcoords.slice(), + op.pool_spec.strides().as_ref(), + op.pool_spec.dilations().as_ref(), + spatial_output_details + ) + .map(|(k, g, s, d, details)| { + (k * d + g * s) as isize - details.pad_before as isize + }) + .collect(); + if ocoord + .iter() + .zip(output_shape.hw_dims().iter()) + .all(|(x, dim)| *x >= 0 && (*x as usize) < *dim) + { + let ocoord = ocoord.iter().map(|x| *x as usize).collect::>(); + let ocoord = op.pool_spec.data_format.with_n().from_n_c_hw(n, o, ocoord)?; + let value = n_o_hkwk_hw[(n, o, kix, gix)]; + output[&*ocoord.shape] = add(output[&*ocoord.shape], value) + } + } + } } } + Ok(()) } } } - Ok(()) } -} -impl TypedOp for DeconvSum { - fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - let shape = super::output_shape(&self.pool_spec, &self.input_shape, &self.adjustments)?; - Ok(tvec!(inputs[0].datum_type.fact(&*shape))) - } - - fn concretize_dims( - &self, - _source: &TypedModel, - node: &TypedNode, - target: &mut TypedModel, - mapping: &HashMap, - values: &SymbolValues, - ) -> TractResult> { - target.wire_node( - &node.name, - Self { input_shape: self.input_shape.eval(values)?.into_owned(), ..self.clone() }, - &[mapping[&node.inputs[0]]], - ) +impl_eval!(generic); +impl_eval! { +#[target_feature(enable = "fp16")] +#[cfg(target_arch = "aarch64")] + aarch64fp16 } - - as_op!(); -} From 0a76d4825d3f5746ed4e067e4f257927ecd92993 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 18 Oct 2023 11:05:56 +0200 Subject: [PATCH 5/5] warnings --- core/src/ops/cnn/deconv/deconv_sum.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 83f4f5fc77..65aed7f7b3 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::ops::AddAssign; use crate::internal::*;