diff --git a/core/src/ops/scan/lir.rs b/core/src/ops/scan/lir.rs index f953165770..942ef59838 100644 --- a/core/src/ops/scan/lir.rs +++ b/core/src/ops/scan/lir.rs @@ -6,13 +6,14 @@ use tract_data::internal::*; #[derive(Debug, Clone, new)] pub struct LirScanOpParams { pub skip: usize, + pub iters: TDim, pub plan: Arc>, pub input_mapping: Vec, pub output_mapping: Vec>, } #[derive(Debug, Clone, new)] -pub struct LirScan(Arc); +pub struct LirScan(pub Arc); impl std::ops::Deref for LirScan { type Target = LirScanOpParams; @@ -21,12 +22,6 @@ impl std::ops::Deref for LirScan { } } -impl LirScan { - pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option { - super::iteration_count(&self.input_mapping, inputs) - } -} - impl Op for LirScan { fn name(&self) -> Cow { "Scan".into() @@ -60,14 +55,14 @@ impl EvalOp for LirScan { position: 0, hidden_state: tvec!(), model_state: TypedSimpleState::new(Arc::clone(&self.plan))?, - op: Arc::clone(&self.0), + params: Arc::clone(&self.0), }))) } } #[derive(Clone, Debug)] pub struct State { - op: Arc, + pub params: Arc, position: usize, hidden_state: TVec, pub model_state: TypedSimpleState>>, @@ -84,7 +79,7 @@ struct FrozenState { impl OpStateFreeze for State { fn freeze(&self) -> Box { Box::new(FrozenState { - op: self.op.clone(), + op: self.params.clone(), position: self.position, hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(), model_state: self.model_state.freeze(), @@ -95,7 +90,7 @@ impl OpStateFreeze for State { impl FrozenOpState for FrozenState { fn unfreeze(&self) -> Box { Box::new(State { - op: self.op.clone(), + params: self.op.clone(), position: self.position, hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(), model_state: self.model_state.unfreeze(), @@ -104,17 +99,6 @@ impl FrozenOpState for FrozenState { } impl State { - pub fn iteration_count(&self, inputs: &TVec) -> usize { - let (slot, info) = self - .op - .input_mapping - .iter() - .enumerate() - .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan))) - .unwrap(); - inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs()) - } - pub(super) fn slice_input( input: &Tensor, axis: usize, @@ -177,13 +161,12 @@ impl OpState for State { _op: &dyn Op, inputs: TVec, ) -> TractResult> { - let iters = self.iteration_count(&inputs); - - let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self; + let State { params, ref mut hidden_state, ref mut position, ref mut model_state } = self; + let iters: usize = params.iters.eval(&session.resolved_symbols).to_usize().unwrap(); // initialize state at first pass if hidden_state.len() == 0 { - for (slot, input) in op.input_mapping.iter().enumerate() { + for (slot, input) in params.input_mapping.iter().enumerate() { if input.is_state() { hidden_state.push(inputs[slot].clone()); } @@ -191,9 +174,9 @@ impl OpState for State { } let mut outputs = tvec!(); - for (ix, output) in op.output_mapping.iter().enumerate() { + for (ix, output) in params.output_mapping.iter().enumerate() { if let Some((slot, info)) = output.scan { - let fact = op.plan.model().output_fact(ix)?; + let fact = params.plan.model().output_fact(ix)?; let mut shape: TVec = fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned(); let scanning_dim = output @@ -212,14 +195,18 @@ impl OpState for State { outputs.sort_by_key(|a| a.0); let mut outputs: TVec = outputs.into_iter().map(|(_slot, v)| v).collect(); - for i in 0..iters { + let mut i = 0; + loop { + if i >= iters { + break; + } *position += 1; - if *position <= op.skip { + if *position <= params.skip { continue; } hidden_state.reverse(); - let iter_inputs: TVec = op + let iter_inputs: TVec = params .input_mapping .iter() .enumerate() @@ -243,7 +230,7 @@ impl OpState for State { model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?; trace!("iter_outputs #{}: {:?}", i, iter_outputs); - for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) { + for (v, mapping) in iter_outputs.into_iter().zip(¶ms.output_mapping) { if let Some((slot, info)) = mapping.scan { Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0); } @@ -256,6 +243,7 @@ impl OpState for State { hidden_state.push(v); } } + i = i + 1; } Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect()) @@ -267,7 +255,6 @@ impl TypedOp for LirScan { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { let mut outputs = tvec!(); - let iters = super::iteration_count(&self.input_mapping, inputs).unwrap(); for (ix, output) in self.output_mapping.iter().enumerate() { let fact = self.plan.model().output_fact(ix)?; if let Some(slot) = output.last_value_slot { @@ -275,8 +262,10 @@ impl TypedOp for LirScan { } if let Some((slot, info)) = output.scan { let mut shape = fact.shape.clone(); - let scanning_dim = - output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters); + let scanning_dim = output + .full_dim_hint + .clone() + .unwrap_or(shape[info.axis].clone() * &self.0.iters); shape.set(info.axis, scanning_dim); outputs.push((slot, fact.datum_type.fact(shape))); } diff --git a/core/src/ops/scan/mir.rs b/core/src/ops/scan/mir.rs index fd5ec26eb0..1249101974 100644 --- a/core/src/ops/scan/mir.rs +++ b/core/src/ops/scan/mir.rs @@ -1,7 +1,3 @@ -use crate::ops::einsum::EinSum; -use crate::ops::konst::Const; -use crate::optim::OptimizerSession; - use super::lir::{LirScan, LirScanOpParams}; use tract_data::internal::*; @@ -10,6 +6,7 @@ use super::*; #[derive(Debug, Clone, Default)] pub struct Scan { pub skip: usize, + pub iters: TDim, pub body: TypedModel, pub input_mapping: Vec, pub output_mapping: Vec>, @@ -25,6 +22,7 @@ impl Scan { Ok(LirScan::new(Arc::new(LirScanOpParams::new( self.skip, + self.iters.clone(), Arc::new(plan), self.input_mapping.clone(), self.output_mapping.clone(), @@ -36,114 +34,12 @@ impl Scan { input_mapping: Vec, output_mapping: Vec>, skip: usize, + iters: TDim, ) -> TractResult { body.check_consistency()?; ensure!(input_mapping.len() == body.input_outlets()?.len()); ensure!(output_mapping.len() == body.output_outlets()?.len()); - Ok(Scan { skip, body, input_mapping, output_mapping }) - } - - pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option { - self.to_codegen_op(false).unwrap().iteration_count(inputs) - } - - fn body_bounds(&self) -> TractResult>> { - let input_state_outlets = self - .input_mapping - .iter() - .zip(self.body.input_outlets()?.iter()) - .filter(|(m, _)| m.is_state()) - .map(|(_, o)| o); - let output_state_outlets = self - .output_mapping - .iter() - .zip(self.body.output_outlets()?.iter()) - .filter(|(m, _)| m.state) - .map(|(_, o)| o); - Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect()) - } - - fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult> { - let input_outlets = - self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| { - if node_input_facts[slot].konst.is_none() { - Some(o) - } else { - None - } - }); - let output_outlets = self - .output_mapping - .iter() - .zip(self.body.output_outlets()?.iter()) - .filter(|(m, _)| !m.invisible()) - .map(|(_, o)| o); - Ok(input_outlets.chain(output_outlets).cloned().collect()) - } - - fn try_body_axes_change( - &self, - change: AxisChange, - locked_interface: bool, - node_input_facts: &[&TypedFact], - ) -> TractResult> { - self.body.check_consistency()?; - let locked_outlets = self.body_locked_outlets(node_input_facts)?; - let (body_patch, body_changed_wires) = if let Some(changes) = - crate::optim::change_axes::change_axes( - &self.body, - &change, - if locked_interface { &locked_outlets } else { &[] }, - &self.body_bounds()?, - )? { - changes - } else { - return Ok(None); - }; - let mut body = self.body.clone(); - body_patch.apply(&mut body)?; - body.compact()?; - let mut wire_changes = tvec!(); - let mut input_mapping: Vec = self.input_mapping.clone(); - for (slot, m) in input_mapping.iter_mut().enumerate() { - if let Some(change) = body_changed_wires - .iter() - .find(|(iface, _change)| iface == &InOut::In(slot)) - .map(|pair| pair.1.clone()) - { - wire_changes.push((InOut::In(slot), change.clone())); - if let InputMapping::Scan(info) = m { - if let Some(axis) = change.transform_axis(info.axis) { - info.axis = axis; - } else { - return Ok(None); - }; - }; - } - } - let mut output_mapping: Vec> = self.output_mapping.clone(); - for (ix, m) in output_mapping.iter_mut().enumerate() { - if let Some(change) = body_changed_wires - .iter() - .find(|(iface, _change)| iface == &InOut::Out(ix)) - .map(|pair| pair.1.clone()) - { - if let Some((slot, info)) = m.scan.as_mut() { - if let Some(new_axis) = change.transform_axis(info.axis) { - info.axis = new_axis; - } else { - return Ok(None); - } - wire_changes.push((InOut::Out(*slot), change.clone())); - } - if let Some(slot) = m.last_value_slot { - wire_changes.push((InOut::Out(slot), change.clone())); - } - }; - } - body.check_consistency()?; - let op = Some(Box::new(Scan { body, input_mapping, output_mapping, ..self.clone() }) as _); - Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes })) + Ok(Scan { skip, iters, body, input_mapping, output_mapping }) } } @@ -153,7 +49,7 @@ impl Op for Scan { } fn info(&self) -> TractResult> { - let mut lines = vec![]; + let mut lines = vec![format!("iters: {:?}", self.iters)]; for (ix, im) in self.input_mapping.iter().enumerate() { lines.push(format!("Model input #{ix}: {im:?}")); } @@ -204,13 +100,12 @@ impl TypedOp for Scan { ) } let mut outputs = tvec!(); - let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?; for (ix, output) in self.output_mapping.iter().enumerate() { let fact = self.body.output_fact(ix)?; if let Some((slot, info)) = output.scan { let mut shape = fact.shape.clone(); let scanning_dim = - output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters); + output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &self.iters); shape.set(info.axis, scanning_dim); outputs.push((slot, fact.datum_type.fact(shape))); } @@ -224,89 +119,6 @@ impl TypedOp for Scan { Ok(outputs) } - fn axes_mapping( - &self, - inputs: &[&TypedFact], - outputs: &[&TypedFact], - ) -> TractResult { - let mut mappings = vec![]; - let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?; - for body_axis in body_invs.iter_all_axes() { - let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len()); - info.inputs = body_axis.inputs.clone(); - for (ix, output_mapping) in self.output_mapping.iter().enumerate() { - let mut slots = vec![]; - if let Some((slot, _scan)) = output_mapping.scan { - slots.push(slot); - } - if let Some(slot) = output_mapping.last_value_slot { - slots.push(slot); - } - for slot in slots { - info.outputs[slot] = body_axis.outputs[ix].clone(); - } - } - if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) { - mappings.push(info); - } - } - AxesMapping::new(inputs.len(), outputs.len(), mappings) - } - - fn suggested_axis_changes(&self) -> TractResult> { - let mut suggestions = tvec!(); - for (slot, input) in self.input_mapping.iter().enumerate() { - if let InputMapping::Scan(info) = input { - if info.axis != 0 { - suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0))) - } - } - } - for output in &self.output_mapping { - if let Some((slot, scan)) = output.scan { - if scan.axis != 0 { - suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0))) - } - } - } - Ok(suggestions) - } - - fn change_axes( - &self, - model: &TypedModel, - node: &TypedNode, - io: InOut, - change: &AxisOp, - ) -> TractResult> { - trace!("Propagating through {}: {:?} {:?}", node, io, change); - let body_leading_outlet = match io { - InOut::In(ix) => self.body.input_outlets()?[ix], - InOut::Out(slot) => { - let output = self - .output_mapping - .iter() - .position(|im| { - im.scan.map(|(slot, _i)| slot) == Some(slot) - || im.last_value_slot == Some(slot) - }) - .unwrap(); - self.body.output_outlets()?[output] - } - }; - let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() }; - let node_input_facts = model.node_input_facts(node.id)?; - let result = self - .try_body_axes_change(axis_change, false, &node_input_facts) - .with_context(|| "Attemping to run change through scan body".to_string())?; - if result.is_some() { - trace!("{} accepted axis change", node); - } else { - trace!("{} rejected axis change", node); - } - Ok(result) - } - fn concretize_dims( &self, _source: &TypedModel, diff --git a/core/src/ops/scan/mod.rs b/core/src/ops/scan/mod.rs index 8d81b4e8d0..b1e7b8f686 100644 --- a/core/src/ops/scan/mod.rs +++ b/core/src/ops/scan/mod.rs @@ -78,6 +78,7 @@ impl fmt::Debug for OutputMapping { } } +/* pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> Option { let Some((slot, info)) = input_mapping .iter() @@ -87,3 +88,4 @@ pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> let outside_dim = inputs[slot].shape[info.axis].clone(); Some(outside_dim.div_ceil(info.chunk.unsigned_abs() as u64)) } +*/ diff --git a/hir/src/ops/scan.rs b/hir/src/ops/scan.rs index 6c18560cc1..2c77aba04f 100644 --- a/hir/src/ops/scan.rs +++ b/hir/src/ops/scan.rs @@ -49,17 +49,19 @@ impl EvalOp for InferenceScan { impl InferenceScan { pub(super) fn to_mir_scan(&self) -> TractResult> { - let typed_model = self.body.clone().into_typed()?; + let iters = self.iter_count_fact.concretize().unwrap(); + let typed_body = self.body.clone().into_typed()?; let input_mapping = self .input_mapping .iter() .enumerate() .map(|(ix, im)| { Ok(match im { - InputMapping::Scan(info) => InputMapping::Scan(ScanInfo { - chunk: typed_model.input_fact(ix)?.shape[info.axis].to_isize()?, + InputMapping::Scan(info) => { + InputMapping::Scan(ScanInfo { + chunk: typed_body.input_fact(ix)?.shape[info.axis].to_isize()?, ..*info - }), + })}, other => other.clone(), }) }) @@ -70,10 +72,13 @@ impl InferenceScan { .enumerate() .map(|(ix, im)| { let scan = if let Some((slot, scan)) = im.scan { - Some((slot, ScanInfo { - chunk: typed_model.input_fact(ix)?.shape[scan.axis].to_isize()?, - ..scan - })) + Some(( + slot, + ScanInfo { + chunk: typed_body.input_fact(ix)?.shape[scan.axis].to_isize()?, + ..scan + }, + )) } else { None }; @@ -85,12 +90,7 @@ impl InferenceScan { }) }) .collect::>()?; - Ok(Box::new(Scan::new( - typed_model, - input_mapping, - output_mapping, - 0, - )?)) + Ok(Box::new(Scan::new(typed_body, input_mapping, output_mapping, 0, iters)?)) } fn unify_scanning_tensor_fact( @@ -248,7 +248,8 @@ impl InferenceOp for InferenceScan { .filter_map(|om| om.last_value_slot) .chain(self.output_mapping.iter().filter_map(|om| om.scan.map(|si| si.0))) .max() - .context("No output slot found")? + 1; + .context("No output slot found")? + + 1; if inputs.len() != expected_op_inputs { bail!("Scan receives {} inputs, mappings expects {}", inputs.len(), expected_op_inputs) } diff --git a/libcli/src/model.rs b/libcli/src/model.rs index fa9761e1ec..b6e9af2fe5 100644 --- a/libcli/src/model.rs +++ b/libcli/src/model.rs @@ -85,9 +85,9 @@ pub trait Model: if let Some(submodel) = self.node_op(id).downcast_ref::() { submodel.iteration_count(input) } else if let Some(lir) = self.node_op(id).downcast_ref::() { - lir.iteration_count(input) + Some(lir.iters.clone()) } else if let Some(mir) = self.node_op(id).downcast_ref::() { - mir.iteration_count(input) + Some(mir.iters.clone()) } else { None } diff --git a/libcli/src/profile.rs b/libcli/src/profile.rs index a14739a0ea..2e3c261174 100644 --- a/libcli/src/profile.rs +++ b/libcli/src/profile.rs @@ -146,7 +146,7 @@ fn profile_submodel( new_prefix.push((node.id, "loop".to_string())); let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?; - let multi = scan_state.iteration_count(&input); + let multi = scan_state.params.iters.to_usize().unwrap(); rec_profiler( &mut scan_state.model_state, diff --git a/nnef/src/ops/core/scan.rs b/nnef/src/ops/core/scan.rs index 8b7758105f..b6debe78a4 100644 --- a/nnef/src/ops/core/scan.rs +++ b/nnef/src/ops/core/scan.rs @@ -154,12 +154,14 @@ fn de_scan(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tract let scan: TVec<(String, OutletId, usize, isize)> = invocation.named_arg_as(builder, "scan")?; let full: TVec<(String, OutletId)> = invocation.named_arg_as(builder, "full")?; let state: TVec<(String, OutletId, String)> = invocation.named_arg_as(builder, "state")?; + let mut iters:Option = None; for par in &fragment.decl.parameters { let (outer_input_wire, inner_fact) = if let Some((_, wire, axis, chunk)) = scan.iter().find(|s| s.0 == par.id.0 || escape(&s.0) == par.id.0) { input_mapping.push(InputMapping::Scan(ScanInfo { axis: *axis, chunk: *chunk })); let mut fact = builder.model.outlet_fact(*wire)?.clone(); + iters = Some(fact.shape[*axis].clone().div_ceil(chunk.abs() as _)); fact.shape.set(*axis, chunk.abs().to_dim()); (*wire, fact) } else if let Some((_, wire)) = @@ -183,6 +185,7 @@ fn de_scan(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tract Value::Wire(body.model.add_source(par.id.0.to_string(), inner_fact)?), ); } + let iters = iters.unwrap(); body.wire_body(fragment.body.as_deref().unwrap()).context("wiring scan body")?; let body_outputs = fragment .decl @@ -246,7 +249,7 @@ fn de_scan(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tract }); } let skip: usize = invocation.named_arg_as(builder, "skip")?; - let op = Scan::new(body.model, input_mapping, output_mapping, skip)?; + let op = Scan::new(body.model, input_mapping, output_mapping, skip, iters)?; builder.wire(op, &outer_inputs) } diff --git a/onnx/src/ops/cumsum.rs b/onnx/src/ops/cumsum.rs index 151a73653d..a060cc1741 100644 --- a/onnx/src/ops/cumsum.rs +++ b/onnx/src/ops/cumsum.rs @@ -45,6 +45,7 @@ impl Expansion for CumSum { format!("{prefix}.zero"), Tensor::zero_dt(data.datum_type, &[])?.into_arc_tensor(), )?; + let iters = var_shape[axis].clone(); var_shape.set(axis, 1.to_dim()); let init = model.wire_node( format!("{prefix}.init"), @@ -77,7 +78,7 @@ impl Expansion for CumSum { let acc = body.add_source("acc_input", var_fact)?; let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0]; body.set_output_outlets(&[sum, acc])?; - let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?; + let scan = scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?; let wires = model.wire_node(prefix, scan, &[inputs[0], init])?; let output = wires[self.exclusive as usize]; Ok(tvec![output]) diff --git a/onnx/src/ops/rec/common.rs b/onnx/src/ops/rec/common.rs index 82106446c0..1d8246525c 100644 --- a/onnx/src/ops/rec/common.rs +++ b/onnx/src/ops/rec/common.rs @@ -119,6 +119,7 @@ impl CommonRec { outer_inputs.push(x_batch_first); input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk })); let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value(); + let iters = x_source_fact.shape[1].clone(); x_source_fact.shape.set(1, 1.to_dim()); let x_source = body.add_source("x_source", x_source_fact)?; wire!(Xt = AxisOp::Rm(1), x_source); @@ -246,7 +247,7 @@ impl CommonRec { let scan_outputs = target.wire_node( prefix, - tract_core::ops::scan::Scan::new(body, input_mapping, output_mapping, 0)?, + tract_core::ops::scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?, &outer_inputs, )?; diff --git a/tensorflow/src/ops/rec/block_lstm.rs b/tensorflow/src/ops/rec/block_lstm.rs index 5837f9f1d7..02e862fd2e 100644 --- a/tensorflow/src/ops/rec/block_lstm.rs +++ b/tensorflow/src/ops/rec/block_lstm.rs @@ -99,6 +99,7 @@ impl Expansion for BlockLSTM { outer_inputs.push(inputs[1]); input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 0, chunk: 1 })); let mut x_source_fact = model.outlet_fact(inputs[1])?.clone(); + let iters = x_source_fact.shape[0].clone(); x_source_fact.shape.set(0, 1.to_dim()); let x_source = body.add_source("x_source", x_source_fact)?; wire!(x = AxisOp::Rm(0), x_source); @@ -174,7 +175,7 @@ impl Expansion for BlockLSTM { if seqlen.to_scalar::()? != &model.outlet_fact(inputs[1])?.shape[0] { bail!("seq_len only supported for trivial noop case"); }; - let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?; + let scan = scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?; model.wire_node(prefix, scan, &outer_inputs) } }