From 5576526b7ddda189d7dbe834f7a14d700022836d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Sun, 5 May 2024 14:57:21 +0200 Subject: [PATCH] plug in maxpool --- core/src/ops/cnn/pools.rs | 2 +- test-rt/test-tflite/suite.rs | 10 +++- tflite/src/ops/cnn.rs | 111 ++++++++++++++++++++++++++++------- tflite/src/rewriter.rs | 107 ++++++++++++++++++++++----------- 4 files changed, 173 insertions(+), 57 deletions(-) diff --git a/core/src/ops/cnn/pools.rs b/core/src/ops/cnn/pools.rs index 5112720546..3638130a4d 100644 --- a/core/src/ops/cnn/pools.rs +++ b/core/src/ops/cnn/pools.rs @@ -108,7 +108,7 @@ impl PoolSpec { if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.padding { let input = self.data_format.shape(input)?; let input_hw = input.hw_dims(); - let reference = self.computed_padding(&input_hw); + let reference = self.computed_padding(input_hw); for replacement in [ PaddingSpec::Valid, PaddingSpec::SameUpper, diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 6d6f9f03af..1db443dde9 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -47,6 +47,9 @@ fn ignore_onnx(t: &[String]) -> bool { Conv1d Conv2d + test_averagepool_2d + test_maxpool_2d + squeeze _transpose_ test_concat @@ -81,7 +84,6 @@ fn ignore_onnx(t: &[String]) -> bool { test_sqrt test_rsqrt - test_cos test_sin # lol, no tan :) @@ -107,7 +109,11 @@ fn ignore_onnx(t: &[String]) -> bool { test_split_zero_size test_mul_uint8 test_div_uint8 - test_reduce_log_sum_exp.* # tflite does not support f64 reducers 🤷 + test_reduce_log_sum_exp.* # tflite does not support f64 reducers 🤷 + pool_2d_ceil + pool_2d_pads + pool_2d_precomputed_pads_count_include_pad + pool_2d_same_lower test_cosh.* test_sinh.* "); diff --git a/tflite/src/ops/cnn.rs b/tflite/src/ops/cnn.rs index da69795ec3..5798777821 100644 --- a/tflite/src/ops/cnn.rs +++ b/tflite/src/ops/cnn.rs @@ -4,26 +4,84 @@ use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ ActivationFunctionType, BuiltinOperator, BuiltinOptions, Conv2DOptions, Conv2DOptionsArgs, DepthwiseConv2DOptions, DepthwiseConv2DOptionsArgs, PadOptions, PadOptionsArgs, Padding, + Pool2DOptions, Pool2DOptionsArgs, }; +use flatbuffers::{FlatBufferBuilder, WIPOffset}; use tract_core::internal::*; use tract_core::ops as core; use tract_core::ops::array::{Pad, PadMode}; use tract_core::ops::cast::cast; -use tract_core::ops::cnn::KernelFormat; -use tract_core::ops::cnn::{Conv, PaddingSpec}; +use tract_core::ops::cnn::{Conv, MaxPool, PaddingSpec, PoolSpec}; +use tract_core::ops::cnn::{KernelFormat, SumPool}; use tract_core::ops::nn::DataFormat; use tract_core::prelude::tract_itertools::Itertools; pub fn register_all(reg: &mut Registry) { - reg.reg_to_tract(BuiltinOperator::AVERAGE_POOL_2D, average_pool_2d); + reg.reg_to_tflite(ser_max_pool); + reg.reg_to_tflite(ser_sum_pool); + reg.reg_to_tract(BuiltinOperator::AVERAGE_POOL_2D, de_average_pool_2d); + reg.reg_to_tract(BuiltinOperator::MAX_POOL_2D, de_max_pool_2d); reg.reg_to_tract(BuiltinOperator::CONV_2D, de_conv2d); reg.reg_to_tflite(ser_conv); reg.reg_to_tract(BuiltinOperator::DEPTHWISE_CONV_2D, de_dw_conv2d); reg.reg_to_tflite(ser_pad); } -fn average_pool_2d(op: &mut DeserOp) -> TractResult> { - let options = builtin!(op, builtin_options_as_pool_2_doptions); +fn pool_2d_options<'fb>( + fb: &mut FlatBufferBuilder<'fb>, + pool_spec: &PoolSpec, +) -> TractResult>> { + ensure!(pool_spec.data_format == DataFormat::NHWC); + ensure!(pool_spec.rank() == 2); + ensure!( + pool_spec.padding == PaddingSpec::Valid || pool_spec.padding == PaddingSpec::SameUpper, + "unsupported padding {:?}", + pool_spec.padding + ); + let padding = + if pool_spec.padding == PaddingSpec::Valid { Padding::VALID } else { Padding::SAME }; + let options = Pool2DOptions::create( + fb, + &Pool2DOptionsArgs { + padding, + stride_h: pool_spec.stride(0) as _, + stride_w: pool_spec.stride(1) as _, + filter_height: pool_spec.kernel_shape[0] as _, + filter_width: pool_spec.kernel_shape[1] as _, + fused_activation_function: ActivationFunctionType::NONE, + }, + ); + Ok(options) +} + +fn ser_max_pool( + builder: &mut SubgraphBuilder, + model: &TypedModel, + node: &TypedNode, + op: &MaxPool, +) -> TractResult<()> { + let inputs = tvec!(builder.map_outlet(model, node.inputs[0])?); + let output = builder.outlets_to_tensors[&node.id.into()]; + let options = pool_2d_options(builder.fb(), &op.pool_spec)?; + let op = BuiltinOp::new(17, 1, BuiltinOperator::MAX_POOL_2D, BuiltinOptions::Pool2DOptions); + builder.write_op_with_options(&inputs, &[output], op, options.as_union_value()) +} + +fn ser_sum_pool( + builder: &mut SubgraphBuilder, + model: &TypedModel, + node: &TypedNode, + op: &SumPool, +) -> TractResult<()> { + ensure!(op.normalize); + let inputs = tvec!(builder.map_outlet(model, node.inputs[0])?); + let output = builder.outlets_to_tensors[&node.id.into()]; + let options = pool_2d_options(builder.fb(), &op.pool_spec)?; + let op = BuiltinOp::new(1, 1, BuiltinOperator::AVERAGE_POOL_2D, BuiltinOptions::Pool2DOptions); + builder.write_op_with_options(&inputs, &[output], op, options.as_union_value()) +} + +fn de_pool_2d_options(options: &Pool2DOptions, shape: &ShapeFact) -> TractResult { let strides = tvec!(options.stride_h() as usize, options.stride_w() as usize); let kernel_shape = tvec!(options.filter_height() as usize, options.filter_width() as usize); let padding = match options.padding() { @@ -31,12 +89,9 @@ fn average_pool_2d(op: &mut DeserOp) -> TractResult> { Padding::VALID => PaddingSpec::Valid, _ => todo!(), }; - let ci = DataFormat::NHWC - .shape(&op.facts()?[0].shape)? - .c() - .to_usize() - .context("Except defined integer depth")?; - let pool_spec = core::cnn::PoolSpec { + let ci = + DataFormat::NHWC.shape(&shape)?.c().to_usize().context("Except defined integer depth")?; + Ok(core::cnn::PoolSpec { data_format: DataFormat::NHWC, kernel_shape, padding, @@ -44,12 +99,25 @@ fn average_pool_2d(op: &mut DeserOp) -> TractResult> { dilations: None, input_channels: ci, output_channels: ci, - }; + }) +} + +fn de_average_pool_2d(op: &mut DeserOp) -> TractResult> { + let options = builtin!(op, builtin_options_as_pool_2_doptions); + let pool_spec = de_pool_2d_options(&options, &op.output_facts[0].shape)?; let pool = core::cnn::SumPool { pool_spec, normalize: true, count_include_pad: false }; let wires = op.ctx.target.wire_node(op.prefix, pool, &op.inputs[0..1])?; wire_fused_activation(op, &wires, &options.fused_activation_function()) } +fn de_max_pool_2d(op: &mut DeserOp) -> TractResult> { + let options = builtin!(op, builtin_options_as_pool_2_doptions); + let pool_spec = de_pool_2d_options(&options, &op.output_facts[0].shape)?; + let pool = core::cnn::MaxPool { pool_spec, with_index_outputs: None }; + let wires = op.ctx.target.wire_node(op.prefix, pool, &op.inputs[0..1])?; + wire_fused_activation(op, &wires, &options.fused_activation_function()) +} + fn ser_conv( builder: &mut SubgraphBuilder, model: &TypedModel, @@ -75,12 +143,16 @@ fn ser_conv( let kscale = facts[6].konst.as_ref().unwrap().as_slice::()?; let per_channel = !kscale.iter().all_equal(); if per_channel { - let kernel = model.outlet_fact(node.inputs[1])?.konst.as_ref().context( - "tract TODO: dynamic convolution and per-channel scales", - )?; - let bias = model.outlet_fact(node.inputs[2])?.konst.as_ref().context( - "tract TODO: dynamic convolution and per-channel scales", - )?; + let kernel = model + .outlet_fact(node.inputs[1])? + .konst + .as_ref() + .context("tract TODO: dynamic convolution and per-channel scales")?; + let bias = model + .outlet_fact(node.inputs[2])? + .konst + .as_ref() + .context("tract TODO: dynamic convolution and per-channel scales")?; inputs.push(builder.write_fact_with_per_axis_q( &format!("{node_name}.weights"), kernel, @@ -191,8 +263,7 @@ fn de_conv2d(op: &mut DeserOp) -> TractResult> { inputs[2] = op.ctx.target.wire_node(format!("{}.cast_bias", op.prefix), cast(bias_dt), &[inputs[2]])? [0]; - let conv = - core::cnn::Conv { pool_spec, kernel_fmt: KernelFormat::OHWI, group: 1, q_params }; + let conv = core::cnn::Conv { pool_spec, kernel_fmt: KernelFormat::OHWI, group: 1, q_params }; let wires = op.ctx.target.wire_node(op.prefix, conv, &inputs)?; wire_fused_activation(op, &wires, &options.fused_activation_function()) } diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 5f812c52b3..16c2e6dd34 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -1,7 +1,7 @@ use tract_core::internal::*; use tract_core::ops::array::{Pad, PadMode}; use tract_core::ops::binary::wire_with_rank_broadcast; -use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat}; +use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat, MaxPool, PoolSpec, SumPool}; use tract_core::ops::cnn::{Conv, PaddingSpec}; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::element_wise::ElementWiseOp; @@ -15,10 +15,12 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> { .with_rule_for("trivial_axes_around_matmul", trivial_axes_around_matmul) .with_rule_for("kernel_in_ohwi", kernel_in_ohwi) .with_rule_for("bias_as_vector", bias_as_vector) -// .with_rule_for("per_layer_in_u8", per_layer_in_u8) + // .with_rule_for("per_layer_in_u8", per_layer_in_u8) .with_rule_for("make_1d_2d", make_1d_2d) .with_rule_for("rewrite_conv_with_n_axis", rewrite_conv_with_n_axis) - .with_rule_for("nchw-to-nhwc", nchw_to_nhwc) + .with_rule_for("conv-nchw-to-nhwc", conv_nchw_to_nhwc) + .with_rule_for("maxpool-nchw-to-nhwc", maxpool_nchw_to_nhwc) + .with_rule_for("sumpool-nchw-to-nhwc", sumpool_nchw_to_nhwc) .with_rule_for("padding", padding) .with_rule_for("manual_recip", manual_recip) .with_rule_for("softmax_on_last_axis", softmax_on_last_axis) @@ -130,32 +132,32 @@ fn bias_as_vector( /* fn per_layer_in_u8( - _ctx: &(), - model: &TypedModel, - node: &TypedNode, - name: &str, - conv: &Conv, +_ctx: &(), +model: &TypedModel, +node: &TypedNode, +name: &str, +conv: &Conv, ) -> TractResult> { - let input_fact = model.outlet_fact(node.inputs[0])?; - let idt = input_fact.datum_type; - let kernel_fact = model.outlet_fact(node.inputs[1])?; - let kdt = kernel_fact.datum_type; - if idt.is_float() || model.outlet_fact(node.inputs[6])?.shape.len() > 1 { - return Ok(None); - } - if idt.unquantized() == u8::datum_type() && kdt.unquantized() == u8::datum_type() { - return Ok(None); - } - let mut patch = TypedModelPatch::default(); - let wire = patch.taps(model, &node.inputs)?; - let [mut i, mut k, b, mut i0, is, mut k0, ks, o0, os] = &*wire else { - bail!("Unexpected number of inputs") - }; - wire_ensure_q8_flavour(&mut patch, name, &mut i, "input", &mut i0, DatumType::U8)?; - wire_ensure_q8_flavour(&mut patch, name, &mut k, "kernel", &mut k0, DatumType::U8)?; - let output = patch.wire_node(name, conv.clone(), &[i, k, *b, i0, *is, k0, *ks, *o0, *os])?; - patch.shunt_outside(model, node.id.into(), output[0])?; - Ok(Some(patch)) +let input_fact = model.outlet_fact(node.inputs[0])?; +let idt = input_fact.datum_type; +let kernel_fact = model.outlet_fact(node.inputs[1])?; +let kdt = kernel_fact.datum_type; +if idt.is_float() || model.outlet_fact(node.inputs[6])?.shape.len() > 1 { +return Ok(None); +} +if idt.unquantized() == u8::datum_type() && kdt.unquantized() == u8::datum_type() { +return Ok(None); +} +let mut patch = TypedModelPatch::default(); +let wire = patch.taps(model, &node.inputs)?; +let [mut i, mut k, b, mut i0, is, mut k0, ks, o0, os] = &*wire else { +bail!("Unexpected number of inputs") +}; +wire_ensure_q8_flavour(&mut patch, name, &mut i, "input", &mut i0, DatumType::U8)?; +wire_ensure_q8_flavour(&mut patch, name, &mut k, "kernel", &mut k0, DatumType::U8)?; +let output = patch.wire_node(name, conv.clone(), &[i, k, *b, i0, *is, k0, *ks, *o0, *os])?; +patch.shunt_outside(model, node.id.into(), output[0])?; +Ok(Some(patch)) } */ @@ -184,29 +186,66 @@ fn make_1d_2d( Ok(None) } -fn nchw_to_nhwc( +fn conv_nchw_to_nhwc( _ctx: &(), model: &TypedModel, node: &TypedNode, name: &str, conv: &Conv, ) -> TractResult> { - if !conv.pool_spec.data_format.c_is_last() { - let mut new = conv.clone(); - new.pool_spec.data_format = match conv.pool_spec.data_format { + nchw_to_nhwc(_ctx, model, node, name, &conv.pool_spec, &|pool_spec| { + Box::new(Conv { pool_spec, ..conv.clone() }) + }) +} + +fn maxpool_nchw_to_nhwc( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + op: &MaxPool, +) -> TractResult> { + nchw_to_nhwc(_ctx, model, node, name, &op.pool_spec, &|pool_spec| { + Box::new(MaxPool { pool_spec, ..op.clone() }) + }) +} + +fn sumpool_nchw_to_nhwc( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + op: &SumPool, +) -> TractResult> { + nchw_to_nhwc(_ctx, model, node, name, &op.pool_spec, &|pool_spec| { + Box::new(SumPool { pool_spec, ..op.clone() }) + }) +} + +fn nchw_to_nhwc( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + old: &PoolSpec, + op: &dyn Fn(PoolSpec) -> Box, +) -> TractResult> { + if !old.data_format.c_is_last() { + let mut new = old.clone(); + new.data_format = match new.data_format { DataFormat::NHWC | DataFormat::HWC => unreachable!(), DataFormat::CHW => DataFormat::HWC, DataFormat::NCHW => DataFormat::NHWC, }; let mut patch = TypedModelPatch::default(); let fact = model.outlet_fact(node.inputs[0])?; - let shape = conv.pool_spec.data_format.shape(&fact.shape)?; + let shape = old.data_format.shape(&fact.shape)?; let before = shape.c_axis(); let after = fact.rank() - 1; let mut wire = patch.taps(model, &node.inputs)?; wire[0] = patch.wire_node(format!("{name}.nhwc"), AxisOp::Move(before, after), &[wire[0]])?[0]; - wire = patch.wire_node(name, new, &wire)?; + wire = patch.wire_node(name, op(new), &wire)?; wire = patch.wire_node(format!("{name}.nchw"), AxisOp::Move(after, before), &wire)?; patch.shunt_outside(model, node.id.into(), wire[0])?; return Ok(Some(patch));