Skip to content

Commit

Permalink
First implementation of FusedAxisOp
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Dec 3, 2024
1 parent 37132e0 commit bba9712
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 45 deletions.
7 changes: 3 additions & 4 deletions metal/src/kernels/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,15 @@ mod tests {
.prop_flat_map(|(b, m, k, n)| {
let lhs_len = b * m * k;
let rhs_len = b * k * n;
let lhs = (0usize..10).prop_map(|x| x.as_());
let rhs = (0usize..10).prop_map(|x| x.as_());
let datum = (0usize..10).prop_map(|x| x.as_());
(
Just(b),
Just(m),
Just(k),
Just(n),
vec(lhs, lhs_len..=lhs_len),
vec(datum.clone(), lhs_len..=lhs_len),
proptest::bool::ANY,
vec(rhs, rhs_len..=rhs_len),
vec(datum, rhs_len..=rhs_len),
proptest::bool::ANY,
)
})
Expand Down
1 change: 1 addition & 0 deletions metal/src/memory/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl MetalMemoryPool {
Ok(MetalArenaView {
arena: Arc::clone(&self.storage),
dt,
len: shape.iter().product(),
shape: shape.into(),
strides: Tensor::natural_strides(shape),
offset_bytes: offset,
Expand Down
79 changes: 40 additions & 39 deletions metal/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,8 @@ impl TypedOp for MetalAxisOp {
as_op!();

fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
crate::utils::metal_facts_from_gpu(inputs, |facts| {
let mut shape = facts[0].shape.clone();
self.0
.change_shape(&mut shape, false)
.with_context(|| format!("Applying {self:?} to {:?}", facts[0]))?;
Ok(tvec!(facts[0].datum_type.fact(shape)))
})
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name()))
crate::utils::metal_facts_from_gpu(inputs, |facts| self.0.output_facts(facts))
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name()))
}

fn axes_mapping(
Expand Down Expand Up @@ -147,36 +141,30 @@ impl TypedOp for MetalAxisOp {
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let conc_shape =
crate::utils::metal_fact(&node.outputs[0].fact, |fact| Ok(fact.shape.as_concrete()))?;
if let Some(shape) = conc_shape {
if !matches!(self, MetalAxisOp(AxisOp::Move(_, _))) {
let (inputs, outputs) = model.node_facts(node.id)?;
let mapping = self.axes_mapping(&inputs, &outputs)?;
let op = MetalIntoShape(IntoShape {
mapping,
len: shape.iter().product(),
strides: Tensor::natural_strides(shape),
dims: shape.into(),
});
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
op,
)?));
}
let shape =
crate::utils::metal_fact(&node.outputs[0].fact, |fact| Ok(fact.shape.to_tvec()))?;
if !matches!(self, MetalAxisOp(AxisOp::Move(_, _))) {
let (inputs, outputs) = model.node_facts(node.id)?;
let mapping = self.axes_mapping(&inputs, &outputs)?;
let op = MetalIntoShape { mapping, dims: shape };
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, op)?));
}
Ok(None)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct MetalIntoShape(IntoShape);
pub struct MetalIntoShape {
pub mapping: AxesMapping,
pub dims: TVec<TDim>,
}

impl MetalIntoShape {
pub fn from_tract_core(core_op: IntoShape) -> Self {
MetalIntoShape(core_op)
MetalIntoShape {
mapping: core_op.mapping,
dims: core_op.dims.into_iter().map(|it| it.into()).collect(),
}
}
}

Expand All @@ -186,7 +174,7 @@ impl Op for MetalIntoShape {
}

fn info(&self) -> TractResult<Vec<String>> {
self.0.info()
Ok(vec![format!("{}", self.mapping)])
}

op_as_typed_op!();
Expand All @@ -205,9 +193,20 @@ impl MetalEvalOp for MetalIntoShape {
) -> TractResult<TVec<TValue>> {
let opaque = args_1!(inputs).into_tensor();
let input = opaque.to_metal_tensor()?;
ensure!(input.len() == self.0.len);
let output =
crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), &self.0.dims)?;
let dims = self
.dims
.iter()
.map(|d| d.eval_to_i64(&session.resolved_symbols).map(|it| it as usize))
.collect::<TractResult<TVec<_>>>()?;
let len = dims.iter().product::<usize>();

ensure!(input.len() == len);
let output = crate::ops::make_tensor_for_node(
session,
node_id,
input.datum_type(),
dims.as_slice(),
)?;

Memcpy.dispatch_eval(context, input, 0, &output)?;

Expand All @@ -217,8 +216,10 @@ impl MetalEvalOp for MetalIntoShape {

impl TypedOp for MetalIntoShape {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
crate::utils::metal_facts_from_gpu(inputs, |facts| self.0.output_facts(facts))
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name()))
crate::utils::metal_facts_from_gpu(inputs, |facts| {
Ok(tvec!(facts[0].datum_type.fact(&self.dims)))
})
.with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name()))
}

