From bba97125b78f5d5df87637e1857015c8d4b77285 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Tue, 3 Dec 2024 14:39:33 +0100 Subject: [PATCH] First implementation of FusedAxisOp --- metal/src/kernels/matmul/mod.rs | 7 +-- metal/src/memory/pool.rs | 1 + metal/src/ops/change_axes.rs | 79 +++++++++++++------------- metal/src/ops/fused_axis_op.rs | 98 +++++++++++++++++++++++++++++++++ metal/src/ops/mod.rs | 1 + metal/src/tensor/arena_view.rs | 23 +++++++- metal/src/tensor/mod.rs | 2 +- 7 files changed, 166 insertions(+), 45 deletions(-) create mode 100644 metal/src/ops/fused_axis_op.rs diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 23349005ef..be85935bd4 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -548,16 +548,15 @@ mod tests { .prop_flat_map(|(b, m, k, n)| { let lhs_len = b * m * k; let rhs_len = b * k * n; - let lhs = (0usize..10).prop_map(|x| x.as_()); - let rhs = (0usize..10).prop_map(|x| x.as_()); + let datum = (0usize..10).prop_map(|x| x.as_()); ( Just(b), Just(m), Just(k), Just(n), - vec(lhs, lhs_len..=lhs_len), + vec(datum.clone(), lhs_len..=lhs_len), proptest::bool::ANY, - vec(rhs, rhs_len..=rhs_len), + vec(datum, rhs_len..=rhs_len), proptest::bool::ANY, ) }) diff --git a/metal/src/memory/pool.rs b/metal/src/memory/pool.rs index c1895e19dc..6ddf9f4ead 100644 --- a/metal/src/memory/pool.rs +++ b/metal/src/memory/pool.rs @@ -45,6 +45,7 @@ impl MetalMemoryPool { Ok(MetalArenaView { arena: Arc::clone(&self.storage), dt, + len: shape.iter().product(), shape: shape.into(), strides: Tensor::natural_strides(shape), offset_bytes: offset, diff --git a/metal/src/ops/change_axes.rs b/metal/src/ops/change_axes.rs index 50cce23227..2b080c4b19 100644 --- a/metal/src/ops/change_axes.rs +++ b/metal/src/ops/change_axes.rs @@ -102,14 +102,8 @@ impl TypedOp for MetalAxisOp { as_op!(); fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - crate::utils::metal_facts_from_gpu(inputs, |facts| { - let mut shape = facts[0].shape.clone(); - self.0 - .change_shape(&mut shape, false) - .with_context(|| format!("Applying {self:?} to {:?}", facts[0]))?; - Ok(tvec!(facts[0].datum_type.fact(shape))) - }) - .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) + crate::utils::metal_facts_from_gpu(inputs, |facts| self.0.output_facts(facts)) + .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) } fn axes_mapping( @@ -147,36 +141,30 @@ impl TypedOp for MetalAxisOp { model: &TypedModel, node: &TypedNode, ) -> TractResult> { - let conc_shape = - crate::utils::metal_fact(&node.outputs[0].fact, |fact| Ok(fact.shape.as_concrete()))?; - if let Some(shape) = conc_shape { - if !matches!(self, MetalAxisOp(AxisOp::Move(_, _))) { - let (inputs, outputs) = model.node_facts(node.id)?; - let mapping = self.axes_mapping(&inputs, &outputs)?; - let op = MetalIntoShape(IntoShape { - mapping, - len: shape.iter().product(), - strides: Tensor::natural_strides(shape), - dims: shape.into(), - }); - return Ok(Some(TypedModelPatch::replace_single_op( - model, - node, - &node.inputs, - op, - )?)); - } + let shape = + crate::utils::metal_fact(&node.outputs[0].fact, |fact| Ok(fact.shape.to_tvec()))?; + if !matches!(self, MetalAxisOp(AxisOp::Move(_, _))) { + let (inputs, outputs) = model.node_facts(node.id)?; + let mapping = self.axes_mapping(&inputs, &outputs)?; + let op = MetalIntoShape { mapping, dims: shape }; + return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?)); } Ok(None) } } #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct MetalIntoShape(IntoShape); +pub struct MetalIntoShape { + pub mapping: AxesMapping, + pub dims: TVec, +} impl MetalIntoShape { pub fn from_tract_core(core_op: IntoShape) -> Self { - MetalIntoShape(core_op) + MetalIntoShape { + mapping: core_op.mapping, + dims: core_op.dims.into_iter().map(|it| it.into()).collect(), + } } } @@ -186,7 +174,7 @@ impl Op for MetalIntoShape { } fn info(&self) -> TractResult> { - self.0.info() + Ok(vec![format!("{}", self.mapping)]) } op_as_typed_op!(); @@ -205,9 +193,20 @@ impl MetalEvalOp for MetalIntoShape { ) -> TractResult> { let opaque = args_1!(inputs).into_tensor(); let input = opaque.to_metal_tensor()?; - ensure!(input.len() == self.0.len); - let output = - crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), &self.0.dims)?; + let dims = self + .dims + .iter() + .map(|d| d.eval_to_i64(&session.resolved_symbols).map(|it| it as usize)) + .collect::>>()?; + let len = dims.iter().product::(); + + ensure!(input.len() == len); + let output = crate::ops::make_tensor_for_node( + session, + node_id, + input.datum_type(), + dims.as_slice(), + )?; Memcpy.dispatch_eval(context, input, 0, &output)?; @@ -217,8 +216,10 @@ impl MetalEvalOp for MetalIntoShape { impl TypedOp for MetalIntoShape { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - crate::utils::metal_facts_from_gpu(inputs, |facts| self.0.output_facts(facts)) - .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) + crate::utils::metal_facts_from_gpu(inputs, |facts| { + Ok(tvec!(facts[0].datum_type.fact(&self.dims))) + }) + .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) } fn declutter( @@ -228,10 +229,10 @@ impl TypedOp for MetalIntoShape { ) -> TractResult> { if let Some(succ) = model.single_succ(node.id)? { if let Some(into_shape) = succ.op_as::() { - let op = Self(IntoShape { - mapping: self.0.mapping.compose(&into_shape.0.mapping)?, - ..into_shape.0.clone() - }); + let op = Self { + mapping: self.mapping.compose(&into_shape.mapping)?, + dims: into_shape.dims.clone(), + }; return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?)); } } diff --git a/metal/src/ops/fused_axis_op.rs b/metal/src/ops/fused_axis_op.rs new file mode 100644 index 0000000000..09a627ad83 --- /dev/null +++ b/metal/src/ops/fused_axis_op.rs @@ -0,0 +1,98 @@ +use crate::ops::{MetalAxisOp, MetalEvalOp, MetalOpState}; +use crate::tensor::MetalTensorExt; +use crate::MetalContext; +use derive_new::new; +use tract_core::internal::*; + +#[derive(Clone, Debug, new, Hash)] +pub struct FusedAxisOp { + pub axis_ops: Vec>, + pub op: O, +} + +impl Op for FusedAxisOp { + fn name(&self) -> Cow { + self.op.name() + } + + op_as_typed_op!(); +} + +impl MetalEvalOp for FusedAxisOp { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + // Apply Axis Op + let inputs = inputs + .into_iter() + .zip(self.axis_ops.iter()) + .map(|(input, axis_op)| { + let Some(axis_op) = axis_op else { return Ok(input) }; + let new_shape = match &axis_op.0 { + AxisOp::Move(..) => bail!("Cannot fused {:?} with metal op", &axis_op.0), + AxisOp::Reshape(skip, from, to) => { + let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect(); + let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect(); + let mut shape: TVec = input.shape().into(); + AxisOp::Reshape(*skip, from, to).change_shape_array(&mut shape, false)?; + shape + } + _ => { + let mut shape: TVec = input.shape().into(); + axis_op.0.change_shape_array(&mut shape, false)?; + shape + } + }; + let t = input.to_metal_tensor()?; + Ok(t.reshaped(new_shape)?.into_opaque_tensor().into()) + }) + .collect::>>()?; + self.op.metal_eval(context, node_id, session, inputs) + } +} + +impl TypedOp for FusedAxisOp { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!( + inputs.len() == self.axis_ops.len(), + "Number of inputs and fused axis ops are not aligned" + ); + // Apply AxisOp + let inputs = inputs + .into_iter() + .zip(self.axis_ops.iter()) + .map(|(i, axis_op)| { + Ok(axis_op + .as_ref() + .map(|a| -> TractResult<_> { Ok(a.output_facts(&[i])?.pop()) }) + .transpose()? + .flatten() + .unwrap_or_else(|| (*i).clone())) + }) + .collect::>>()?; + let inputs_ref = inputs.iter().collect::>(); + // Apply Op + self.op.output_facts(&inputs_ref) + } + + as_op!(); +} + +impl EvalOp for FusedAxisOp { + fn is_stateless(&self) -> bool { + false + } + + #[allow(unused_variables)] + fn state( + &self, + session: &mut tract_core::internal::SessionState, + node_id: usize, + ) -> TractResult>> { + Ok(Some(Box::new(MetalOpState::new(node_id, self.clone())))) + } +} diff --git a/metal/src/ops/mod.rs b/metal/src/ops/mod.rs index 3e669106ba..5c19907f78 100644 --- a/metal/src/ops/mod.rs +++ b/metal/src/ops/mod.rs @@ -5,6 +5,7 @@ pub mod cast; pub mod change_axes; pub mod concat; pub mod element_wise; +pub mod fused_axis_op; pub mod gemm; pub mod konst; pub mod new_gelu; diff --git a/metal/src/tensor/arena_view.rs b/metal/src/tensor/arena_view.rs index e205e2fb20..30e9362b33 100644 --- a/metal/src/tensor/arena_view.rs +++ b/metal/src/tensor/arena_view.rs @@ -54,6 +54,7 @@ impl Hash for MetalArenaStorage { pub struct MetalArenaView { pub(crate) arena: Arc, pub(crate) dt: DatumType, + pub(crate) len: usize, pub(crate) shape: TVec, pub(crate) strides: TVec, pub(crate) offset_bytes: usize, @@ -93,7 +94,7 @@ impl MetalArenaView { #[inline] #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { - self.shape().iter().product() + self.len } pub fn as_bytes(&self) -> &[u8] { @@ -112,6 +113,26 @@ impl MetalArenaView { ) } } + + /// Reshaped tensor with given shape. + pub fn reshaped(&self, shape: impl Into>) -> TractResult { + let shape = shape.into(); + if self.len() != shape.iter().product::() { + bail!("Invalid reshape {:?} to {:?}", self.shape(), shape); + } + if shape.as_slice() != self.shape() { + Ok(Self { + arena: Arc::clone(&self.arena), + dt: self.dt, + len: self.len, + strides: Tensor::natural_strides(&shape), + shape, + offset_bytes: self.offset_bytes, + }) + } else { + Ok(self.clone()) + } + } } impl Display for MetalArenaView { diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index b94d7a5c84..2597738aa8 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -150,7 +150,7 @@ impl MetalTensor { pub fn reshaped(&self, shape: impl Into>) -> Result { match self { Self::Owned(t) => Ok(Self::Owned(t.reshaped(shape)?)), - Self::ArenaView(_t) => bail!("Reshape a Metal Arena View is not supported"), + Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)), } }