diff --git a/cli/src/main.rs b/cli/src/main.rs index 6361758c89..73e7b0572a 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -235,6 +235,11 @@ fn main() -> tract_core::anyhow::Result<()> { .takes_value(true) .help("Save intermediary values as a npz file"), ) + .arg( + Arg::new("check-f16-overflow") + .long("check-f16-overflow") + .help("Check for f16 overflow in all outputs"), + ) .arg( Arg::new("assert-sane-floats") .long("assert-sane-floats") diff --git a/cli/src/run.rs b/cli/src/run.rs index 50f08acd00..20453be12d 100644 --- a/cli/src/run.rs +++ b/cli/src/run.rs @@ -147,6 +147,7 @@ fn run_regular( sub_matches: &clap::ArgMatches, ) -> TractResult>> { let steps = sub_matches.is_present("steps"); + let check_f16_overflow = sub_matches.is_present("check-f16-overflow"); let assert_sane_floats = sub_matches.is_present("assert-sane-floats"); let mut npz = if let Some(npz) = sub_matches.value_of("save-steps") { let npz = std::fs::File::create(npz).with_context(|| format!("Creating {npz}"))?; @@ -210,6 +211,15 @@ fn run_regular( npz_add_tensor(npz, name, t)?; } } + if check_f16_overflow { + for (ix, o) in r.iter().enumerate() { + if let Ok(f32s) = o.as_slice::() { + if f32s.iter().any(|f| f.abs() > f16::MAX.to_f32()) { + warn!("{node}, output {ix} overflows f16"); + } + } + } + } if assert_sane_floats { for (ix, o) in r.iter().enumerate() { if node.op_is::() || node.op_is::() { diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 183cec9bdb..60c06a4ea1 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -4,11 +4,12 @@ use std::fmt; use tract_data::itertools::izip; pub fn wire_cast( - prefix: &str, + prefix: impl AsRef, target: &mut TypedModel, inputs: &[OutletId], operating_datum_type: DatumType, ) -> TractResult> { + let prefix = prefix.as_ref(); let mut wires = tvec!(); for (ix, mut wire) in inputs.iter().copied().enumerate() { if target.outlet_fact(wire)?.datum_type != operating_datum_type { diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 28593be282..e26619cc49 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -3,7 +3,7 @@ mod reduce; mod softmax; pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape}; -pub use self::reduce::{Reduce, Reducer}; +pub use self::reduce::{Reduce, Reducer, expand_mean_of_squares}; pub use self::softmax::{Softmax, SoftmaxExp}; pub use crate::internal::*; diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index f26cba0f5a..944ca9fdcc 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -1,5 +1,9 @@ use crate::internal::Axis; use crate::internal::*; +use crate::ops::binary::{wire_cast, wire_with_rank_broadcast, TypedBinOp}; +use crate::ops::cast::cast; +use crate::ops::element_wise::ElementWiseOp; +use crate::ops::math::{div, mul, square, Mul, Square}; use std::convert::TryFrom; use std::mem::transmute; use tract_data::internal::ClampCast; @@ -50,6 +54,7 @@ pub enum Reducer { Min, Prod, Sum, + MeanOfSquares, } impl Reducer { @@ -90,6 +95,7 @@ impl Reducer { )) } } + MeanOfSquares => self.mean_of_squares(axes, input)?, }; if input.datum_type().is_quantized() && input.datum_type().unquantized() == t.datum_type().unquantized() @@ -177,6 +183,16 @@ impl Reducer { } output.unwrap().into_tensor() } + + fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult { + let dt = input.datum_type(); + let mut input = input.cast_to::()?.into_owned(); + input.as_slice_mut::()?.iter_mut().for_each(|x| *x = *x * *x); + let mut output = unsafe { self.sum::(axis, &input) }; + let norm = output.len() as f32 / input.len() as f32; + output.as_slice_mut::()?.iter_mut().for_each(|x| *x *= norm); + Ok(output.cast_to_dt(dt)?.into_owned()) + } } fn argmax_t(v: ArrayViewD, last: bool) -> i64 @@ -298,6 +314,51 @@ impl TypedOp for Reduce { Ok(tvec!(dt.fact(shape))) } + fn declutter( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + if self.reducer == Reducer::Sum { + let Some(prec) = model.single_prec(node.id)? else { return Ok(None) }; + let Some(prec_ew) = prec.op_as::() else { return Ok(None) }; + if !prec_ew.0.is::() { + return Ok(None); + } + if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 { + return Ok(None); + } + let our_inlet = node.outputs[0].successors[0]; + let succ = model.node(our_inlet.node); + let Some(succ_bin) = succ.op_as::() else { return Ok(None) }; + if !succ_bin.0.is::() { + return Ok(None); + } + let other = succ.inputs[1 - our_inlet.slot]; + let Some(other_konst) = model.outlet_fact(other)?.uniform.as_ref() else { + return Ok(None); + }; + let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product(); + let Some(norm) = norm.as_i64() else { return Ok(None) }; + if norm == 0 { + return Ok(None); + } + let norm = tensor0((norm as f32).recip()); + if other_konst.close_enough(&norm, Approximation::Close).is_ok() { + let mut patch = TypedModelPatch::default(); + let wire = patch.tap_model(model, prec.inputs[0])?; + let wire = patch.wire_node( + &node.name, + Reduce::new(self.axes.clone(), Reducer::MeanOfSquares), + &[wire], + )?[0]; + patch.shunt_outside(model, succ.id.into(), wire)?; + return Ok(Some(patch)); + } + } + Ok(None) + } + fn axes_mapping( &self, inputs: &[&TypedFact], @@ -346,3 +407,46 @@ impl TypedOp for Reduce { as_op!(); } + +pub fn expand_mean_of_squares( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + op: &Reduce, +) -> TractResult> { + if op.reducer == Reducer::MeanOfSquares { + let mut patch = TypedModelPatch::default(); + let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); + let dt = model.outlet_fact(node.inputs[0])?.datum_type; + if dt != f32::datum_type() { + wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?; + } + wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?; + let input_size = patch.outlet_fact(wire[0])?.shape.volume(); + let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?; + wire = patch.wire_node( + format!("{name}.sum"), + Reduce::new(op.axes.clone(), Reducer::Sum), + &wire, + )?; + let output_size = patch.outlet_fact(wire[0])?.shape.volume(); + let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?; + let norm = wire_cast( + format!("{name}.norm"), + &mut patch, + &[output_size, input_size], + f32::datum_type(), + )?; + let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0]; + wire = + wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?; + if dt != f32::datum_type() { + wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?; + } + patch.shunt_outside(model, node.id.into(), wire[0])?; + Ok(Some(patch)) + } else { + Ok(None) + } +} diff --git a/core/src/plan.rs b/core/src/plan.rs index d2563e0bbe..cf39bcc520 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -553,7 +553,7 @@ where None => node.op().eval(input), } .with_context(|| format!("Evaluating {node}")); - // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); + // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); r } diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index d396503a7f..a49837fedf 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -142,11 +142,11 @@ graph network(input) -> (output) { i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape"); i"tdnn1.relu.output.low.cst" = [[[0.0]]]; i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn1.renorm.reduced.sq" = square(i"tdnn1.relu.output.low"); - i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [1]); - i"tdnn1.renorm.scaled-recip" = [[[0.00390625]]]; - i"tdnn1.renorm.scaled" = mul(i"tdnn1.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled"); + i"tdnn1.renorm.reduced.sum.sqr" = square(i"tdnn1.relu.output.low"); + i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2" = [[[0.00390625]]]; + i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.reduced.sum.card"); i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip"); i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn2.affine.kernel.0" = variable(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]); @@ -154,20 +154,20 @@ graph network(input) -> (output) { i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn2.affine.output" = i"tdnn2.affine.output_conv"; i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low"); - i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [1]); - i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled"); + i"tdnn2.renorm.reduced.sum.sqr" = square(i"tdnn2.relu.output.low"); + i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.reduced.sum.card"); i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip"); i"tdnn3.affine.kernel.0" = variable(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]); i"tdnn3.affine.bias.0" = variable(label = "tdnn3.affine.bias.0", shape = [256]); i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn3.affine.output" = i"tdnn3.affine.output_conv"; i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low"); - i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [1]); - i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.scaled"); + i"tdnn3.renorm.reduced.sum.sqr" = square(i"tdnn3.relu.output.low"); + i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.reduced.sum.card"); i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip"); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false); @@ -217,10 +217,10 @@ graph network(input) -> (output) { i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn4.affine.output" = i"tdnn4.affine.output_conv"; i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low"); - i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [1]); - i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled"); + i"tdnn4.renorm.reduced.sum.sqr" = square(i"tdnn4.relu.output.low"); + i"tdnn4.renorm.reduced.sum.sum" = sum_reduce(i"tdnn4.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.reduced.sum.card"); i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip"); i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn5.affine.kernel.0" = variable(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]); @@ -228,10 +228,10 @@ graph network(input) -> (output) { i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn5.affine.output" = i"tdnn5.affine.output_conv"; i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low"); - i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [1]); - i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.scaled"); + i"tdnn5.renorm.reduced.sum.sqr" = square(i"tdnn5.relu.output.low"); + i"tdnn5.renorm.reduced.sum.sum" = sum_reduce(i"tdnn5.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.reduced.sum.card"); i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip"); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false); diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 3168b88871..946573ee6e 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -24,6 +24,7 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> { "rewrite_consistent_quantized_conv", crate::ops::nnef::ser::rewrite_consistent_quantized_conv, ) + .with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares) .rewrite(&(), model) } @@ -363,7 +364,8 @@ impl<'a> IntoAst<'a> { Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() ); } else if have_tract_core && tensor.datum_type() == DatumType::F16 { - let array = Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); + let array = + Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))])); } else if have_tract_core && tensor.datum_type().is_integer() { if let Ok(value) = tensor.cast_to::() { diff --git a/test-rt/test-nnef-cycle/src/lib.rs b/test-rt/test-nnef-cycle/src/lib.rs index ba4aed4a68..e3ccbb2367 100644 --- a/test-rt/test-nnef-cycle/src/lib.rs +++ b/test-rt/test-nnef-cycle/src/lib.rs @@ -5,7 +5,7 @@ use log::*; use tract_nnef::internal::*; use tract_onnx_opl::*; -#[path="../suite.rs"] +#[path = "../suite.rs"] mod suite; mod nnef_predump { @@ -42,7 +42,6 @@ mod nnef_predump { include!(concat!(env!("OUT_DIR"), "/tests/nnef_cycle.rs")); } - mod nnef_cycle { use super::*; diff --git a/tflite/src/ops/math.rs b/tflite/src/ops/math.rs index 0b2edd9cd1..eb1e8cc9ee 100644 --- a/tflite/src/ops/math.rs +++ b/tflite/src/ops/math.rs @@ -3,8 +3,8 @@ use crate::registry::{DeserOp, Registry}; use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ ActivationFunctionType, AddOptions, AddOptionsArgs, BuiltinOperator, BuiltinOptions, - MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, MulOptionsArgs, SubOptions, - SubOptionsArgs, DivOptions, DivOptionsArgs, + DivOptions, DivOptionsArgs, MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, + MulOptionsArgs, SubOptions, SubOptionsArgs, }; use tract_core::internal::*; use tract_core::ops::binary::{wire_cast, wire_rank_broadcast, TypedBinOp}; @@ -32,7 +32,7 @@ pub fn register_all(reg: &mut Registry) { fn wire_cast_and_rank_broadcast(op: &mut DeserOp) -> TractResult> { let wire = wire_cast( - &format!("{}.cast", op.prefix), + format!("{}.cast", op.prefix), op.ctx.target, op.inputs, DatumType::super_type_for(op.facts()?.iter().map(|f| f.datum_type)) diff --git a/tflite/src/ops/nn.rs b/tflite/src/ops/nn.rs index c612475c08..38270e1130 100644 --- a/tflite/src/ops/nn.rs +++ b/tflite/src/ops/nn.rs @@ -272,7 +272,7 @@ fn ser_reduce( BuiltinOp::new(74, 1, BuiltinOperator::SUM, BuiltinOptions::ReducerOptions), options.as_union_value(), ), - Reducer::ArgMin(_) | Reducer::ArgMax(_) => unreachable!(), + Reducer::ArgMin(_) | Reducer::ArgMax(_) | Reducer::MeanOfSquares => unreachable!(), } } } diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 0e342bb4b2..5f812c52b3 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -1,12 +1,12 @@ use tract_core::internal::*; use tract_core::ops::array::{Pad, PadMode}; use tract_core::ops::binary::wire_with_rank_broadcast; -use tract_core::ops::cnn::{KernelFormat, rewrite_conv_with_n_axis}; +use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat}; use tract_core::ops::cnn::{Conv, PaddingSpec}; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::element_wise::ElementWiseOp; use tract_core::ops::math::Recip; -use tract_core::ops::nn::{DataFormat, Softmax}; +use tract_core::ops::nn::{expand_mean_of_squares, DataFormat, Softmax}; use tract_core::tract_data::itertools::Itertools; pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> { @@ -22,6 +22,7 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> { .with_rule_for("padding", padding) .with_rule_for("manual_recip", manual_recip) .with_rule_for("softmax_on_last_axis", softmax_on_last_axis) + .with_rule_for("expand-means-of-square", expand_mean_of_squares) .rewrite(&(), model) }