Skip to content

Commit

Permalink
plug in maxpool
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 7, 2024
1 parent 7ed8895 commit 5576526
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 57 deletions.
2 changes: 1 addition & 1 deletion core/src/ops/cnn/pools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl PoolSpec {
if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.padding {
let input = self.data_format.shape(input)?;
let input_hw = input.hw_dims();
let reference = self.computed_padding(&input_hw);
let reference = self.computed_padding(input_hw);
for replacement in [
PaddingSpec::Valid,
PaddingSpec::SameUpper,
Expand Down
10 changes: 8 additions & 2 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ fn ignore_onnx(t: &[String]) -> bool {
Conv1d
Conv2d
test_averagepool_2d
test_maxpool_2d
squeeze
_transpose_
test_concat
Expand Down Expand Up @@ -81,7 +84,6 @@ fn ignore_onnx(t: &[String]) -> bool {
test_sqrt
test_rsqrt
test_cos
test_sin
# lol, no tan :)
Expand All @@ -107,7 +109,11 @@ fn ignore_onnx(t: &[String]) -> bool {
test_split_zero_size
test_mul_uint8
test_div_uint8
test_reduce_log_sum_exp.* # tflite does not support f64 reducers 🤷
test_reduce_log_sum_exp.* # tflite does not support f64 reducers 🤷
pool_2d_ceil
pool_2d_pads
pool_2d_precomputed_pads_count_include_pad
pool_2d_same_lower
test_cosh.*
test_sinh.*
");
Expand Down
111 changes: 91 additions & 20 deletions tflite/src/ops/cnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,120 @@ use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
ActivationFunctionType, BuiltinOperator, BuiltinOptions, Conv2DOptions, Conv2DOptionsArgs,
DepthwiseConv2DOptions, DepthwiseConv2DOptionsArgs, PadOptions, PadOptionsArgs, Padding,
Pool2DOptions, Pool2DOptionsArgs,
};
use flatbuffers::{FlatBufferBuilder, WIPOffset};
use tract_core::internal::*;
use tract_core::ops as core;
use tract_core::ops::array::{Pad, PadMode};
use tract_core::ops::cast::cast;
use tract_core::ops::cnn::KernelFormat;
use tract_core::ops::cnn::{Conv, PaddingSpec};
use tract_core::ops::cnn::{Conv, MaxPool, PaddingSpec, PoolSpec};
use tract_core::ops::cnn::{KernelFormat, SumPool};
use tract_core::ops::nn::DataFormat;
use tract_core::prelude::tract_itertools::Itertools;

pub fn register_all(reg: &mut Registry) {
reg.reg_to_tract(BuiltinOperator::AVERAGE_POOL_2D, average_pool_2d);
reg.reg_to_tflite(ser_max_pool);
reg.reg_to_tflite(ser_sum_pool);
reg.reg_to_tract(BuiltinOperator::AVERAGE_POOL_2D, de_average_pool_2d);
reg.reg_to_tract(BuiltinOperator::MAX_POOL_2D, de_max_pool_2d);
reg.reg_to_tract(BuiltinOperator::CONV_2D, de_conv2d);
reg.reg_to_tflite(ser_conv);
reg.reg_to_tract(BuiltinOperator::DEPTHWISE_CONV_2D, de_dw_conv2d);
reg.reg_to_tflite(ser_pad);
}

fn average_pool_2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let options = builtin!(op, builtin_options_as_pool_2_doptions);
fn pool_2d_options<'fb>(
fb: &mut FlatBufferBuilder<'fb>,
pool_spec: &PoolSpec,
) -> TractResult<WIPOffset<Pool2DOptions<'fb>>> {
ensure!(pool_spec.data_format == DataFormat::NHWC);
ensure!(pool_spec.rank() == 2);
ensure!(
pool_spec.padding == PaddingSpec::Valid || pool_spec.padding == PaddingSpec::SameUpper,
"unsupported padding {:?}",
pool_spec.padding
);
let padding =
if pool_spec.padding == PaddingSpec::Valid { Padding::VALID } else { Padding::SAME };
let options = Pool2DOptions::create(
fb,
&Pool2DOptionsArgs {
padding,
stride_h: pool_spec.stride(0) as _,
stride_w: pool_spec.stride(1) as _,
filter_height: pool_spec.kernel_shape[0] as _,
filter_width: pool_spec.kernel_shape[1] as _,
fused_activation_function: ActivationFunctionType::NONE,
},
);
Ok(options)
}

fn ser_max_pool(
builder: &mut SubgraphBuilder,
model: &TypedModel,
node: &TypedNode,
op: &MaxPool,
) -> TractResult<()> {
let inputs = tvec!(builder.map_outlet(model, node.inputs[0])?);
let output = builder.outlets_to_tensors[&node.id.into()];
let options = pool_2d_options(builder.fb(), &op.pool_spec)?;
let op = BuiltinOp::new(17, 1, BuiltinOperator::MAX_POOL_2D, BuiltinOptions::Pool2DOptions);
builder.write_op_with_options(&inputs, &[output], op, options.as_union_value())
}

fn ser_sum_pool(
builder: &mut SubgraphBuilder,
model: &TypedModel,
node: &TypedNode,
op: &SumPool,
) -> TractResult<()> {
ensure!(op.normalize);
let inputs = tvec!(builder.map_outlet(model, node.inputs[0])?);
let output = builder.outlets_to_tensors[&node.id.into()];
let options = pool_2d_options(builder.fb(), &op.pool_spec)?;
let op = BuiltinOp::new(1, 1, BuiltinOperator::AVERAGE_POOL_2D, BuiltinOptions::Pool2DOptions);
builder.write_op_with_options(&inputs, &[output], op, options.as_union_value())
}

fn de_pool_2d_options(options: &Pool2DOptions, shape: &ShapeFact) -> TractResult<PoolSpec> {
let strides = tvec!(options.stride_h() as usize, options.stride_w() as usize);
let kernel_shape = tvec!(options.filter_height() as usize, options.filter_width() as usize);
let padding = match options.padding() {
Padding::SAME => PaddingSpec::SameUpper,
Padding::VALID => PaddingSpec::Valid,
_ => todo!(),
};
let ci = DataFormat::NHWC
.shape(&op.facts()?[0].shape)?
.c()
.to_usize()
.context("Except defined integer depth")?;
let pool_spec = core::cnn::PoolSpec {
let ci =
DataFormat::NHWC.shape(&shape)?.c().to_usize().context("Except defined integer depth")?;
Ok(core::cnn::PoolSpec {
data_format: DataFormat::NHWC,
kernel_shape,
padding,
strides: Some(strides),
dilations: None,
input_channels: ci,
output_channels: ci,
};
})
}

fn de_average_pool_2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let options = builtin!(op, builtin_options_as_pool_2_doptions);
let pool_spec = de_pool_2d_options(&options, &op.output_facts[0].shape)?;
let pool = core::cnn::SumPool { pool_spec, normalize: true, count_include_pad: false };
let wires = op.ctx.target.wire_node(op.prefix, pool, &op.inputs[0..1])?;
wire_fused_activation(op, &wires, &options.fused_activation_function())
}

fn de_max_pool_2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let options = builtin!(op, builtin_options_as_pool_2_doptions);
let pool_spec = de_pool_2d_options(&options, &op.output_facts[0].shape)?;
let pool = core::cnn::MaxPool { pool_spec, with_index_outputs: None };
let wires = op.ctx.target.wire_node(op.prefix, pool, &op.inputs[0..1])?;
wire_fused_activation(op, &wires, &options.fused_activation_function())
}

fn ser_conv(
builder: &mut SubgraphBuilder,
model: &TypedModel,
Expand All @@ -75,12 +143,16 @@ fn ser_conv(
let kscale = facts[6].konst.as_ref().unwrap().as_slice::<f32>()?;
let per_channel = !kscale.iter().all_equal();
if per_channel {
let kernel = model.outlet_fact(node.inputs[1])?.konst.as_ref().context(
"tract TODO: dynamic convolution and per-channel scales",
)?;
let bias = model.outlet_fact(node.inputs[2])?.konst.as_ref().context(
"tract TODO: dynamic convolution and per-channel scales",
)?;
let kernel = model
.outlet_fact(node.inputs[1])?
.konst
.as_ref()
.context("tract TODO: dynamic convolution and per-channel scales")?;
let bias = model
.outlet_fact(node.inputs[2])?
.konst
.as_ref()
.context("tract TODO: dynamic convolution and per-channel scales")?;
inputs.push(builder.write_fact_with_per_axis_q(
&format!("{node_name}.weights"),
kernel,
Expand Down Expand Up @@ -191,8 +263,7 @@ fn de_conv2d(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
inputs[2] =
op.ctx.target.wire_node(format!("{}.cast_bias", op.prefix), cast(bias_dt), &[inputs[2]])?
[0];
let conv =
core::cnn::Conv { pool_spec, kernel_fmt: KernelFormat::OHWI, group: 1, q_params };
let conv = core::cnn::Conv { pool_spec, kernel_fmt: KernelFormat::OHWI, group: 1, q_params };
let wires = op.ctx.target.wire_node(op.prefix, conv, &inputs)?;
wire_fused_activation(op, &wires, &options.fused_activation_function())
}
Expand Down
107 changes: 73 additions & 34 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use tract_core::internal::*;
use tract_core::ops::array::{Pad, PadMode};
use tract_core::ops::binary::wire_with_rank_broadcast;
use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat};
use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat, MaxPool, PoolSpec, SumPool};
use tract_core::ops::cnn::{Conv, PaddingSpec};
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::element_wise::ElementWiseOp;
Expand All @@ -15,10 +15,12 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> {
.with_rule_for("trivial_axes_around_matmul", trivial_axes_around_matmul)
.with_rule_for("kernel_in_ohwi", kernel_in_ohwi)
.with_rule_for("bias_as_vector", bias_as_vector)
// .with_rule_for("per_layer_in_u8", per_layer_in_u8)
// .with_rule_for("per_layer_in_u8", per_layer_in_u8)
.with_rule_for("make_1d_2d", make_1d_2d)
.with_rule_for("rewrite_conv_with_n_axis", rewrite_conv_with_n_axis)
.with_rule_for("nchw-to-nhwc", nchw_to_nhwc)
.with_rule_for("conv-nchw-to-nhwc", conv_nchw_to_nhwc)
.with_rule_for("maxpool-nchw-to-nhwc", maxpool_nchw_to_nhwc)
.with_rule_for("sumpool-nchw-to-nhwc", sumpool_nchw_to_nhwc)
.with_rule_for("padding", padding)
.with_rule_for("manual_recip", manual_recip)
.with_rule_for("softmax_on_last_axis", softmax_on_last_axis)
Expand Down Expand Up @@ -130,32 +132,32 @@ fn bias_as_vector(

/*
fn per_layer_in_u8(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
conv: &Conv,
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
let input_fact = model.outlet_fact(node.inputs[0])?;
let idt = input_fact.datum_type;
let kernel_fact = model.outlet_fact(node.inputs[1])?;
let kdt = kernel_fact.datum_type;
if idt.is_float() || model.outlet_fact(node.inputs[6])?.shape.len() > 1 {
return Ok(None);
}
if idt.unquantized() == u8::datum_type() && kdt.unquantized() == u8::datum_type() {
return Ok(None);
}
let mut patch = TypedModelPatch::default();
let wire = patch.taps(model, &node.inputs)?;
let [mut i, mut k, b, mut i0, is, mut k0, ks, o0, os] = &*wire else {
bail!("Unexpected number of inputs")
};
wire_ensure_q8_flavour(&mut patch, name, &mut i, "input", &mut i0, DatumType::U8)?;
wire_ensure_q8_flavour(&mut patch, name, &mut k, "kernel", &mut k0, DatumType::U8)?;
let output = patch.wire_node(name, conv.clone(), &[i, k, *b, i0, *is, k0, *ks, *o0, *os])?;
patch.shunt_outside(model, node.id.into(), output[0])?;
Ok(Some(patch))
let input_fact = model.outlet_fact(node.inputs[0])?;
let idt = input_fact.datum_type;
let kernel_fact = model.outlet_fact(node.inputs[1])?;
let kdt = kernel_fact.datum_type;
if idt.is_float() || model.outlet_fact(node.inputs[6])?.shape.len() > 1 {
return Ok(None);
}
if idt.unquantized() == u8::datum_type() && kdt.unquantized() == u8::datum_type() {
return Ok(None);
}
let mut patch = TypedModelPatch::default();
let wire = patch.taps(model, &node.inputs)?;
let [mut i, mut k, b, mut i0, is, mut k0, ks, o0, os] = &*wire else {
bail!("Unexpected number of inputs")
};
wire_ensure_q8_flavour(&mut patch, name, &mut i, "input", &mut i0, DatumType::U8)?;
wire_ensure_q8_flavour(&mut patch, name, &mut k, "kernel", &mut k0, DatumType::U8)?;
let output = patch.wire_node(name, conv.clone(), &[i, k, *b, i0, *is, k0, *ks, *o0, *os])?;
patch.shunt_outside(model, node.id.into(), output[0])?;
Ok(Some(patch))
}
*/

Expand Down Expand Up @@ -184,29 +186,66 @@ fn make_1d_2d(
Ok(None)
}

fn nchw_to_nhwc(
fn conv_nchw_to_nhwc(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
if !conv.pool_spec.data_format.c_is_last() {
let mut new = conv.clone();
new.pool_spec.data_format = match conv.pool_spec.data_format {
nchw_to_nhwc(_ctx, model, node, name, &conv.pool_spec, &|pool_spec| {
Box::new(Conv { pool_spec, ..conv.clone() })
})
}

fn maxpool_nchw_to_nhwc(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
op: &MaxPool,
) -> TractResult<Option<TypedModelPatch>> {
nchw_to_nhwc(_ctx, model, node, name, &op.pool_spec, &|pool_spec| {
Box::new(MaxPool { pool_spec, ..op.clone() })
})
}

fn sumpool_nchw_to_nhwc(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
op: &SumPool,
) -> TractResult<Option<TypedModelPatch>> {
nchw_to_nhwc(_ctx, model, node, name, &op.pool_spec, &|pool_spec| {
Box::new(SumPool { pool_spec, ..op.clone() })
})
}

fn nchw_to_nhwc(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
old: &PoolSpec,
op: &dyn Fn(PoolSpec) -> Box<dyn TypedOp>,
) -> TractResult<Option<TypedModelPatch>> {
if !old.data_format.c_is_last() {
let mut new = old.clone();
new.data_format = match new.data_format {
DataFormat::NHWC | DataFormat::HWC => unreachable!(),
DataFormat::CHW => DataFormat::HWC,
DataFormat::NCHW => DataFormat::NHWC,
};
let mut patch = TypedModelPatch::default();
let fact = model.outlet_fact(node.inputs[0])?;
let shape = conv.pool_spec.data_format.shape(&fact.shape)?;
let shape = old.data_format.shape(&fact.shape)?;
let before = shape.c_axis();
let after = fact.rank() - 1;
let mut wire = patch.taps(model, &node.inputs)?;
wire[0] =
patch.wire_node(format!("{name}.nhwc"), AxisOp::Move(before, after), &[wire[0]])?[0];
wire = patch.wire_node(name, new, &wire)?;
wire = patch.wire_node(name, op(new), &wire)?;
wire = patch.wire_node(format!("{name}.nchw"), AxisOp::Move(after, before), &wire)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
Expand Down

0 comments on commit 5576526

Please sign in to comment.