Skip to content

Commit

Permalink
Merge pull request #1347 from sonos/f16-for-whisper
Browse files Browse the repository at this point in the history
F16 for whisper
  • Loading branch information
kali authored Mar 25, 2024
2 parents 000196a + 2a3cc1d commit 9590809
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 33 deletions.
5 changes: 5 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ fn main() -> tract_core::anyhow::Result<()> {
.takes_value(true)
.help("Save intermediary values as a npz file"),
)
.arg(
Arg::new("check-f16-overflow")
.long("check-f16-overflow")
.help("Check for f16 overflow in all outputs"),
)
.arg(
Arg::new("assert-sane-floats")
.long("assert-sane-floats")
Expand Down
10 changes: 10 additions & 0 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ fn run_regular(
sub_matches: &clap::ArgMatches,
) -> TractResult<TVec<Vec<TValue>>> {
let steps = sub_matches.is_present("steps");
let check_f16_overflow = sub_matches.is_present("check-f16-overflow");
let assert_sane_floats = sub_matches.is_present("assert-sane-floats");
let mut npz = if let Some(npz) = sub_matches.value_of("save-steps") {
let npz = std::fs::File::create(npz).with_context(|| format!("Creating {npz}"))?;
Expand Down Expand Up @@ -210,6 +211,15 @@ fn run_regular(
npz_add_tensor(npz, name, t)?;
}
}
if check_f16_overflow {
for (ix, o) in r.iter().enumerate() {
if let Ok(f32s) = o.as_slice::<f32>() {
if f32s.iter().any(|f| f.abs() > f16::MAX.to_f32()) {
warn!("{node}, output {ix} overflows f16");
}
}
}
}
if assert_sane_floats {
for (ix, o) in r.iter().enumerate() {
if node.op_is::<Im2Col>() || node.op_is::<MatMatMulPack>() {
Expand Down
3 changes: 2 additions & 1 deletion core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::fmt;
use tract_data::itertools::izip;

pub fn wire_cast(
prefix: &str,
prefix: impl AsRef<str>,
target: &mut TypedModel,
inputs: &[OutletId],
operating_datum_type: DatumType,
) -> TractResult<TVec<OutletId>> {
let prefix = prefix.as_ref();
let mut wires = tvec!();
for (ix, mut wire) in inputs.iter().copied().enumerate() {
if target.outlet_fact(wire)?.datum_type != operating_datum_type {
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod reduce;
mod softmax;

pub use self::data_formats::{BaseDataShape, DataFormat, DataShape, SymDataShape};
pub use self::reduce::{Reduce, Reducer};
pub use self::reduce::{Reduce, Reducer, expand_mean_of_squares};
pub use self::softmax::{Softmax, SoftmaxExp};

pub use crate::internal::*;
Expand Down
104 changes: 104 additions & 0 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::internal::Axis;
use crate::internal::*;
use crate::ops::binary::{wire_cast, wire_with_rank_broadcast, TypedBinOp};
use crate::ops::cast::cast;
use crate::ops::element_wise::ElementWiseOp;
use crate::ops::math::{div, mul, square, Mul, Square};
use std::convert::TryFrom;
use std::mem::transmute;
use tract_data::internal::ClampCast;
Expand Down Expand Up @@ -50,6 +54,7 @@ pub enum Reducer {
Min,
Prod,
Sum,
MeanOfSquares,
}

impl Reducer {
Expand Down Expand Up @@ -90,6 +95,7 @@ impl Reducer {
))
}
}
MeanOfSquares => self.mean_of_squares(axes, input)?,
};
if input.datum_type().is_quantized()
&& input.datum_type().unquantized() == t.datum_type().unquantized()
Expand Down Expand Up @@ -177,6 +183,16 @@ impl Reducer {
}
output.unwrap().into_tensor()
}

fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
let dt = input.datum_type();
let mut input = input.cast_to::<f32>()?.into_owned();
input.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
let mut output = unsafe { self.sum::<f32>(axis, &input) };
let norm = output.len() as f32 / input.len() as f32;
output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
Ok(output.cast_to_dt(dt)?.into_owned())
}
}

fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
Expand Down Expand Up @@ -298,6 +314,51 @@ impl TypedOp for Reduce {
Ok(tvec!(dt.fact(shape)))
}

fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.reducer == Reducer::Sum {
let Some(prec) = model.single_prec(node.id)? else { return Ok(None) };
let Some(prec_ew) = prec.op_as::<ElementWiseOp>() else { return Ok(None) };
if !prec_ew.0.is::<Square>() {
return Ok(None);
}
if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
return Ok(None);
}
let our_inlet = node.outputs[0].successors[0];
let succ = model.node(our_inlet.node);
let Some(succ_bin) = succ.op_as::<TypedBinOp>() else { return Ok(None) };
if !succ_bin.0.is::<Mul>() {
return Ok(None);
}
let other = succ.inputs[1 - our_inlet.slot];
let Some(other_konst) = model.outlet_fact(other)?.uniform.as_ref() else {
return Ok(None);
};
let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
let Some(norm) = norm.as_i64() else { return Ok(None) };
if norm == 0 {
return Ok(None);
}
let norm = tensor0((norm as f32).recip());
if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, prec.inputs[0])?;
let wire = patch.wire_node(
&node.name,
Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
&[wire],
)?[0];
patch.shunt_outside(model, succ.id.into(), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}

fn axes_mapping(
&self,
inputs: &[&TypedFact],
Expand Down Expand Up @@ -346,3 +407,46 @@ impl TypedOp for Reduce {

as_op!();
}

pub fn expand_mean_of_squares(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
name: &str,
op: &Reduce,
) -> TractResult<Option<TypedModelPatch>> {
if op.reducer == Reducer::MeanOfSquares {
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
let dt = model.outlet_fact(node.inputs[0])?.datum_type;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
}
wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
let input_size = patch.outlet_fact(wire[0])?.shape.volume();
let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?;
wire = patch.wire_node(
format!("{name}.sum"),
Reduce::new(op.axes.clone(), Reducer::Sum),
&wire,
)?;
let output_size = patch.outlet_fact(wire[0])?.shape.volume();
let output_size = patch.add_const(format!("{name}.output_size"), tensor0(output_size))?;
let norm = wire_cast(
format!("{name}.norm"),
&mut patch,
&[output_size, input_size],
f32::datum_type(),
)?;
let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0];
wire =
wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
}
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
} else {
Ok(None)
}
}
2 changes: 1 addition & 1 deletion core/src/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ where
None => node.op().eval(input),
}
.with_context(|| format!("Evaluating {node}"));
// eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
// eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
r
}

