From 711c01bb12103f0c671611847bba04b6e946794b Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 23 May 2023 11:32:06 +0200 Subject: [PATCH] wip simplifying scan (cumsum test borken) --- core/src/ops/array/dyn_slice.rs | 6 ++--- hir/src/ops/array/strided_slice.rs | 4 +-- onnx/src/ops/cumsum.rs | 37 +++++++++++++++++++++++---- onnx/src/ops/rec/common.rs | 41 ++++++++++++++++++++++-------- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/core/src/ops/array/dyn_slice.rs b/core/src/ops/array/dyn_slice.rs index 43f538c4ad..c5a68ee822 100644 --- a/core/src/ops/array/dyn_slice.rs +++ b/core/src/ops/array/dyn_slice.rs @@ -6,7 +6,7 @@ pub struct DynSlice { pub axis: usize, pub start_input: bool, pub end_input: bool, - pub symbol: Symbol, + pub len: TDim, } impl DynSlice { @@ -63,8 +63,8 @@ impl EvalOp for DynSlice { impl TypedOp for DynSlice { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - let mut fact = inputs[0].clone(); - fact.shape.set(self.axis, self.symbol.clone().into()); + let mut fact = inputs[0].without_value(); + fact.shape.set(self.axis, self.len.clone().into()); Ok(tvec!(fact)) } diff --git a/hir/src/ops/array/strided_slice.rs b/hir/src/ops/array/strided_slice.rs index 4068bdd94e..6dcb63c626 100644 --- a/hir/src/ops/array/strided_slice.rs +++ b/hir/src/ops/array/strided_slice.rs @@ -297,10 +297,10 @@ impl Expansion for StridedSlice { AxisOp::Rm(0), &right, )?[0]; - let sym = target.symbol_table.new_with_prefix("l"); + let len = target.symbol_table.new_with_prefix("len").to_dim(); wire = target.wire_node( format!("{prefix}.slice-axis-{axis}"), - tract_core::ops::array::DynSlice::new(axis, true, true, sym), + tract_core::ops::array::DynSlice::new(axis, true, true, len), &[wire, left, right], )?[0]; } diff --git a/onnx/src/ops/cumsum.rs b/onnx/src/ops/cumsum.rs index a060cc1741..6c551bd2d9 100644 --- a/onnx/src/ops/cumsum.rs +++ b/onnx/src/ops/cumsum.rs @@ -1,4 +1,5 @@ use tract_hir::internal::*; +use tract_hir::tract_core::ops::array::DynSlice; use tract_hir::tract_core::ops::scan::ScanInfo; use crate::model::{OnnxOpRegister, ParsingContext}; @@ -54,11 +55,17 @@ impl Expansion for CumSum { )?[0]; let chunk = if self.reverse { -1 } else { 1 }; let input_mapping = - vec![scan::InputMapping::Scan(ScanInfo { axis, chunk }), scan::InputMapping::State]; + vec![scan::InputMapping::Full, scan::InputMapping::State, scan::InputMapping::State]; // outputs will be // acc + x (!exclusive) // acc input (exclusive) let output_mapping = vec![ + scan::OutputMapping { + scan: None, + full_dim_hint: None, + last_value_slot: None, + state: true, + }, scan::OutputMapping { scan: Some((0, ScanInfo { axis, chunk })), full_dim_hint: None, @@ -74,12 +81,32 @@ impl Expansion for CumSum { ]; let mut body = TypedModel::default(); let var_fact = data.datum_type.fact(var_shape); - let x = body.add_source("scan_input", var_fact.clone())?; + let x = body.add_source("scan_input", data)?; + + let i = body.add_source("i", i64::scalar_fact())?; + let one = body.add_const("one", tensor0(1i64))?; + let i_plus_one = body.wire_node("inc_i", tract_core::ops::math::add(), &[i, one])?[0]; + let x_slice = body.wire_node( + "x", + DynSlice { + axis, + start_input: true, + end_input: true, + len: 1.to_dim(), + }, + &[x, i, i_plus_one], + )?[0]; + let acc = body.add_source("acc_input", var_fact)?; - let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0]; - body.set_output_outlets(&[sum, acc])?; + dbg!(axis); + dbg!(body.outlet_fact(x)); + dbg!(body.outlet_fact(x_slice)); + dbg!(body.outlet_fact(acc)); + let sum = body.wire_node("add", tract_core::ops::math::add(), &[x_slice, acc])?[0]; + body.set_output_outlets(&[i_plus_one, sum, acc])?; let scan = scan::Scan::new(body, input_mapping, output_mapping, 0, iters)?; - let wires = model.wire_node(prefix, scan, &[inputs[0], init])?; + let zero = model.add_const(format!("{prefix}.zero"), tensor0(0i64))?; + let wires = model.wire_node(prefix, scan, &[inputs[0], zero, init])?; let output = wires[self.exclusive as usize]; Ok(tvec![output]) } diff --git a/onnx/src/ops/rec/common.rs b/onnx/src/ops/rec/common.rs index e392d60c45..1079f66eb8 100644 --- a/onnx/src/ops/rec/common.rs +++ b/onnx/src/ops/rec/common.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use crate::pb::*; use tract_hir::internal::*; use tract_hir::tract_core::dyn_clone::{clone_trait_object, DynClone}; +use tract_hir::tract_core::ops::array::DynSlice; use tract_hir::tract_core::ops::scan::ScanInfo; pub trait WireBody: Debug + DynClone + Send + Sync { @@ -117,12 +118,21 @@ impl CommonRec { // scann inner interface: [chunk=1, batch_size, input_size] // onnx inner interface: [batch_size, input_size] outer_inputs.push(x_batch_first); - input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk })); - let mut x_source_fact = target.outlet_fact(x_batch_first)?.without_value(); + // input_mapping.push(scan::InputMapping::Scan(ScanInfo { axis: 1, chunk })); + input_mapping.push(scan::InputMapping::Full); + let x_source_fact = target.outlet_fact(x_batch_first)?.without_value(); let iters = x_source_fact.shape[1].clone(); - x_source_fact.shape.set(1, 1.to_dim()); let x_source = body.add_source("x_source", x_source_fact)?; - wire!(Xt = AxisOp::Rm(1), x_source); + + input_mapping.push(scan::InputMapping::State); + let zero = target.add_const(format!("{prefix}.zero"), tensor0(0i64))?; + outer_inputs.push(zero); + let i = body.add_source("i", i64::scalar_fact())?; + let one = body.add_const("one", tensor0(1i64))?; + wire!(i_plus_one = tract_core::ops::math::add(), i, one); + let dyn_slice = DynSlice { axis: 1, start_input: true, end_input: true, len: 1.to_dim() }; + wire!(x_slice = dyn_slice, x_source, i, i_plus_one); + wire!(Xt = AxisOp::Rm(1), x_slice); // W: onnx interface: [num_directions, 3*hidden_size, input_size] // scan interfaces: [3*hidden_size, input_size] @@ -229,13 +239,24 @@ impl CommonRec { }; self.body.wire_body(prefix, &mut body).context("Wiring body")?; + let mut outputs = body.outputs.clone(); + outputs.insert(0, i_plus_one); + body.set_output_outlets(&*outputs)?; - let mut output_mapping = vec![scan::OutputMapping { - state: true, - full_dim_hint: None, - last_value_slot: self.optional_y_h_output, - scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })), - }]; + let mut output_mapping = vec![ + scan::OutputMapping { + state: true, + full_dim_hint: None, + last_value_slot: None, + scan: None, + }, + scan::OutputMapping { + state: true, + full_dim_hint: None, + last_value_slot: self.optional_y_h_output, + scan: self.optional_y_output.map(|slot| (slot, ScanInfo { axis: 1, chunk })), + }, + ]; if self.body.have_extra_c_state() { output_mapping.push(scan::OutputMapping { state: true,