Skip to content

Commit

Permalink
wip, making iter count a field
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 24, 2023
1 parent f684c85 commit 3e91496
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 256 deletions.
59 changes: 24 additions & 35 deletions core/src/ops/scan/lir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use tract_data::internal::*;
#[derive(Debug, Clone, new)]
pub struct LirScanOpParams {
pub skip: usize,
pub iters: TDim,
pub plan: Arc<TypedSimplePlan<TypedModel>>,
pub input_mapping: Vec<InputMapping>,
pub output_mapping: Vec<OutputMapping<TDim>>,
}

#[derive(Debug, Clone, new)]
pub struct LirScan(Arc<LirScanOpParams>);
pub struct LirScan(pub Arc<LirScanOpParams>);

impl std::ops::Deref for LirScan {
type Target = LirScanOpParams;
Expand All @@ -21,12 +22,6 @@ impl std::ops::Deref for LirScan {
}
}

impl LirScan {
pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
super::iteration_count(&self.input_mapping, inputs)
}
}

impl Op for LirScan {
fn name(&self) -> Cow<str> {
"Scan".into()
Expand Down Expand Up @@ -60,14 +55,14 @@ impl EvalOp for LirScan {
position: 0,
hidden_state: tvec!(),
model_state: TypedSimpleState::new(Arc::clone(&self.plan))?,
op: Arc::clone(&self.0),
params: Arc::clone(&self.0),
})))
}
}

#[derive(Clone, Debug)]
pub struct State {
op: Arc<LirScanOpParams>,
pub params: Arc<LirScanOpParams>,
position: usize,
hidden_state: TVec<TValue>,
pub model_state: TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
Expand All @@ -84,7 +79,7 @@ struct FrozenState {
impl OpStateFreeze for State {
fn freeze(&self) -> Box<dyn FrozenOpState> {
Box::new(FrozenState {
op: self.op.clone(),
op: self.params.clone(),
position: self.position,
hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
model_state: self.model_state.freeze(),
Expand All @@ -95,7 +90,7 @@ impl OpStateFreeze for State {
impl FrozenOpState for FrozenState {
fn unfreeze(&self) -> Box<dyn OpState> {
Box::new(State {
op: self.op.clone(),
params: self.op.clone(),
position: self.position,
hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
model_state: self.model_state.unfreeze(),
Expand All @@ -104,17 +99,6 @@ impl FrozenOpState for FrozenState {
}

impl State {
pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
let (slot, info) = self
.op
.input_mapping
.iter()
.enumerate()
.find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
.unwrap();
inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
}

pub(super) fn slice_input(
input: &Tensor,
axis: usize,
Expand Down Expand Up @@ -177,23 +161,22 @@ impl OpState for State {
_op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let iters = self.iteration_count(&inputs);

let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self;
let State { params, ref mut hidden_state, ref mut position, ref mut model_state } = self;
let iters: usize = params.iters.eval(&session.resolved_symbols).to_usize().unwrap();

// initialize state at first pass
if hidden_state.len() == 0 {
for (slot, input) in op.input_mapping.iter().enumerate() {
for (slot, input) in params.input_mapping.iter().enumerate() {
if input.is_state() {
hidden_state.push(inputs[slot].clone());
}
}
}

let mut outputs = tvec!();
for (ix, output) in op.output_mapping.iter().enumerate() {
for (ix, output) in params.output_mapping.iter().enumerate() {
if let Some((slot, info)) = output.scan {
let fact = op.plan.model().output_fact(ix)?;
let fact = params.plan.model().output_fact(ix)?;
let mut shape: TVec<usize> =
fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
let scanning_dim = output
Expand All @@ -212,14 +195,18 @@ impl OpState for State {
outputs.sort_by_key(|a| a.0);
let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();

for i in 0..iters {
let mut i = 0;
loop {
if i >= iters {
break;
}
*position += 1;
if *position <= op.skip {
if *position <= params.skip {
continue;
}
hidden_state.reverse();

let iter_inputs: TVec<TValue> = op
let iter_inputs: TVec<TValue> = params
.input_mapping
.iter()
.enumerate()
Expand All @@ -243,7 +230,7 @@ impl OpState for State {
model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
trace!("iter_outputs #{}: {:?}", i, iter_outputs);

for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
for (v, mapping) in iter_outputs.into_iter().zip(&params.output_mapping) {
if let Some((slot, info)) = mapping.scan {
Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
}
Expand All @@ -256,6 +243,7 @@ impl OpState for State {
hidden_state.push(v);
}
}
i = i + 1;
}

Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
Expand All @@ -267,16 +255,17 @@ impl TypedOp for LirScan {

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut outputs = tvec!();
let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
for (ix, output) in self.output_mapping.iter().enumerate() {
let fact = self.plan.model().output_fact(ix)?;
if let Some(slot) = output.last_value_slot {
outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
}
if let Some((slot, info)) = output.scan {
let mut shape = fact.shape.clone();
let scanning_dim =
output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
let scanning_dim = output
.full_dim_hint
.clone()
.unwrap_or(shape[info.axis].clone() * &self.0.iters);
shape.set(info.axis, scanning_dim);
outputs.push((slot, fact.datum_type.fact(shape)));
}
Expand Down
Loading

0 comments on commit 3e91496

Please sign in to comment.