Expand Down
42 changes: 21 additions & 21 deletions harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected
Original file line number Diff line number Diff line change
Expand Up @@ -142,32 +142,32 @@ graph network(input) -> (output) {
i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape");
i"tdnn1.relu.output.low.cst" = [[[0.0]]];
i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn1.renorm.reduced.sq" = square(i"tdnn1.relu.output.low");
i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [1]);
i"tdnn1.renorm.scaled-recip" = [[[0.00390625]]];
i"tdnn1.renorm.scaled" = mul(i"tdnn1.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled");
i"tdnn1.renorm.reduced.sum.sqr" = square(i"tdnn1.relu.output.low");
i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2" = [[[0.00390625]]];
i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.reduced.sum.card");
i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip");
i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2);
i"tdnn2.affine.kernel.0" = variable<scalar>(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn2.affine.bias.0" = variable<scalar>(label = "tdnn2.affine.bias.0", shape = [256]);
i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn2.affine.output" = i"tdnn2.affine.output_conv";
i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low");
i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [1]);
i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled");
i"tdnn2.renorm.reduced.sum.sqr" = square(i"tdnn2.relu.output.low");
i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.reduced.sum.card");
i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip");
i"tdnn3.affine.kernel.0" = variable<scalar>(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn3.affine.bias.0" = variable<scalar>(label = "tdnn3.affine.bias.0", shape = [256]);
i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn3.affine.output" = i"tdnn3.affine.output_conv";
i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low");
i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [1]);
i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.scaled");
i"tdnn3.renorm.reduced.sum.sqr" = square(i"tdnn3.relu.output.low");
i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.reduced.sum.card");
i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip");
i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable<scalar>(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]);
i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false);
Expand Down Expand Up @@ -217,21 +217,21 @@ graph network(input) -> (output) {
i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn4.affine.output" = i"tdnn4.affine.output_conv";
i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low");
i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [1]);
i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled");
i"tdnn4.renorm.reduced.sum.sqr" = square(i"tdnn4.relu.output.low");
i"tdnn4.renorm.reduced.sum.sum" = sum_reduce(i"tdnn4.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.reduced.sum.card");
i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip");
i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2);
i"tdnn5.affine.kernel.0" = variable<scalar>(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn5.affine.bias.0" = variable<scalar>(label = "tdnn5.affine.bias.0", shape = [256]);
i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn5.affine.output" = i"tdnn5.affine.output_conv";
i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low");
i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [1]);
i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.scaled");
i"tdnn5.renorm.reduced.sum.sqr" = square(i"tdnn5.relu.output.low");
i"tdnn5.renorm.reduced.sum.sum" = sum_reduce(i"tdnn5.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.reduced.sum.card");
i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip");
i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable<scalar>(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]);
i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false);
Expand Down
4 changes: 3 additions & 1 deletion nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub fn rewrite_model(model: &mut TypedModel) -> TractResult<()> {
"rewrite_consistent_quantized_conv",
crate::ops::nnef::ser::rewrite_consistent_quantized_conv,
)
.with_rule_for("expand_mean_of_square", tract_core::ops::nn::expand_mean_of_squares)
.rewrite(&(), model)
}

