diff --git a/core/src/ops/matmul/lir_unary.rs b/core/src/ops/matmul/lir_unary.rs index 5634b19348..b7c34e0df1 100644 --- a/core/src/ops/matmul/lir_unary.rs +++ b/core/src/ops/matmul/lir_unary.rs @@ -6,7 +6,7 @@ use ndarray::*; use tract_itertools::Itertools; use tract_linalg::mmm::{ - BinOp, FusedSpec, InputStoreSpec, MatMatMul, OutputStoreSpec, ScratchSpace, + BinOp, FusedSpec, InputStoreSpec, MatMatMul, OutputStoreSpec, }; use tract_linalg::Scaler; use tract_smallvec::ToSmallVec; @@ -283,6 +283,7 @@ impl Op for LirMatMulUnary { op_as_typed_op!(); } +/* #[derive(Clone, Debug)] struct State; trivial_op_state_freeeze!(State); @@ -295,28 +296,17 @@ impl OpState for State { inputs: TVec, ) -> TractResult> { let op = op.downcast_ref::().unwrap(); - unsafe { - if session - .cached_mmm_scratch_space - .as_deref() - .map(|scratch| op.mmm.can_use_scratch_space(scratch)) - == Some(false) - { - session.cached_mmm_scratch_space = None - } - let scratch = session - .cached_mmm_scratch_space - .get_or_insert_with(|| op.mmm.allocate_scratch_space()); - eval(op, &session.resolved_symbols, scratch.as_mut(), &inputs) - } + unsafe { eval(op, &session.resolved_symbols, session, &inputs) } } } +*/ impl EvalOp for LirMatMulUnary { fn is_stateless(&self) -> bool { - self.geometry.is_concrete() + true } + /* fn state( &self, _session: &mut SessionState, @@ -324,14 +314,14 @@ impl EvalOp for LirMatMulUnary { ) -> TractResult>> { Ok(Some(Box::new(State))) } + */ fn eval_with_session( &self, session: &SessionState, inputs: TVec, ) -> TractResult> { - let mut scratch = unsafe { self.mmm.allocate_scratch_space() }; - eval(self, &session.resolved_symbols, scratch.as_mut(), &inputs) + eval(self, &session.resolved_symbols, session, &inputs) } } @@ -339,17 +329,29 @@ impl EvalOp for LirMatMulUnary { fn eval( op: &LirMatMulUnary, symbols: &SymbolValues, - scratch: &mut dyn ScratchSpace, + session: &SessionState, inputs: &[TValue], ) -> TractResult> { unsafe { + if session + .cached_mmm_scratch_space + .borrow_mut() + .as_deref() + .map(|scratch| op.mmm.can_use_scratch_space(scratch)) + == Some(false) + { + session.cached_mmm_scratch_space.replace(None); + } + let mut cell = session.cached_mmm_scratch_space.borrow_mut(); + let scratch = cell.get_or_insert_with(|| op.mmm.allocate_scratch_space()); + if op.trivial_path { let c_shape = op.c_fact.shape.as_concrete().unwrap_unchecked(); let geometry = op.geometry.as_concrete().unwrap_unchecked(); let mut c = Tensor::uninitialized_dt(op.c_fact.datum_type, c_shape)?; let uops: Vec = op.micro_ops.iter().map(|o| o.resolve_trivial(inputs, &mut c)).collect(); - op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch, &uops)?; + op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?; Ok(tvec!(c.into_tvalue())) } else { let geometry = op.geometry.to_concrete(symbols)?; @@ -368,7 +370,7 @@ fn eval( &c, ); } - op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch, &uops)?; + op.mmm.run_with_scratch_space(geometry.m, geometry.n, scratch.as_mut(), &uops)?; } Ok(tvec!(c.into_tvalue())) } diff --git a/core/src/plan.rs b/core/src/plan.rs index 343c9621de..d4c6e746ce 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::cell::RefCell; use std::fmt::{Debug, Display}; use std::marker::PhantomData; @@ -14,7 +15,7 @@ pub struct SessionState { pub inputs: HashMap, pub resolved_symbols: SymbolValues, pub tensors: HashMap, - pub cached_mmm_scratch_space: Option>, + pub cached_mmm_scratch_space: RefCell>>, } impl Clone for SessionState { @@ -23,7 +24,7 @@ impl Clone for SessionState { inputs: self.inputs.clone(), resolved_symbols: self.resolved_symbols.clone(), tensors: self.tensors.clone(), - cached_mmm_scratch_space: None, + cached_mmm_scratch_space: None.into() } } } @@ -590,7 +591,7 @@ where inputs: self.inputs.iter().map(|(ix, t)| (*ix, t.clone().into_tvalue())).collect(), resolved_symbols: self.resolved_symbols.clone(), tensors: self.tensors.clone(), - cached_mmm_scratch_space: None, + cached_mmm_scratch_space: None.into(), }, states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(), values: self