diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index 183cec9bdb..60c06a4ea1 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -4,11 +4,12 @@ use std::fmt; use tract_data::itertools::izip; pub fn wire_cast( - prefix: &str, + prefix: impl AsRef, target: &mut TypedModel, inputs: &[OutletId], operating_datum_type: DatumType, ) -> TractResult> { + let prefix = prefix.as_ref(); let mut wires = tvec!(); for (ix, mut wire) in inputs.iter().copied().enumerate() { if target.outlet_fact(wire)?.datum_type != operating_datum_type { diff --git a/core/src/ops/nn/mod.rs b/core/src/ops/nn/mod.rs index 28593be282..e26619cc49 100644 --- a/core/src/ops/nn/mod.rs +++ b/core/src/ops/nn/mod.rs @@ -3,7 +3,7 @@ mod reduce; mod softmax; pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape}; -pub use self::reduce::{Reduce, Reducer}; +pub use self::reduce::{Reduce, Reducer, expand_mean_of_squares}; pub use self::softmax::{Softmax, SoftmaxExp}; pub use crate::internal::*; diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index 1e56ef9456..ce4c9afdda 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -1,8 +1,9 @@ use crate::internal::Axis; use crate::internal::*; -use crate::ops::binary::TypedBinOp; +use crate::ops::binary::{wire_cast, wire_with_rank_broadcast, TypedBinOp}; +use crate::ops::cast::cast; use crate::ops::element_wise::ElementWiseOp; -use crate::ops::math::{Mul, Square}; +use crate::ops::math::{div, mul, square, Mul, Square}; use std::convert::TryFrom; use std::mem::transmute; use tract_data::internal::ClampCast; @@ -189,7 +190,7 @@ impl Reducer { input.as_slice_mut::()?.iter_mut().for_each(|x| *x = *x * *x); let mut output = unsafe { self.sum::(axis, &input) }; let norm = output.len() as f32 / input.len() as f32; - output.as_slice_mut::()?.iter_mut().for_each(|x| *x = *x * norm); + output.as_slice_mut::()?.iter_mut().for_each(|x| *x *= norm); Ok(output.cast_to_dt(dt)?.into_owned()) } } @@ -406,3 +407,45 @@ impl TypedOp for Reduce { as_op!(); } + +pub fn expand_mean_of_squares( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + name: &str, + op: &Reduce, +) -> TractResult> { + if op.reducer == Reducer::MeanOfSquares { + let mut patch = TypedModelPatch::default(); + let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); + wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?; + wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?; + let input_size = patch.outlet_fact(wire[0])?.shape.volume(); + let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?; + wire = patch.wire_node( + format!("{name}.sum"), + Reduce::new(op.axes.clone(), Reducer::Sum), + &wire, + )?; + let output_size = patch.outlet_fact(wire[0])?.shape.volume(); + let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?; + let norm = wire_cast( + format!("{name}.norm"), + &mut patch, + &[output_size, input_size], + f32::datum_type(), + )?; + let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0]; + wire = + wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?; + wire = patch.wire_node( + format!("{name}.from_f32"), + cast(model.outlet_fact(node.inputs[0])?.datum_type), + &wire, + )?; + patch.shunt_outside(model, node.id.into(), wire[0])?; + Ok(Some(patch)) + } else { + Ok(None) + } +} diff --git a/core/src/plan.rs b/core/src/plan.rs index d2563e0bbe..cf39bcc520 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -553,7 +553,7 @@ where None => node.op().eval(input), } .with_context(|| format!("Evaluating {node}")); - // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); + // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); r } diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 3168b88871..946573ee6e 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -24,6 +24,7 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> { "rewrite_consistent_quantized_conv", crate::ops::nnef::ser::rewrite_consistent_quantized_conv, ) + .with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares) .rewrite(&(), model) } @@ -363,7 +364,8 @@ impl<'a> IntoAst<'a> { Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into() ); } else if have_tract_core && tensor.datum_type() == DatumType::F16 { - let array = Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); + let array = + Self::dump_rec_tensor(&tensor.to_array_view::()?, |f| numeric(f)).into(); return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))])); } else if have_tract_core && tensor.datum_type().is_integer() { if let Ok(value) = tensor.cast_to::() { diff --git a/test-rt/test-nnef-cycle/src/lib.rs b/test-rt/test-nnef-cycle/src/lib.rs index ba4aed4a68..e3ccbb2367 100644 --- a/test-rt/test-nnef-cycle/src/lib.rs +++ b/test-rt/test-nnef-cycle/src/lib.rs @@ -5,7 +5,7 @@ use log::*; use tract_nnef::internal::*; use tract_onnx_opl::*; -#[path="../suite.rs"] +#[path = "../suite.rs"] mod suite; mod nnef_predump { @@ -42,7 +42,6 @@ mod nnef_predump { include!(concat!(env!("OUT_DIR"), "/tests/nnef_cycle.rs")); } - mod nnef_cycle { use super::*; diff --git a/tflite/src/ops/math.rs b/tflite/src/ops/math.rs index 0b2edd9cd1..eb1e8cc9ee 100644 --- a/tflite/src/ops/math.rs +++ b/tflite/src/ops/math.rs @@ -3,8 +3,8 @@ use crate::registry::{DeserOp, Registry}; use crate::ser::{BuiltinOp, SubgraphBuilder}; use crate::tflite::{ ActivationFunctionType, AddOptions, AddOptionsArgs, BuiltinOperator, BuiltinOptions, - MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, MulOptionsArgs, SubOptions, - SubOptionsArgs, DivOptions, DivOptionsArgs, + DivOptions, DivOptionsArgs, MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, + MulOptionsArgs, SubOptions, SubOptionsArgs, }; use tract_core::internal::*; use tract_core::ops::binary::{wire_cast, wire_rank_broadcast, TypedBinOp}; @@ -32,7 +32,7 @@ pub fn register_all(reg: &mut Registry) { fn wire_cast_and_rank_broadcast(op: &mut DeserOp) -> TractResult> { let wire = wire_cast( - &format!("{}.cast", op.prefix), + format!("{}.cast", op.prefix), op.ctx.target, op.inputs, DatumType::super_type_for(op.facts()?.iter().map(|f| f.datum_type)) diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index c4f89f2486..5f812c52b3 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -5,8 +5,8 @@ use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat}; use tract_core::ops::cnn::{Conv, PaddingSpec}; use tract_core::ops::einsum::BasicMatMul; use tract_core::ops::element_wise::ElementWiseOp; -use tract_core::ops::math::{div, mul, square, Recip}; -use tract_core::ops::nn::{DataFormat, Reduce, Reducer, Softmax}; +use tract_core::ops::math::Recip; +use tract_core::ops::nn::{expand_mean_of_squares, DataFormat, Softmax}; use tract_core::tract_data::itertools::Itertools; pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> { @@ -329,32 +329,3 @@ fn softmax_on_last_axis( Ok(None) } } - -fn expand_mean_of_squares( - _ctx: &(), - model: &TypedModel, - node: &TypedNode, - name: &str, - op: &Reduce, -) -> TractResult> { - if op.reducer == Reducer::MeanOfSquares { - let mut patch = TypedModelPatch::default(); - let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); - wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?; - let input_size = patch.outlet_fact(wire[0])?.shape.volume(); - let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?; - wire = patch.wire_node( - format!("{name}.sum"), - Reduce::new(op.axes.clone(), Reducer::Sum), - &wire, - )?; - let output_size = patch.outlet_fact(wire[0])?.shape.volume(); - let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?; - let norm = patch.wire_node(format!("{name}.norm"), div(), &[input_size, output_size])?[0]; - wire = patch.wire_node(format!("{name}.card"), mul(), &[wire[0], norm])?; - patch.shunt_outside(model, node.id.into(), wire[0])?; - Ok(Some(patch)) - } else { - Ok(None) - } -}