From 9b47a0ba6b4a128716f5957d14680611e4eade52 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Dec 2024 13:38:08 +0100 Subject: [PATCH 1/4] dont propagate axes changes through pseudo-scalar constants --- core/src/optim/change_axes.rs | 53 ++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/core/src/optim/change_axes.rs b/core/src/optim/change_axes.rs index 8cafb7060d..c5c51561aa 100644 --- a/core/src/optim/change_axes.rs +++ b/core/src/optim/change_axes.rs @@ -4,6 +4,7 @@ use crate::internal::*; use crate::model::*; use crate::ops::dummy::Dummy; use crate::ops::einsum::EinSum; +use crate::ops::konst::Const; use std::collections::hash_map::Entry; use std::collections::HashSet; use std::fmt::Debug; @@ -41,9 +42,7 @@ impl TypedPass for ChangeAxes { let change = AxisChange { outlet, op: suggestion.1 }; if self.0.insert(change.clone()) { if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[]) - .with_context(|| { - format!("Making patch for {:?} from {}", change, node) - })? + .with_context(|| format!("Making patch for {:?} from {}", change, node))? { self.1 = node.id; return Ok(Some(patch)); @@ -80,6 +79,9 @@ pub fn change_axes( locked: &[OutletId], bounds: &[TVec], ) -> TractResult)>> { + if model.node(change.outlet.node).op_as::().is_some_and(|c| c.0.volume() == 1) { + return Ok(None); + } debug!(" Considering change {:?}", change); let mut todo_changes = vec![(change.clone(), None)]; let mut changed_wires: HashMap, AxisOp> = HashMap::new(); @@ -88,14 +90,15 @@ pub fn change_axes( }; changed_wires.insert(bound_outlets(change.outlet), change.op.clone()); let mut changed_ops: HashMap> = HashMap::new(); - while let Some((c, emitter)) = todo_changes.pop() { - let outlet_group = bound_outlets(c.outlet); + let mut rewired_scalar_input: HashMap = Default::default(); + while let Some((change, emitter)) = todo_changes.pop() { + let outlet_group = bound_outlets(change.outlet); for &outlet in &outlet_group { if locked.contains(&outlet) { debug!(" Change {:?} blocked by locked interface {:?}", change, outlet); return Ok(None); } - let mut interfaces = vec![(outlet.node, InOut::Out(outlet.slot))]; + let mut interfaces: Vec<(usize, InOut)> = vec![(outlet.node, InOut::Out(outlet.slot))]; for inlet in model.outlet_successors(outlet) { interfaces.push((inlet.node, InOut::In(inlet.slot))); } @@ -104,6 +107,7 @@ pub fn change_axes( continue; } let node = model.node(node_id); + // if this is a revisit... let op = if let Some(op) = changed_ops.get(&node_id) { trace!(" Change {:?} revisiting {}", change, model.node(node_id)); if op.is::() { @@ -117,20 +121,33 @@ pub fn change_axes( &node.op }; let more = op - .change_axes(model, node, io, &c.op) + .change_axes(model, node, io, &change.op) .with_context(|| format!("Propagating {change:?} to node {node}"))?; if more.is_none() { debug!(" Propagation of {:?} blocked by {}", change, node); return Ok(None); } let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap(); - trace!(" Change {:?} enters {} from {:?}", c.op, node, io); + trace!(" Change {:?} enters {} from {:?}", change.op, node, io); trace!(" propagates as {:?}", wire_changes); if let Some(op) = substitute_op { trace!(" replace op by {:?}", op); changed_ops.insert(node.id, op); } for (wire, op) in wire_changes.into_iter() { + let outlet = wire.as_outlet(node); + // stop upstram propagation to a scalar constant: we will clone it and alter it + // at patch generation time + if let InOut::In(inlet) = wire { + if model + .node(outlet.node) + .op_as::() + .is_some_and(|k| k.0.volume() == 1) + { + rewired_scalar_input.insert(InletId::new(node.id, inlet), (outlet, op)); + continue; + } + } let outlet_group = bound_outlets(wire.as_outlet(node)); match changed_wires.entry(outlet_group.clone()) { Entry::Vacant(entry) => { @@ -180,11 +197,21 @@ pub fn change_axes( let node = model.node(node_id); if nodes_to_replace.contains(&node_id) { let mut inputs = tvec!(); - for orig in &node.inputs { - let tgt = replaced_wires - .entry(*orig) - .or_insert_with(|| patch.tap_model(model, *orig).unwrap()); - inputs.push(*tgt); + for (slot, orig) in node.inputs.iter().enumerate() { + let tgt = if let Some((outlet, alteration)) = + rewired_scalar_input.get(&InletId::new(node_id, slot)) + { + let const_node = model.node(outlet.node); + let mut value = const_node.op_as::().unwrap().0.clone().into_tensor(); + alteration.change_tensor(&mut value, false)?; + let name = model.unique_name(&const_node.name); + patch.add_const(name, value)? + } else { + *replaced_wires + .entry(*orig) + .or_insert_with(|| patch.tap_model(model, *orig).unwrap()) + }; + inputs.push(tgt); } let op: Box = changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone()); From ba0e018f146195cd8fbb68e4104544e557625281 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Dec 2024 14:06:22 +0100 Subject: [PATCH 2/4] no more propagation through scalar constants --- .../mdl-en-2019-Q3-librispeech/expected | 165 +++++++++--------- 1 file changed, 82 insertions(+), 83 deletions(-) diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index 8c5d7f19d1..0106afa885 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -27,20 +27,20 @@ fragment scan_body_0( ) -> (i"fastlstm1.c_new": tensor, i"fastlstm1.r_new": tensor, i"fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0": tensor) { i"fastlstm1.peephole0.mul" = mul(i"fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.c"); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256" = add(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256"); i"fastlstm1.four_parts.split-over-1.0..256" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"); i"fastlstm1.peephole0.output" = add(i"fastlstm1.peephole0.mul", i"fastlstm1.four_parts.split-over-1.0..256"); i"fastlstm1.peephole0.output.nolin" = sigmoid(i"fastlstm1.peephole0.output"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"fastlstm1.four_parts.split-over-1.512..768" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"); i"fastlstm1.four_parts.j.nolin" = tanh(i"fastlstm1.four_parts.split-over-1.512..768"); i"fastlstm1.c_update" = mul(i"fastlstm1.peephole0.output.nolin", i"fastlstm1.four_parts.j.nolin"); i"fastlstm1.peephole1.mul" = mul(i"fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.c"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"fastlstm1.four_parts.split-over-1.256..512" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"); i"fastlstm1.peephole1.output" = add(i"fastlstm1.peephole1.mul", i"fastlstm1.four_parts.split-over-1.256..512"); @@ -50,13 +50,13 @@ fragment scan_body_0( i"fastlstm1.tanh_c" = tanh(i"fastlstm1.c_new"); i"fastlstm1.peephole2.mul" = mul(i"fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.c_new"); i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); - i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm1.r", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"fastlstm1.four_parts.split-over-1.768..1024" = add(i"fastlstm1.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"); i"fastlstm1.peephole2.output" = add(i"fastlstm1.peephole2.mul", i"fastlstm1.four_parts.split-over-1.768..1024"); i"fastlstm1.peephole2.output.nolin" = sigmoid(i"fastlstm1.peephole2.output"); i"fastlstm1.m" = mul(i"fastlstm1.tanh_c", i"fastlstm1.peephole2.output.nolin"); - i"fastlstm1.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm1.m", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->mna", acc = "f32", output = ""); + i"fastlstm1.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm1.m", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->an", acc = "f32", output = ""); i"fastlstm1.h_new.split-over-1.0..128" = add(i"fastlstm1.h_new.W.split-over-1.0..128", i"fastlstm1.h_new.split-1-over-1.0..128.slice"); i"fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(i"fastlstm1.m", axes = [0]); i"fastlstm1.r_new" = i"fastlstm1.h_new.split-over-1.0..128"; @@ -85,20 +85,20 @@ fragment scan_body_1( ) -> (i"fastlstm2.c_new": tensor, i"fastlstm2.r_new": tensor, i"fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0": tensor) { i"fastlstm2.peephole0.mul" = mul(i"fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.c"); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256" = add(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.0..256"); i"fastlstm2.four_parts.split-over-1.0..256" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"); i"fastlstm2.peephole0.output" = add(i"fastlstm2.peephole0.mul", i"fastlstm2.four_parts.split-over-1.0..256"); i"fastlstm2.peephole0.output.nolin" = sigmoid(i"fastlstm2.peephole0.output"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"fastlstm2.four_parts.split-over-1.512..768" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"); i"fastlstm2.four_parts.j.nolin" = tanh(i"fastlstm2.four_parts.split-over-1.512..768"); i"fastlstm2.c_update" = mul(i"fastlstm2.peephole0.output.nolin", i"fastlstm2.four_parts.j.nolin"); i"fastlstm2.peephole1.mul" = mul(i"fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.c"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"fastlstm2.four_parts.split-over-1.256..512" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"); i"fastlstm2.peephole1.output" = add(i"fastlstm2.peephole1.mul", i"fastlstm2.four_parts.split-over-1.256..512"); @@ -108,13 +108,13 @@ fragment scan_body_1( i"fastlstm2.tanh_c" = tanh(i"fastlstm2.c_new"); i"fastlstm2.peephole2.mul" = mul(i"fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.c_new"); i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); - i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "mka,kn->bn", acc = "f32", output = ""); + i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([i"fastlstm2.r", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"fastlstm2.four_parts.split-over-1.768..1024" = add(i"fastlstm2.four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"); i"fastlstm2.peephole2.output" = add(i"fastlstm2.peephole2.mul", i"fastlstm2.four_parts.split-over-1.768..1024"); i"fastlstm2.peephole2.output.nolin" = sigmoid(i"fastlstm2.peephole2.output"); i"fastlstm2.m" = mul(i"fastlstm2.tanh_c", i"fastlstm2.peephole2.output.nolin"); - i"fastlstm2.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm2.m", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->mna", acc = "f32", output = ""); + i"fastlstm2.h_new.W.split-over-1.0..128" = tract_core_einsum([i"fastlstm2.m", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->an", acc = "f32", output = ""); i"fastlstm2.h_new.split-over-1.0..128" = add(i"fastlstm2.h_new.W.split-over-1.0..128", i"fastlstm2.h_new.split-1-over-1.0..128.slice"); i"fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(i"fastlstm2.m", axes = [0]); i"fastlstm2.r_new" = i"fastlstm2.h_new.split-over-1.0..128"; @@ -136,60 +136,59 @@ graph network(input) -> (output) { i"lda.output_conv" = conv(i"lda.output_input", i"lda.kernel.0", i"lda.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"lda.output" = transpose(i"lda.output_conv", axes = [0, 2, 1]); i"lda.output.rm_n" = squeeze(i"lda.output", axes = [0]); - i"tdnn1.affine.output.einsum.fix_a" = unsqueeze(i"lda.output.rm_n", axes = [0]); - i"tdnn1.affine.output.einsum.fix_b" = variable(label = "tdnn1.affine.output.einsum.fix_b", shape = [1, 256, 200]); - i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.output.einsum.fix_b", i"tdnn1.affine.output.einsum.fix_a", transposeA = false, transposeB = true); - i"tdnn1.affine.output.bias.reshape.1" = variable(label = "tdnn1.affine.output.bias.reshape.1", shape = [1, 256, 1]); - i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.output.bias.reshape.1"); - i"tdnn1.relu.output.low.cst" = [[[0.0]]]; - i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst"); + i"tdnn1.affine.kernel.0" = variable(label = "tdnn1.affine.kernel.0", shape = [256, 200]); + i"tdnn1.affine.output.einsum" = matmul(i"tdnn1.affine.kernel.0", i"lda.output.rm_n", transposeA = false, transposeB = true); + i"tdnn1.affine.bias.0" = variable(label = "tdnn1.affine.bias.0", shape = [256, 1]); + i"tdnn1.affine.output" = add(i"tdnn1.affine.output.einsum", i"tdnn1.affine.bias.0"); + i"tdnn1.relu.output.low.cst.1" = [[0.0]]; + i"tdnn1.relu.output.low" = max(i"tdnn1.affine.output", i"tdnn1.relu.output.low.cst.1"); i"tdnn1.renorm.reduced.sum.sqr" = square(i"tdnn1.relu.output.low"); - i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [1]); - i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2" = [[[0.00390625]]]; - i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn1.renorm.reduced.sum.sum" = sum_reduce(i"tdnn1.renorm.reduced.sum.sqr", axes = [0]); + i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1" = [[0.00390625]]; + i"tdnn1.renorm.reduced.sum.card" = mul(i"tdnn1.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1"); i"tdnn1.renorm.output-recip" = rsqrt(i"tdnn1.renorm.reduced.sum.card"); i"tdnn1.renorm.output" = mul(i"tdnn1.relu.output.low", i"tdnn1.renorm.output-recip"); - i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 2, delay = 0, overlap = 2); + i"tdnn2.affine.output.delay" = tract_pulse_delay(i"tdnn1.renorm.output", axis = 1, delay = 0, overlap = 2); + i"tdnn2.affine.output.add_n" = unsqueeze(i"tdnn2.affine.output.delay", axes = [0]); i"tdnn2.affine.kernel.0" = variable(label = "tdnn2.affine.kernel.0", shape = [256, 256, 3]); i"tdnn2.affine.bias.0" = variable(label = "tdnn2.affine.bias.0", shape = [256]); - i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.delay", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn2.affine.output_conv" = conv(i"tdnn2.affine.output.add_n", i"tdnn2.affine.kernel.0", i"tdnn2.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn2.affine.output" = i"tdnn2.affine.output_conv"; - i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output", i"tdnn1.relu.output.low.cst"); + i"tdnn2.affine.output.rm_n" = squeeze(i"tdnn2.affine.output", axes = [0]); + i"tdnn2.relu.output.low" = max(i"tdnn2.affine.output.rm_n", i"tdnn1.relu.output.low.cst.1"); i"tdnn2.renorm.reduced.sum.sqr" = square(i"tdnn2.relu.output.low"); - i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [1]); - i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn2.renorm.reduced.sum.sum" = sum_reduce(i"tdnn2.renorm.reduced.sum.sqr", axes = [0]); + i"tdnn2.renorm.reduced.sum.card" = mul(i"tdnn2.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1"); i"tdnn2.renorm.output-recip" = rsqrt(i"tdnn2.renorm.reduced.sum.card"); i"tdnn2.renorm.output" = mul(i"tdnn2.relu.output.low", i"tdnn2.renorm.output-recip"); + i"tdnn3.affine.output.add_n" = unsqueeze(i"tdnn2.renorm.output", axes = [0]); i"tdnn3.affine.kernel.0" = variable(label = "tdnn3.affine.kernel.0", shape = [256, 256, 3]); i"tdnn3.affine.bias.0" = variable(label = "tdnn3.affine.bias.0", shape = [256]); - i"tdnn3.affine.output_conv" = conv(i"tdnn2.renorm.output", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn3.affine.output_conv" = conv(i"tdnn3.affine.output.add_n", i"tdnn3.affine.kernel.0", i"tdnn3.affine.bias.0", dilation = [1], stride = [3], border = "constant", groups = 1, padding = [(0, 0)]); i"tdnn3.affine.output" = i"tdnn3.affine.output_conv"; - i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output", i"tdnn1.relu.output.low.cst"); + i"tdnn3.affine.output.rm_n" = squeeze(i"tdnn3.affine.output", axes = [0]); + i"tdnn3.relu.output.low" = max(i"tdnn3.affine.output.rm_n", i"tdnn1.relu.output.low.cst.1"); i"tdnn3.renorm.reduced.sum.sqr" = square(i"tdnn3.relu.output.low"); - i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [1]); - i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn3.renorm.reduced.sum.sum" = sum_reduce(i"tdnn3.renorm.reduced.sum.sqr", axes = [0]); + i"tdnn3.renorm.reduced.sum.card" = mul(i"tdnn3.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1"); i"tdnn3.renorm.output-recip" = rsqrt(i"tdnn3.renorm.reduced.sum.card"); i"tdnn3.renorm.output" = mul(i"tdnn3.relu.output.low", i"tdnn3.renorm.output-recip"); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b", shape = [1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", axes = [0]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", transposeA = true, transposeB = false); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", axes = [1, 0, 2]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b", shape = [1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1, 0, 2]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", axes = [1, 0, 2]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b", shape = [1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1, 0, 2]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a" = unsqueeze(i"tdnn3.renorm.output", axes = [0]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = squeeze(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", axes = [1, 0, 2]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b", shape = [1, 256, 256]); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b", transposeA = true, transposeB = false); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = transpose(i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1, 0, 2]); i"tap.tap.fastlstm1.c_init.0-35/0-100/0" = variable(label = "tap.tap.fastlstm1.c_init.0-35/0-100/0", shape = [1, 256]); - i"tap.fastlstm1.r_init.0-36/0" = variable(label = "tap.fastlstm1.r_init.0-36/0", shape = [1, 128, 1]); + i"tap.fastlstm1.r_init.0-36/0" = variable(label = "tap.fastlstm1.r_init.0-36/0", shape = [1, 128]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", shape = [128, 256]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", shape = [128, 256]); @@ -199,59 +198,60 @@ graph network(input) -> (output) { i"fastlstm1.four_parts.split-1-over-1.512..768.slice" = variable(label = "fastlstm1.four_parts.split-1-over-1.512..768.slice", shape = [1, 256]); i"fastlstm1.four_parts.split-1-over-1.768..1024.slice" = variable(label = "fastlstm1.four_parts.split-1-over-1.768..1024.slice", shape = [1, 256]); i"fastlstm1.h_new.W.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.W.split-1-over-1.0..128.slice", shape = [256, 128]); - i"fastlstm1.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.split-1-over-1.0..128.slice", shape = [1, 128, 1]); + i"fastlstm1.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm1.h_new.split-1-over-1.0..128.slice", shape = [1, 128]); i"fastlstm1.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm1.c_new"), ("fastlstm1.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 2, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); + ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", 0, 1), ("fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", 0, 1)], full = [("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm1.four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("fastlstm1.four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("fastlstm1.four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("fastlstm1.four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm1.h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm1.h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("fastlstm1.peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("fastlstm1.peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("fastlstm1.peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("fastlstm1.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm1.c_new"), ("fastlstm1.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm1.r_new")], output = [("fastlstm1.r_new", "full", 0, 1), ("fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2, reset_every_turn = false); i"fastlstm1.h_new.W.split-over-1.128..256.fix_a" = transpose(i"fastlstm1.c_final_1", axes = [1, 0, 2]); - i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.1" = unsqueeze(i"fastlstm1.h_new.W.split-over-1.128..256.fix_a", axes = [0]); - i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.1" = variable(label = "fastlstm1.h_new.W.split-over-1.128..256.fix_b.1", shape = [1, 1, 256, 128]); - i"fastlstm1.h_new.W.split-over-1.128..256" = matmul(i"fastlstm1.h_new.W.split-over-1.128..256.fix_b.1", i"fastlstm1.h_new.W.split-over-1.128..256.fix_a.1", transposeA = true, transposeB = true); - i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0" = squeeze(i"fastlstm1.h_new.W.split-over-1.128..256", axes = [1]); - i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice", shape = [1, 128, 1]); + i"fastlstm1.h_new.W.split-over-1.128..256.fix_b" = variable(label = "fastlstm1.h_new.W.split-over-1.128..256.fix_b", shape = [1, 256, 128]); + i"fastlstm1.h_new.W.split-over-1.128..256" = matmul(i"fastlstm1.h_new.W.split-over-1.128..256.fix_a", i"fastlstm1.h_new.W.split-over-1.128..256.fix_b", transposeA = false, transposeB = false); + i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0" = squeeze(i"fastlstm1.h_new.W.split-over-1.128..256", axes = [0]); + i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice", shape = [1, 128]); i"fastlstm1.h_new.split-over-1.128..256" = add(i"fastlstm1.h_new.W.split-over-1.128..256.fix_c.0", i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice"); i"fastlstm1.h_new.concat-1" = concat([i"fastlstm1.c_final", i"fastlstm1.h_new.split-over-1.128..256"], axis = 1); - i"tdnn4.affine.output.delay" = tract_pulse_delay(i"fastlstm1.h_new.concat-1", axis = 2, delay = 0, overlap = 2); + i"tdnn4.affine.output.delay" = tract_pulse_delay(i"fastlstm1.h_new.concat-1", axis = 0, delay = 0, overlap = 2); + i"tdnn4.affine.output.add_n" = unsqueeze(i"tdnn4.affine.output.delay", axes = [0]); i"tdnn4.affine.kernel.0" = variable(label = "tdnn4.affine.kernel.0", shape = [256, 256, 3]); i"tdnn4.affine.bias.0" = variable(label = "tdnn4.affine.bias.0", shape = [256]); - i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output.delay", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); - i"tdnn4.affine.output" = i"tdnn4.affine.output_conv"; - i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output", i"tdnn1.relu.output.low.cst"); + i"tdnn4.affine.output_input" = transpose(i"tdnn4.affine.output.add_n", axes = [0, 2, 1]); + i"tdnn4.affine.output_conv" = conv(i"tdnn4.affine.output_input", i"tdnn4.affine.kernel.0", i"tdnn4.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn4.affine.output" = transpose(i"tdnn4.affine.output_conv", axes = [0, 2, 1]); + i"tdnn4.affine.output.rm_n" = squeeze(i"tdnn4.affine.output", axes = [0]); + i"tdnn4.relu.output.low" = max(i"tdnn4.affine.output.rm_n", i"tdnn1.relu.output.low.cst.1"); i"tdnn4.renorm.reduced.sum.sqr" = square(i"tdnn4.relu.output.low"); i"tdnn4.renorm.reduced.sum.sum" = sum_reduce(i"tdnn4.renorm.reduced.sum.sqr", axes = [1]); - i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn4.renorm.reduced.sum.card" = mul(i"tdnn4.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1"); i"tdnn4.renorm.output-recip" = rsqrt(i"tdnn4.renorm.reduced.sum.card"); i"tdnn4.renorm.output" = mul(i"tdnn4.relu.output.low", i"tdnn4.renorm.output-recip"); - i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 2, delay = 0, overlap = 2); + i"tdnn5.affine.output.delay" = tract_pulse_delay(i"tdnn4.renorm.output", axis = 0, delay = 0, overlap = 2); + i"tdnn5.affine.output.add_n" = unsqueeze(i"tdnn5.affine.output.delay", axes = [0]); i"tdnn5.affine.kernel.0" = variable(label = "tdnn5.affine.kernel.0", shape = [256, 256, 3]); i"tdnn5.affine.bias.0" = variable(label = "tdnn5.affine.bias.0", shape = [256]); - i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output.delay", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); - i"tdnn5.affine.output" = i"tdnn5.affine.output_conv"; - i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output", i"tdnn1.relu.output.low.cst"); + i"tdnn5.affine.output_input" = transpose(i"tdnn5.affine.output.add_n", axes = [0, 2, 1]); + i"tdnn5.affine.output_conv" = conv(i"tdnn5.affine.output_input", i"tdnn5.affine.kernel.0", i"tdnn5.affine.bias.0", dilation = [1], stride = [1], border = "constant", groups = 1, padding = [(0, 0)]); + i"tdnn5.affine.output" = transpose(i"tdnn5.affine.output_conv", axes = [0, 2, 1]); + i"tdnn5.affine.output.rm_n" = squeeze(i"tdnn5.affine.output", axes = [0]); + i"tdnn5.relu.output.low" = max(i"tdnn5.affine.output.rm_n", i"tdnn1.relu.output.low.cst.1"); i"tdnn5.renorm.reduced.sum.sqr" = square(i"tdnn5.relu.output.low"); i"tdnn5.renorm.reduced.sum.sum" = sum_reduce(i"tdnn5.renorm.reduced.sum.sqr", axes = [1]); - i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-2"); + i"tdnn5.renorm.reduced.sum.card" = mul(i"tdnn5.renorm.reduced.sum.sum", i"tdnn1.renorm.reduced.sum.card.fix-rank-1-1"); i"tdnn5.renorm.output-recip" = rsqrt(i"tdnn5.renorm.reduced.sum.card"); i"tdnn5.renorm.output" = mul(i"tdnn5.relu.output.low", i"tdnn5.renorm.output-recip"); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b", shape = [1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_b", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", axes = [0]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = matmul(i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", transposeA = false, transposeB = false); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", axes = [1, 0, 2]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b", shape = [1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_b", transposeA = false, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [1, 0, 2]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", axes = [1, 0, 2]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b", shape = [1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_b", transposeA = false, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [1, 0, 2]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a" = unsqueeze(i"tdnn5.renorm.output", axes = [0]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", shape = [1, 1, 256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b.1", transposeA = true, transposeB = false); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = squeeze(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", axes = [1, 0, 2]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b", shape = [1, 256, 256]); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = matmul(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_a", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_b", transposeA = false, transposeB = false); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0" = transpose(i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [1, 0, 2]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", shape = [128, 256]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", shape = [128, 256]); @@ -261,14 +261,13 @@ graph network(input) -> (output) { i"fastlstm2.four_parts.split-1-over-1.512..768.slice" = variable(label = "fastlstm2.four_parts.split-1-over-1.512..768.slice", shape = [1, 256]); i"fastlstm2.four_parts.split-1-over-1.768..1024.slice" = variable(label = "fastlstm2.four_parts.split-1-over-1.768..1024.slice", shape = [1, 256]); i"fastlstm2.h_new.W.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.W.split-1-over-1.0..128.slice", shape = [256, 128]); - i"fastlstm2.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.split-1-over-1.0..128.slice", shape = [1, 128, 1]); + i"fastlstm2.h_new.split-1-over-1.0..128.slice" = variable(label = "fastlstm2.h_new.split-1-over-1.0..128.slice", shape = [1, 128]); i"fastlstm2.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.1", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.1", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm2.c_new"), ("fastlstm2.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 2, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); - i"output.affine.output.W.concat-einsum-k.0..128.fix_a" = variable(label = "output.affine.output.W.concat-einsum-k.0..128.fix_a", shape = [1, 1690, 128]); - i"output.affine.output.W.concat-einsum-k.0..128" = matmul(i"fastlstm2.c_final", i"output.affine.output.W.concat-einsum-k.0..128.fix_a", transposeA = true, transposeB = true); - i"output.affine.output.W.concat-einsum-k.0..128.fix_c.0" = squeeze(i"output.affine.output.W.concat-einsum-k.0..128", axes = [0]); + ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.fix_c.0", 0, 1), ("fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.fix_c.0", 0, 1)], full = [("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("fastlstm2.four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("fastlstm2.four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("fastlstm2.four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("fastlstm2.four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("fastlstm2.h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("fastlstm2.h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("fastlstm2.peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("fastlstm2.peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("fastlstm2.peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("fastlstm2.c", i"tap.tap.fastlstm1.c_init.0-35/0-100/0", "fastlstm2.c_new"), ("fastlstm2.r", i"tap.fastlstm1.r_init.0-36/0", "fastlstm2.r_new")], output = [("fastlstm2.r_new", "full", 0, 1), ("fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6, reset_every_turn = false); + i"output.affine.output.W.concat-einsum-slice-k.0.0..128" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.0..128", shape = [1690, 128]); + i"output.affine.output.W.concat-einsum-k.0..128" = matmul(i"fastlstm2.c_final", i"output.affine.output.W.concat-einsum-slice-k.0.0..128", transposeA = false, transposeB = true); i"fastlstm2.h_new.W.split-over-1.128..256.fix_a" = transpose(i"fastlstm2.c_final_1", axes = [1, 0, 2]); i"fastlstm2.h_new.W.split-over-1.128..256.fix_b" = variable(label = "fastlstm2.h_new.W.split-over-1.128..256.fix_b", shape = [1, 256, 128]); i"fastlstm2.h_new.W.split-over-1.128..256" = matmul(i"fastlstm2.h_new.W.split-over-1.128..256.fix_b", i"fastlstm2.h_new.W.split-over-1.128..256.fix_a", transposeA = true, transposeB = true); @@ -277,7 +276,7 @@ graph network(input) -> (output) { i"fastlstm2.h_new.split-over-1.128..256" = add(i"fastlstm2.h_new.W.split-over-1.128..256.fix_c.0", i"fastlstm2.c_final.fastlstm2.h_new.split-1-over-1.128..256.slice"); i"output.affine.output.W.concat-einsum-slice-k.0.128..256" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.128..256", shape = [1690, 128]); i"output.affine.output.W.concat-einsum-k.128..256" = matmul(i"fastlstm2.h_new.split-over-1.128..256", i"output.affine.output.W.concat-einsum-slice-k.0.128..256", transposeA = true, transposeB = true); - i"output.affine.output.W.concat-einsum-k.add-1" = add(i"output.affine.output.W.concat-einsum-k.0..128.fix_c.0", i"output.affine.output.W.concat-einsum-k.128..256"); + i"output.affine.output.W.concat-einsum-k.add-1" = add(i"output.affine.output.W.concat-einsum-k.0..128", i"output.affine.output.W.concat-einsum-k.128..256"); i"output.affine.bias.0" = variable(label = "output.affine.bias.0", shape = [1, 1690]); i"output.affine.output" = add(i"output.affine.output.W.concat-einsum-k.add-1", i"output.affine.bias.0"); output = i"output.affine.output"; From 94c6ffc3231b1298301af107f97b55d3f577fca7 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Dec 2024 14:07:14 +0100 Subject: [PATCH 3/4] no more propagation through scalar constants --- harness/pre-optimized-graphes/hey_snips_v4_model17/expected | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/harness/pre-optimized-graphes/hey_snips_v4_model17/expected b/harness/pre-optimized-graphes/hey_snips_v4_model17/expected index 98fed5870c..a9375e56d4 100644 --- a/harness/pre-optimized-graphes/hey_snips_v4_model17/expected +++ b/harness/pre-optimized-graphes/hey_snips_v4_model17/expected @@ -548,11 +548,11 @@ graph network(input_node) -> (i"wavenet_2/post_proc_2-1x1_conv-conv1d/convolutio i"wavenet_2/dilation_layer_23-dilation_rate_8-1x1_conv_skip-conv1d/convolution/Conv2D.filters_as_co_ci" = variable(label = "wavenet_2/dilation_layer_23-dilation_rate_8-1x1_conv_skip-conv1d/convolution/Conv2D.filters_as_co_ci", shape = [32, 64]); i"wavenet_2/dilation_layer_23-dilation_rate_8-1x1_conv_skip-conv1d/convolution/Conv2D" = matmul(i"wavenet_2/mul_23", i"wavenet_2/dilation_layer_23-dilation_rate_8-1x1_conv_skip-conv1d/convolution/Conv2D.filters_as_co_ci", transposeA = false, transposeB = true); i"wavenet_2/AddN.22" = add(i"wavenet_2/AddN.21", i"wavenet_2/dilation_layer_23-dilation_rate_8-1x1_conv_skip-conv1d/convolution/Conv2D"); - i"wavenet_2/Relu.low.cst" = [[0.0]]; - i"wavenet_2/Relu.low" = max(i"wavenet_2/AddN.22", i"wavenet_2/Relu.low.cst"); + i"wavenet_2/Relu.low.cst.1.1.1" = [[0.0]]; + i"wavenet_2/Relu.low" = max(i"wavenet_2/AddN.22", i"wavenet_2/Relu.low.cst.1.1.1"); i"wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci" = variable(label = "wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci", shape = [32, 32]); i"wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D" = matmul(i"wavenet_2/Relu.low", i"wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci", transposeA = false, transposeB = true); - i"wavenet_2/post_proc_1-1x1_conv-conv1d/Relu.low" = max(i"wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D", i"wavenet_2/Relu.low.cst"); + i"wavenet_2/post_proc_1-1x1_conv-conv1d/Relu.low" = max(i"wavenet_2/post_proc_1-1x1_conv-conv1d/convolution/Conv2D", i"wavenet_2/Relu.low.cst.1.1.1"); i"wavenet_2/post_proc_2-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci" = variable(label = "wavenet_2/post_proc_2-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci", shape = [2, 32]); i"wavenet_2/post_proc_2-1x1_conv-conv1d/convolution/Conv2D" = matmul(i"wavenet_2/post_proc_1-1x1_conv-conv1d/Relu.low", i"wavenet_2/post_proc_2-1x1_conv-conv1d/convolution/Conv2D.filters_as_co_ci", transposeA = false, transposeB = true); } From 10827982b3c2aa5622a25236e60cf671a5e93411 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 4 Dec 2024 17:21:24 +0100 Subject: [PATCH 4/4] more aggressive discarding of change axis candidates --- core/src/ops/scan/decluttered.rs | 4 ++++ core/src/optim/change_axes.rs | 25 +++++++++++++++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/core/src/ops/scan/decluttered.rs b/core/src/ops/scan/decluttered.rs index 75b886cfcb..ccb9eef2fe 100644 --- a/core/src/ops/scan/decluttered.rs +++ b/core/src/ops/scan/decluttered.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::ops::einsum::EinSum; use crate::ops::konst::Const; use crate::optim::OptimizerSession; @@ -592,12 +594,14 @@ impl Scan { ) -> TractResult> { self.body.check_consistency()?; let locked_outlets = self.body_locked_outlets(node_input_facts)?; + let mut explored: HashSet = Default::default(); let (body_patch, body_changed_wires) = if let Some(changes) = crate::optim::change_axes::change_axes( &self.body, &change, if locked_interface { &locked_outlets } else { &[] }, &self.body_bounds()?, + &mut explored )? { changes } else { diff --git a/core/src/optim/change_axes.rs b/core/src/optim/change_axes.rs index c5c51561aa..963b88d78b 100644 --- a/core/src/optim/change_axes.rs +++ b/core/src/optim/change_axes.rs @@ -31,6 +31,7 @@ impl TypedPass for ChangeAxes { _session: &mut OptimizerSession, model: &TypedModel, ) -> TractResult> { + let mut explored: HashSet = Default::default(); let mut interfaces = model.output_outlets()?.to_vec(); interfaces.extend(model.input_outlets()?.iter()); for node in &model.nodes[self.1..] { @@ -41,8 +42,10 @@ impl TypedPass for ChangeAxes { let outlet = suggestion.0.as_outlet(node); let change = AxisChange { outlet, op: suggestion.1 }; if self.0.insert(change.clone()) { - if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[]) - .with_context(|| format!("Making patch for {:?} from {}", change, node))? + if let Some((patch, _)) = + change_axes(model, &change, &interfaces, &[], &mut explored).with_context( + || format!("Making patch for {:?} from {}", change, node), + )? { self.1 = node.id; return Ok(Some(patch)); @@ -55,10 +58,11 @@ impl TypedPass for ChangeAxes { let change = AxisChange { outlet: OutletId::new(node.id, slot), op: AxisOp::Rm(ix) }; if self.0.insert(change.clone()) { - if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[]) - .with_context(|| { - format!("Making patch for {:?} from {}", change, node) - })? + if let Some((patch, _)) = + change_axes(model, &change, &interfaces, &[], &mut explored) + .with_context(|| { + format!("Making patch for {:?} from {}", change, node) + })? { self.1 = node.id; return Ok(Some(patch)); @@ -78,8 +82,14 @@ pub fn change_axes( change: &AxisChange, locked: &[OutletId], bounds: &[TVec], + explored: &mut HashSet, ) -> TractResult)>> { + if explored.contains(change) { + debug!(" Not considering change because deja vu {:?}", change); + return Ok(None); + } if model.node(change.outlet.node).op_as::().is_some_and(|c| c.0.volume() == 1) { + debug!(" Not considering change from const {:?}", change); return Ok(None); } debug!(" Considering change {:?}", change); @@ -92,6 +102,9 @@ pub fn change_axes( let mut changed_ops: HashMap> = HashMap::new(); let mut rewired_scalar_input: HashMap = Default::default(); while let Some((change, emitter)) = todo_changes.pop() { + if !explored.insert(change.clone()) { + return Ok(None); + } let outlet_group = bound_outlets(change.outlet); for &outlet in &outlet_group { if locked.contains(&outlet) {