Skip to content

Commit

Permalink
avoid cast clutter in meanofsquare
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Mar 22, 2024
1 parent 2d0c6c7 commit 2a3cc1d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 27 deletions.
13 changes: 7 additions & 6 deletions core/src/ops/nn/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,10 @@ pub fn expand_mean_of_squares(
if op.reducer == Reducer::MeanOfSquares {
let mut patch = TypedModelPatch::default();
let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
let dt = model.outlet_fact(node.inputs[0])?.datum_type;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
}
wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
let input_size = patch.outlet_fact(wire[0])?.shape.volume();
let input_size = patch.add_const(format!("{name}.input_size"), tensor0(input_size))?;
Expand All @@ -438,11 +441,9 @@ pub fn expand_mean_of_squares(
let norm = patch.wire_node(format!("{name}.norm"), div(), &norm)?[0];
wire =
wire_with_rank_broadcast(format!("{name}.card"), &mut patch, mul(), &[wire[0], norm])?;
wire = patch.wire_node(
format!("{name}.from_f32"),
cast(model.outlet_fact(node.inputs[0])?.datum_type),
&wire,
)?;
if dt != f32::datum_type() {
wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
}
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
} else {
Expand Down
42 changes: 21 additions & 21 deletions harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected
Original file line number Diff line number Diff line change
Expand Up @@ -142,32 +142,32 @@ graph network(input) -> (output) {
i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape");
i"tdnn1.relu.output.low.cst" = [[[0.0]]];
i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn1.renorm.reduced.sq" = square(i"tdnn1.relu.output.low");
i"tdnn1.renorm.reduced.sum" = sum_reduce(i"tdnn1.renorm.reduced.sq", axes = [1]);
i"tdnn1.renorm.scaled-recip" = [[[0.00390625]]];
i"tdnn1.renorm.scaled" = mul(i"tdnn1.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.scaled");
i"tdnn1.renorm.reduced.sum.sqr" = square(i"tdnn1.relu.output.low");
i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2" = [[[0.00390625]]];
i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.reduced.sum.card");
i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip");
i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2);
i"tdnn2.affine.kernel.0" = variable<scalar>(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn2.affine.bias.0" = variable<scalar>(label = "tdnn2.affine.bias.0", shape = [256]);
i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn2.affine.output" = i"tdnn2.affine.output_conv";
i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn2.renorm.reduced.sq" = square(i"tdnn2.relu.output.low");
i"tdnn2.renorm.reduced.sum" = sum_reduce(i"tdnn2.renorm.reduced.sq", axes = [1]);
i"tdnn2.renorm.scaled" = mul(i"tdnn2.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.scaled");
i"tdnn2.renorm.reduced.sum.sqr" = square(i"tdnn2.relu.output.low");
i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.reduced.sum.card");
i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip");
i"tdnn3.affine.kernel.0" = variable<scalar>(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn3.affine.bias.0" = variable<scalar>(label = "tdnn3.affine.bias.0", shape = [256]);
i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn3.affine.output" = i"tdnn3.affine.output_conv";
i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn3.renorm.reduced.sq" = square(i"tdnn3.relu.output.low");
i"tdnn3.renorm.reduced.sum" = sum_reduce(i"tdnn3.renorm.reduced.sq", axes = [1]);
i"tdnn3.renorm.scaled" = mul(i"tdnn3.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.scaled");
i"tdnn3.renorm.reduced.sum.sqr" = square(i"tdnn3.relu.output.low");
i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.reduced.sum.card");
i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip");
i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable<scalar>(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]);
i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false);
Expand Down Expand Up @@ -217,21 +217,21 @@ graph network(input) -> (output) {
i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn4.affine.output" = i"tdnn4.affine.output_conv";
i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn4.renorm.reduced.sq" = square(i"tdnn4.relu.output.low");
i"tdnn4.renorm.reduced.sum" = sum_reduce(i"tdnn4.renorm.reduced.sq", axes = [1]);
i"tdnn4.renorm.scaled" = mul(i"tdnn4.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.scaled");
i"tdnn4.renorm.reduced.sum.sqr" = square(i"tdnn4.relu.output.low");
i"tdnn4.renorm.reduced.sum.sum" = sum_reduce(i"tdnn4.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.reduced.sum.card");
i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip");
i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2);
i"tdnn5.affine.kernel.0" = variable<scalar>(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]);
i"tdnn5.affine.bias.0" = variable<scalar>(label = "tdnn5.affine.bias.0", shape = [256]);
i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]);
i"tdnn5.affine.output" = i"tdnn5.affine.output_conv";
i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst");
i"tdnn5.renorm.reduced.sq" = square(i"tdnn5.relu.output.low");
i"tdnn5.renorm.reduced.sum" = sum_reduce(i"tdnn5.renorm.reduced.sq", axes = [1]);
i"tdnn5.renorm.scaled" = mul(i"tdnn5.renorm.reduced.sum", i"tdnn1.renorm.scaled-recip");
i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.scaled");
i"tdnn5.renorm.reduced.sum.sqr" = square(i"tdnn5.relu.output.low");
i"tdnn5.renorm.reduced.sum.sum" = sum_reduce(i"tdnn5.renorm.reduced.sum.sqr", axes = [1]);
i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2");
i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.reduced.sum.card");
i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip");
i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0" = variable<scalar>(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", shape = [1, 256, 256]);
i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b.0", transposeA = true, transposeB = false);
Expand Down

0 comments on commit 2a3cc1d

Please sign in to comment.