Skip to content

Commit

Permalink
fix mm scratch space in-session
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 24, 2024
1 parent 69c2805 commit 1eded5b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
44 changes: 23 additions & 21 deletions core/src/ops/matmul/lir_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ndarray::*;
use tract_itertools::Itertools;

use tract_linalg::mmm::{
BinOp, FusedSpec, InputStoreSpec, MatMatMul, OutputStoreSpec, ScratchSpace,
BinOp, FusedSpec, InputStoreSpec, MatMatMul, OutputStoreSpec,
};
use tract_linalg::Scaler;
use tract_smallvec::ToSmallVec;
Expand Down Expand Up @@ -283,6 +283,7 @@ impl Op for LirMatMulUnary {
op_as_typed_op!();
}

/*
#[derive(Clone, Debug)]
struct State;
trivial_op_state_freeeze!(State);
Expand All @@ -295,61 +296,62 @@ impl OpState for State {
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let op = op.downcast_ref::<LirMatMulUnary>().unwrap();
unsafe {
if session
.cached_mmm_scratch_space
.as_deref()
.map(|scratch| op.mmm.can_use_scratch_space(scratch))
== Some(false)
{
session.cached_mmm_scratch_space = None
}
let scratch = session
.cached_mmm_scratch_space
.get_or_insert_with(|| op.mmm.allocate_scratch_space());
eval(op, &session.resolved_symbols, scratch.as_mut(), &inputs)
}
unsafe { eval(op, &session.resolved_symbols, session, &inputs) }
}
}
*/

impl EvalOp for LirMatMulUnary {
fn is_stateless(&self) -> bool {
self.geometry.is_concrete()
true
}

/*
fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(State)))
}
*/

fn eval_with_session(
&self,
session: &SessionState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let mut scratch = unsafe { self.mmm.allocate_scratch_space() };
eval(self, &session.resolved_symbols, scratch.as_mut(), &inputs)
eval(self, &session.resolved_symbols, session, &inputs)
}
}

#[allow(clippy::too_many_arguments)]
fn eval(
op: &LirMatMulUnary,
symbols: &SymbolValues,
scratch: &mut dyn ScratchSpace,
session: &SessionState,
inputs: &[TValue],
) -> TractResult<TVec<TValue>> {
unsafe {
if session
.cached_mmm_scratch_space
.borrow_mut()
.as_deref()
.map(|scratch| op.mmm.can_use_scratch_space(scratch))
== Some(false)
{
session.cached_mmm_scratch_space.replace(None);
}
let mut cell = session.cached_mmm_scratch_space.borrow_mut();
let scratch = cell.get_or_insert_with(|| op.mmm.allocate_scratch_space());

if op.trivial_path {
let c_shape = op.c_fact.shape.as_concrete().unwrap_unchecked();
let geometry = op.geometry.as_concrete().unwrap_unchecked();
let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, c_shape)?;
let uops: Vec<FusedSpec> =
op.micro_ops.iter().map(|o| o.resolve_trivial(inputs, &mut c)).collect();
op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch, &uops)?;
op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?;
Ok(tvec!(c.into_tvalue()))
} else {
let geometry = op.geometry.to_concrete(symbols)?;
Expand All @@ -368,7 +370,7 @@ fn eval(
&c,
);
}
op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch, &uops)?;
op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?;
}
Ok(tvec!(c.into_tvalue()))
}
Expand Down
7 changes: 4 additions & 3 deletions core/src/plan.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Borrow;
use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;

Expand All @@ -14,7 +15,7 @@ pub struct SessionState {
pub inputs: HashMap<usize, TValue>,
pub resolved_symbols: SymbolValues,
pub tensors: HashMap<String, Tensor>,
pub cached_mmm_scratch_space: Option<Box<dyn tract_linalg::mmm::ScratchSpace>>,
pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
}

impl Clone for SessionState {
Expand All @@ -23,7 +24,7 @@ impl Clone for SessionState {
inputs: self.inputs.clone(),
resolved_symbols: self.resolved_symbols.clone(),
tensors: self.tensors.clone(),
cached_mmm_scratch_space: None,
cached_mmm_scratch_space: None.into()
}
}
}
Expand Down Expand Up @@ -590,7 +591,7 @@ where
inputs: self.inputs.iter().map(|(ix, t)| (*ix, t.clone().into_tvalue())).collect(),
resolved_symbols: self.resolved_symbols.clone(),
tensors: self.tensors.clone(),
cached_mmm_scratch_space: None,
cached_mmm_scratch_space: None.into(),
},
states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
values: self
Expand Down

0 comments on commit 1eded5b

Please sign in to comment.