Expand Down Expand Up @@ -363,7 +364,8 @@ impl<'a> IntoAst<'a> {
Self::dump_rec_tensor(&tensor.to_array_view::<f32>()?, |f| numeric(f)).into()
);
} else if have_tract_core && tensor.datum_type() == DatumType::F16 {
let array = Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
let array =
Self::dump_rec_tensor(&tensor.to_array_view::<f16>()?, |f| numeric(f)).into();
return Ok(invocation("tract_core_cast", &[array], &[("to", string("f16"))]));
} else if have_tract_core && tensor.datum_type().is_integer() {
if let Ok(value) = tensor.cast_to::<i64>() {
Expand Down
3 changes: 1 addition & 2 deletions test-rt/test-nnef-cycle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use log::*;
use tract_nnef::internal::*;
use tract_onnx_opl::*;

#[path="../suite.rs"]
#[path = "../suite.rs"]
mod suite;

mod nnef_predump {
Expand Down Expand Up @@ -42,7 +42,6 @@ mod nnef_predump {
include!(concat!(env!("OUT_DIR"), "/tests/nnef_cycle.rs"));
}


mod nnef_cycle {
use super::*;

Expand Down
6 changes: 3 additions & 3 deletions tflite/src/ops/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::registry::{DeserOp, Registry};
use crate::ser::{BuiltinOp, SubgraphBuilder};
use crate::tflite::{
ActivationFunctionType, AddOptions, AddOptionsArgs, BuiltinOperator, BuiltinOptions,
MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions, MulOptionsArgs, SubOptions,
SubOptionsArgs, DivOptions, DivOptionsArgs,
DivOptions, DivOptionsArgs, MaximumMinimumOptions, MaximumMinimumOptionsArgs, MulOptions,
MulOptionsArgs, SubOptions, SubOptionsArgs,
};
use tract_core::internal::*;
use tract_core::ops::binary::{wire_cast, wire_rank_broadcast, TypedBinOp};
Expand Down Expand Up @@ -32,7 +32,7 @@ pub fn register_all(reg: &mut Registry) {

fn wire_cast_and_rank_broadcast(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let wire = wire_cast(
&format!("{}.cast", op.prefix),
format!("{}.cast", op.prefix),
op.ctx.target,
op.inputs,
DatumType::super_type_for(op.facts()?.iter().map(|f| f.datum_type))
Expand Down
2 changes: 1 addition & 1 deletion tflite/src/ops/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ fn ser_reduce(
BuiltinOp::new(74, 1, BuiltinOperator::SUM, BuiltinOptions::ReducerOptions),
options.as_union_value(),
),
Reducer::ArgMin(_) | Reducer::ArgMax(_) => unreachable!(),
Reducer::ArgMin(_) | Reducer::ArgMax(_) | Reducer::MeanOfSquares => unreachable!(),
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use tract_core::internal::*;
use tract_core::ops::array::{Pad, PadMode};
use tract_core::ops::binary::wire_with_rank_broadcast;
use tract_core::ops::cnn::{KernelFormat, rewrite_conv_with_n_axis};
use tract_core::ops::cnn::{rewrite_conv_with_n_axis, KernelFormat};
use tract_core::ops::cnn::{Conv, PaddingSpec};
use tract_core::ops::einsum::BasicMatMul;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::math::Recip;
use tract_core::ops::nn::{DataFormat, Softmax};
use tract_core::ops::nn::{expand_mean_of_squares, DataFormat, Softmax};
use tract_core::tract_data::itertools::Itertools;

pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> {
Expand All @@ -22,6 +22,7 @@ pub fn rewrite_for_tflite(model: &mut TypedModel) -> TractResult<()> {
.with_rule_for("padding", padding)
.with_rule_for("manual_recip", manual_recip)
.with_rule_for("softmax_on_last_axis", softmax_on_last_axis)
.with_rule_for("expand-means-of-square", expand_mean_of_squares)
.rewrite(&(), model)
}

Expand Down

0 comments on commit 9590809

Please sign in to comment.