fn declutter(
Expand All @@ -228,10 +229,10 @@ impl TypedOp for MetalIntoShape {
) -> TractResult<Option<TypedModelPatch>> {
if let Some(succ) = model.single_succ(node.id)? {
if let Some(into_shape) = succ.op_as::<MetalIntoShape>() {
let op = Self(IntoShape {
mapping: self.0.mapping.compose(&into_shape.0.mapping)?,
..into_shape.0.clone()
});
let op = Self {
mapping: self.mapping.compose(&into_shape.mapping)?,
dims: into_shape.dims.clone(),
};
return Ok(Some(TypedModelPatch::fuse_with_next(model, node, op)?));
}
}
Expand Down
98 changes: 98 additions & 0 deletions metal/src/ops/fused_axis_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use crate::ops::{MetalAxisOp, MetalEvalOp, MetalOpState};
use crate::tensor::MetalTensorExt;
use crate::MetalContext;
use derive_new::new;
use tract_core::internal::*;

#[derive(Clone, Debug, new, Hash)]
pub struct FusedAxisOp<O: MetalEvalOp + TypedOp> {
pub axis_ops: Vec<Option<MetalAxisOp>>,
pub op: O,
}

impl<O: MetalEvalOp + TypedOp> Op for FusedAxisOp<O> {
fn name(&self) -> Cow<str> {
self.op.name()
}

op_as_typed_op!();
}

impl<O: MetalEvalOp + TypedOp> MetalEvalOp for FusedAxisOp<O> {
fn metal_eval(
&self,
context: &MetalContext,
node_id: usize,
session: &mut SessionState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
// Apply Axis Op
let inputs = inputs
.into_iter()
.zip(self.axis_ops.iter())
.map(|(input, axis_op)| {
let Some(axis_op) = axis_op else { return Ok(input) };
let new_shape = match &axis_op.0 {
AxisOp::Move(..) => bail!("Cannot fused {:?} with metal op", &axis_op.0),
AxisOp::Reshape(skip, from, to) => {
let from = from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
let mut shape: TVec<usize> = input.shape().into();
AxisOp::Reshape(*skip, from, to).change_shape_array(&mut shape, false)?;
shape
}
_ => {
let mut shape: TVec<usize> = input.shape().into();
axis_op.0.change_shape_array(&mut shape, false)?;
shape
}
};
let t = input.to_metal_tensor()?;
Ok(t.reshaped(new_shape)?.into_opaque_tensor().into())
})
.collect::<TractResult<TVec<_>>>()?;
self.op.metal_eval(context, node_id, session, inputs)
}
}

impl<O: MetalEvalOp + TypedOp> TypedOp for FusedAxisOp<O> {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(
inputs.len() == self.axis_ops.len(),
"Number of inputs and fused axis ops are not aligned"
);
// Apply AxisOp
let inputs = inputs
.into_iter()
.zip(self.axis_ops.iter())
.map(|(i, axis_op)| {
Ok(axis_op
.as_ref()
.map(|a| -> TractResult<_> { Ok(a.output_facts(&[i])?.pop()) })
.transpose()?
.flatten()
.unwrap_or_else(|| (*i).clone()))
})
.collect::<TractResult<TVec<_>>>()?;
let inputs_ref = inputs.iter().collect::<TVec<_>>();
// Apply Op
self.op.output_facts(&inputs_ref)
}

as_op!();
}

impl<O: MetalEvalOp + TypedOp> EvalOp for FusedAxisOp<O> {
fn is_stateless(&self) -> bool {
false
}

#[allow(unused_variables)]
fn state(
&self,
session: &mut tract_core::internal::SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(MetalOpState::new(node_id, self.clone()))))
}
}
1 change: 1 addition & 0 deletions metal/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod cast;
pub mod change_axes;
pub mod concat;
pub mod element_wise;
pub mod fused_axis_op;
pub mod gemm;
pub mod konst;
pub mod new_gelu;
Expand Down
23 changes: 22 additions & 1 deletion metal/src/tensor/arena_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl Hash for MetalArenaStorage {
pub struct MetalArenaView {
pub(crate) arena: Arc<MetalArenaStorage>,
pub(crate) dt: DatumType,
pub(crate) len: usize,
pub(crate) shape: TVec<usize>,
pub(crate) strides: TVec<isize>,
pub(crate) offset_bytes: usize,
Expand Down Expand Up @@ -93,7 +94,7 @@ impl MetalArenaView {
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.shape().iter().product()
self.len
}

pub fn as_bytes(&self) -> &[u8] {
Expand All @@ -112,6 +113,26 @@ impl MetalArenaView {
)
}
}

/// Reshaped tensor with given shape.
pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> TractResult<Self> {
let shape = shape.into();
if self.len() != shape.iter().product::<usize>() {
bail!("Invalid reshape {:?} to {:?}", self.shape(), shape);
}
if shape.as_slice() != self.shape() {
Ok(Self {
arena: Arc::clone(&self.arena),
dt: self.dt,
len: self.len,
strides: Tensor::natural_strides(&shape),
shape,
offset_bytes: self.offset_bytes,
})
} else {
Ok(self.clone())
}
}
}

impl Display for MetalArenaView {
Expand Down
2 changes: 1 addition & 1 deletion metal/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl MetalTensor {
pub fn reshaped(&self, shape: impl Into<TVec<usize>>) -> Result<Self> {
match self {
Self::Owned(t) => Ok(Self::Owned(t.reshaped(shape)?)),
Self::ArenaView(_t) => bail!("Reshape a Metal Arena View is not supported"),
Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)),
}
}

Expand Down

0 comments on commit bba9712

Please sign in to comment.