Skip to content

Commit

Permalink
expand meanofsquare for nnef dumping
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Mar 22, 2024
1 parent 062d053 commit 2d0c6c7
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 43 deletions.
3 changes: 2 additions & 1 deletion core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::fmt;
use tract_data::itertools::izip;

pub fn wire_cast(
prefix: &str,
prefix: impl AsRef<str>,
target: &mut TypedModel,
inputs: &[OutletId],
operating_datum_type: DatumType,
) -> TractResult<TVec<OutletId>> {
let prefix = prefix.as_ref();
let mut wires = tvec!();
for (ix, mut wire) in inputs.iter().copied().enumerate() {
if target.outlet_fact(wire)?.datum_type != operating_datum_type {
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod reduce;
mod softmax;

pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
pub use self::reduce::{Reduce, Reducer};
pub use self::reduce::{Reduce, Reducer, expand_mean_of_squares};
pub use self::softmax::{Softmax, SoftmaxExp};

pub use crate::internal::*;
Expand Down
49 changes: 46 additions & 3 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::internal::Axis;
use crate::internal::*;
use crate::ops::binary::TypedBinOp;
use crate::ops::binary::{wire_cast, wire_with_rank_broadcast, TypedBinOp};
use crate::ops::cast::cast;
use crate::ops::element_wise::ElementWiseOp;
use crate::ops::math::{Mul, Square};
use crate::ops::math::{div, mul, square, Mul, Square};
use std::convert::TryFrom;
use std::mem::transmute;
use tract_data::internal::ClampCast;
Expand Down Expand Up @@ -189,7 +190,7 @@ impl Reducer {
input.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
let mut output = unsafe { self.sum::<f32>(axis, &input) };
let norm = output.len() as f32 / input.len() as f32;
output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * norm);
output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
Ok(output.cast_to_dt(dt)?.into_owned())
}
}
Expand Down Expand Up @@ -406,3 +407,45 @@ impl TypedOp for Reduce {

as_op!();
}

pub fn expand_mean_of_squares(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
op: &Reduce,
) -> TractResult<Option<TypedModelPatch>> {
if op.reducer == Reducer::MeanOfSquares {
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
let input_size = patch.outlet_fact(wire[0])?.shape.volume();
let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?;
wire = patch.wire_node(
format!("{name}.sum"),
Reduce::new(op.axes.clone(), Reducer::Sum),
&wire,
)?;
let output_size = patch.outlet_fact(wire[0])?.shape.volume();
let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?;
let norm = wire_cast(
format!("{name}.norm"),
&mut patch,
&[output_size, input_size],
f32::datum_type(),
)?;
let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0];
wire =
wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?;
wire = patch.wire_node(
format!("{name}.from_f32"),
cast(model.outlet_fact(node.inputs[0])?.datum_type),
&wire,
)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
} else {
Ok(None)
}
}
2 changes: 1 addition & 1 deletion core/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ where
None => node.op().eval(input),
}
.with_context(|| format!("Evaluating {node}"));
// eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
// eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
r
}

Expand Down
4 changes: 3 additions & 1 deletion nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> {
"rewrite_consistent_quantized_conv",
crate::ops::nnef::ser::rewrite_consistent_quantized_conv,
)
.with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares)
.rewrite(&(), model)
}

Expand Down Expand Up @@ -363,7 +364,8 @@ impl<'a> IntoAst<'a> {
Self::dump_rec_tensor(&tensor.to_array_view::<f32>()?, |f| numeric(f)).into()
);
} else if have_tract_core && tensor.datum_type() == DatumType::F16 {
let array = Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
let array =
Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))]));
} else if have_tract_core && tensor.datum_type().is_integer() {
if let Ok(value) = tensor.cast_to::<i64>() {
Expand Down
3 changes: 1 addition & 2 deletions test-rt/test-nnef-cycle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use log::*;
use tract_nnef::internal::*;
use tract_onnx_opl::*;

#[path="../suite.rs"]
#[path = "../suite.rs"]
mod suite;

mod nnef_predump {
Expand Down Expand Up @@ -42,7 +42,6 @@ mod nnef_predump {
include!(concat!(env!("OUT_DIR"), "/tests/nnef_cycle.rs"));
}


mod nnef_cycle {
use super::*;

Expand Down
6 changes: 3 additions & 3 deletions tflite/src/ops/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::registry::{DeserOp, Registry};
use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
ActivationFunctionType, AddOptions, AddOptionsArgs, BuiltinOperator, BuiltinOptions,
MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, MulOptionsArgs, SubOptions,
SubOptionsArgs, DivOptions, DivOptionsArgs,
DivOptions, DivOptionsArgs, MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions,
MulOptionsArgs, SubOptions, SubOptionsArgs,
};
use tract_core::internal::*;
use tract_core::ops::binary::{wire_cast, wire_rank_broadcast, TypedBinOp};
Expand Down Expand Up @@ -32,7 +32,7 @@ pub fn register_all(reg: &mut Registry) {

fn wire_cast_and_rank_broadcast(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let wire = wire_cast(
&format!("{}.cast", op.prefix),
format!("{}.cast", op.prefix),
op.ctx.target,
op.inputs,
DatumType::super_type_for(op.facts()?.iter().map(|f| f.datum_type))
Expand Down
33 changes: 2 additions & 31 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat};
use tract_core::ops::cnn::{Conv, PaddingSpec};
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::{div, mul, square, Recip};
use tract_core::ops::nn::{DataFormat, Reduce, Reducer, Softmax};
use tract_core::ops::math::Recip;
use tract_core::ops::nn::{expand_mean_of_squares, DataFormat, Softmax};
use tract_core::tract_data::itertools::Itertools;

pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> {
Expand Down Expand Up @@ -329,32 +329,3 @@ fn softmax_on_last_axis(
Ok(None)
}
}

fn expand_mean_of_squares(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
op: &Reduce,
) -> TractResult<Option<TypedModelPatch>> {
if op.reducer == Reducer::MeanOfSquares {
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
let input_size = patch.outlet_fact(wire[0])?.shape.volume();
let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?;
wire = patch.wire_node(
format!("{name}.sum"),
Reduce::new(op.axes.clone(), Reducer::Sum),
&wire,
)?;
let output_size = patch.outlet_fact(wire[0])?.shape.volume();
let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?;
let norm = patch.wire_node(format!("{name}.norm"), div(), &[input_size, output_size])?[0];
wire = patch.wire_node(format!("{name}.card"), mul(), &[wire[0], norm])?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
} else {
Ok(None)
}
}

0 comments on commit 2d0c6c7

Please sign in to comment.