diff --git a/core/src/ops/nn/reduce.rs b/core/src/ops/nn/reduce.rs index ce4c9afdda..944ca9fdcc 100644 --- a/core/src/ops/nn/reduce.rs +++ b/core/src/ops/nn/reduce.rs @@ -418,7 +418,10 @@ pub fn expand_mean_of_squares( if op.reducer == Reducer::MeanOfSquares { let mut patch = TypedModelPatch::default(); let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?); - wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?; + let dt = model.outlet_fact(node.inputs[0])?.datum_type; + if dt != f32::datum_type() { + wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?; + } wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?; let input_size = patch.outlet_fact(wire[0])?.shape.volume(); let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?; @@ -438,11 +441,9 @@ pub fn expand_mean_of_squares( let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0]; wire = wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?; - wire = patch.wire_node( - format!("{name}.from_f32"), - cast(model.outlet_fact(node.inputs[0])?.datum_type), - &wire, - )?; + if dt != f32::datum_type() { + wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?; + } patch.shunt_outside(model, node.id.into(), wire[0])?; Ok(Some(patch)) } else { diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index d396503a7f..a49837fedf 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -142,11 +142,11 @@ graph network(input) -> (output) { i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape"); i"tdnn1.relu.output.low.cst" = [[[0.0]]]; i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn1.renorm.reduced.sq" = square(i"tdnn1.relu.output.low"); - i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [1]); - i"tdnn1.renorm.scaled-recip" = [[[0.00390625]]]; - i"tdnn1.renorm.scaled" = mul(i"tdnn1.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled"); + i"tdnn1.renorm.reduced.sum.sqr" = square(i"tdnn1.relu.output.low"); + i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2" = [[[0.00390625]]]; + i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.reduced.sum.card"); i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip"); i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn2.affine.kernel.0" = variable(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]); @@ -154,20 +154,20 @@ graph network(input) -> (output) { i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn2.affine.output" = i"tdnn2.affine.output_conv"; i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low"); - i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [1]); - i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled"); + i"tdnn2.renorm.reduced.sum.sqr" = square(i"tdnn2.relu.output.low"); + i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.reduced.sum.card"); i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip"); i"tdnn3.affine.kernel.0" = variable(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]); i"tdnn3.affine.bias.0" = variable(label = "tdnn3.affine.bias.0", shape = [256]); i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn3.affine.output" = i"tdnn3.affine.output_conv"; i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low"); - i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [1]); - i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.scaled"); + i"tdnn3.renorm.reduced.sum.sqr" = square(i"tdnn3.relu.output.low"); + i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.reduced.sum.card"); i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip"); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false); @@ -217,10 +217,10 @@ graph network(input) -> (output) { i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn4.affine.output" = i"tdnn4.affine.output_conv"; i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low"); - i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [1]); - i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled"); + i"tdnn4.renorm.reduced.sum.sqr" = square(i"tdnn4.relu.output.low"); + i"tdnn4.renorm.reduced.sum.sum" = sum_reduce(i"tdnn4.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.reduced.sum.card"); i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip"); i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2); i"tdnn5.affine.kernel.0" = variable(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]); @@ -228,10 +228,10 @@ graph network(input) -> (output) { i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn5.affine.output" = i"tdnn5.affine.output_conv"; i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst"); - i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low"); - i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [1]); - i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip"); - i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.scaled"); + i"tdnn5.renorm.reduced.sum.sqr" = square(i"tdnn5.relu.output.low"); + i"tdnn5.renorm.reduced.sum.sum" = sum_reduce(i"tdnn5.renorm.reduced.sum.sqr", axes = [1]); + i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.reduced.sum.card"); i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip"); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false);