diff --git a/metal/src/kernels/nn/mod.rs b/metal/src/kernels/nn/mod.rs index 61c170af42..4a1ed6b582 100644 --- a/metal/src/kernels/nn/mod.rs +++ b/metal/src/kernels/nn/mod.rs @@ -2,6 +2,7 @@ pub mod apply_rope; pub mod new_gelu; pub mod reduce; pub mod rms_norm; +pub mod scaled_masked_softmax; pub mod silu; pub mod softmax; @@ -9,6 +10,7 @@ pub use apply_rope::ApplyRope; pub use new_gelu::NewGelu; pub use reduce::Reducer; pub use rms_norm::RmsNorm; +pub use scaled_masked_softmax::ScaledMaskedSoftmax; pub use silu::Silu; pub use softmax::Softmax; @@ -28,5 +30,11 @@ pub fn all_functions() -> Vec { .flat_map(|dt| Softmax.kernel_name(dt).into_iter()), ); + functions.extend( + crate::MetalTensor::SUPPORTED_DT + .into_iter() + .flat_map(|dt| ScaledMaskedSoftmax.kernel_name(dt).into_iter()), + ); + functions.into_iter().collect() } diff --git a/metal/src/kernels/nn/nn_ops.metal b/metal/src/kernels/nn/nn_ops.metal index fcf836b80f..2f82533490 100644 --- a/metal/src/kernels/nn/nn_ops.metal +++ b/metal/src/kernels/nn/nn_ops.metal @@ -276,6 +276,70 @@ typedef decltype(softmax_nd3) softmax_nd3_t; template [[host_name("nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3; template [[host_name("nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3; +template +[[kernel]] void scaled_masked_softmax_nd3( + device const void *input_b, + device const void *mask_b, + constant void *scale_b, + device void *output_b, + constant const size_t shape[3], + constant const size_t strides[3], + constant const size_t mask_strides[3], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint tpsg[[threads_per_simdgroup]] + ) { + + device const F *input = (device const F *)input_b; + device const F *mask = (device const F *)mask_b; + F scale = ((constant F *)scale_b)[0]; + device F *output = (device F *)output_b; + + size_t reduce_dim = shape[2]; + + size_t base_idx = tgpig.y * strides[1] + + tgpig.z * strides[0]; + + size_t mask_base_idx = tgpig.y * mask_strides[1] + + tgpig.z * mask_strides[0]; + + // Get max value on softmax reduce_dim after applying scale and mask + float partial_max = -INFINITY; + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { + auto idx = base_idx + i * strides[2]; + auto mask_idx = mask_base_idx + i * mask_strides[2]; + output[idx] = input[idx] * scale + mask[mask_idx]; + float el = static_cast(output[idx]); + partial_max = max(partial_max, el); + } + + float axis_max = simd_max(partial_max); + + // Compute Sum(exp(x - max)) + float partial_norm = 0; + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { + auto idx = base_idx + i * strides[2]; + float el = static_cast(output[idx]); + float exp_el = fast::exp(el - axis_max); + partial_norm += exp_el; + } + + float axis_norm = simd_sum(partial_norm); + float inv_axis_norm = 1.0 / axis_norm; + + for (size_t i = tiisg; i < reduce_dim; i += tpsg) { + auto idx = base_idx + i * strides[2]; + float el = static_cast(output[idx]); + float exp_el = fast::exp(el - axis_max); + output[idx] = static_cast(exp_el * inv_axis_norm); + } +} + +typedef decltype(scaled_masked_softmax_nd3) scaled_masked_softmax_nd3_t; + +template [[host_name("nn_ops::scaled_masked_softmax_nd3_f32")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3; +template [[host_name("nn_ops::scaled_masked_softmax_nd3_f16")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3; + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; diff --git a/metal/src/kernels/nn/scaled_masked_softmax.rs b/metal/src/kernels/nn/scaled_masked_softmax.rs new file mode 100644 index 0000000000..b38e6ec0e4 --- /dev/null +++ b/metal/src/kernels/nn/scaled_masked_softmax.rs @@ -0,0 +1,247 @@ +use crate::encoder::EncoderExt; +use crate::{LibraryName, MetalContext, MetalTensor}; +use anyhow::Result; +use metal::MTLSize; +use tract_core::internal::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ScaledMaskedSoftmax; + +impl ScaledMaskedSoftmax { + pub fn is_supported_dt(dt: DatumType) -> bool { + matches!(dt, DatumType::F32 | DatumType::F16) + } + + pub fn kernel_name(&self, dt: DatumType) -> Result { + ensure!( + Self::is_supported_dt(dt), + "Unsupport dt {:?} for metal scaled masked softmax op", + dt + ); + let tname = MetalTensor::tname(dt)?; + Ok(format!("nn_ops::scaled_masked_softmax_nd3_{tname}")) + } + + pub fn eval( + &self, + context: &MetalContext, + input: &MetalTensor, + scale: &Tensor, + mask: &MetalTensor, + ) -> Result { + let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? }; + self.dispatch_eval(context, input, scale, mask, &output)?; + context.wait_until_completed()?; + Ok(output) + } + + pub fn dispatch_eval( + &self, + context: &MetalContext, + input: &MetalTensor, + scale: &Tensor, + mask: &MetalTensor, + output: &MetalTensor, + ) -> Result<()> { + input.retained_until_completion(); + mask.retained_until_completion(); + output.retained_until_completion(); + + ensure!(output.shape() == input.shape()); + ensure!(mask.rank() == 3 && input.rank() == 3); + ensure!(output.datum_type() == input.datum_type()); + + let shape = input.shape(); + let strides = input.strides(); + let mask_strides_nd3 = + crate::utils::compute_broadcast_strides::(mask.shape(), mask.strides())?; + + let pipeline = context + .shared_context() + .load_pipeline(LibraryName::NNOps, &self.kernel_name(input.datum_type())?)?; + + let command_buffer = context.command_buffer(); + command_buffer.encode(|encoder| { + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read); + encoder.set_metal_tensor(1, mask, metal::MTLResourceUsage::Read); + encoder.set_tensor(2, scale); + encoder.set_metal_tensor(3, output, metal::MTLResourceUsage::Write); + encoder.set_slice(4, shape); + encoder.set_slice(5, strides); + encoder.set_slice(6, &mask_strides_nd3); + let grid_size = MTLSize { width: 1 as _, height: shape[1] as _, depth: shape[0] as _ }; + let group_size = MTLSize { width: usize::min(32, shape[2]) as _, height: 1, depth: 1 }; + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + }); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rewrite_rules::BasicScaledMaskedSoftmax; + use crate::IntoMetal; + use derive_new::new; + use num_traits::AsPrimitive; + use num_traits::Float; + use proptest::collection::vec; + use proptest::prelude::*; + use proptest::strategy::Strategy; + use tract_core::internal::Tensor; + + #[test] + fn test_scaled_masked_softmax_f32() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let m = 4; + let n = 4; + let scale: Arc<_> = tensor0(0.125f32).into(); + let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?; + + let a = Tensor::from_shape( + &[1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_metal()?; + + let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() }; + + let cpu_output = cpu + .eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0] + .clone() + .into_tensor(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; + Ok(()) + }) + }) + } + + #[test] + fn test_scaled_masked_softmax_f32_2() -> Result<()> { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let m = 4; + let n = 1024; + let scale: Arc<_> = tensor0(0.125f32).into(); + let mask = Tensor::from_shape(&[1, m, n], &vec![-1000f32; m * n])?.into_metal()?; + + let a = Tensor::from_shape( + &[1, m, n], + &(0..m * n).map(|f| f as f32).collect::>(), + )? + .into_metal()?; + + let cpu = BasicScaledMaskedSoftmax { scale: scale.clone() }; + + let cpu_output = cpu + .eval(tvec![a.to_cpu()?.into_tvalue(), mask.to_cpu()?.into_tvalue()])?[0] + .clone() + .into_tensor(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + cpu_output.close_enough(&metal_output.to_cpu()?, Approximation::Approximate)?; + Ok(()) + }) + }) + } + + proptest::proptest! { + #[test] + fn scaled_masked_softmax_prop_f32(pb in any::>()) { + fn run(pb: ScaledMaskedSoftmaxProblem) -> TractResult<()> { + let out = pb.run()?; + let reference = pb.reference()?; + + out.close_enough(&reference, Approximation::Approximate) + .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) + } + run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; + } + + #[test] + fn scaled_masked_softmax_prop_f16(pb in any::>()) { + fn run(pb: ScaledMaskedSoftmaxProblem) -> TractResult<()> { + let out = pb.run()?; + let reference = pb.reference()?; + + out.close_enough(&reference, Approximation::Approximate) + .with_context(|| anyhow!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true))) + } + + run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?; + } + } + + #[derive(Debug, new)] + pub struct ScaledMaskedSoftmaxProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + pub shape: Vec, + pub mask_shape: Vec, + pub input: Vec, + pub mask: Vec, + } + + impl Arbitrary for ScaledMaskedSoftmaxProblem + where + F: Datum + Float, + usize: AsPrimitive, + { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: ()) -> Self::Strategy { + vec(1usize..10, 3..=3) + .prop_map(|shape| { + let mut mask_shape = shape.clone(); + mask_shape[0] = 1; + + let input = (0..shape.iter().product::()) + .map(|f| f.as_() / 1000.as_()) + .collect::>(); + + let mask = (0..mask_shape.iter().product::()) + .map(|f| f.as_() / 1000.as_()) + .collect::>(); + Self { shape, input, mask_shape, mask } + }) + .boxed() + } + } + + impl ScaledMaskedSoftmaxProblem + where + F: Datum + Float + std::ops::AddAssign, + usize: AsPrimitive, + { + pub fn reference(&self) -> Result { + let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?; + let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?; + let scale: Arc<_> = tensor0(0.125f32).into(); + + let cpu_output = BasicScaledMaskedSoftmax { scale } + .eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0] + .clone() + .into_tensor(); + Ok(cpu_output) + } + + pub fn run(&self) -> Result { + objc::rc::autoreleasepool(|| { + crate::METAL_CONTEXT.with_borrow(|context| { + let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?; + let mask = + Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?; + let scale: Arc<_> = tensor0(0.125f32).into(); + let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?; + metal_output.to_cpu() + }) + }) + } + } +} diff --git a/metal/src/kernels/nn/softmax.rs b/metal/src/kernels/nn/softmax.rs index 6024c1d911..34a19bd3fb 100644 --- a/metal/src/kernels/nn/softmax.rs +++ b/metal/src/kernels/nn/softmax.rs @@ -63,7 +63,6 @@ impl Softmax { MTLSize { width: shape_nd3[2] as _, height: 1, depth: shape_nd3[0] as _ }; let group_size = MTLSize { width: usize::min(32, shape_nd3[1]) as _, height: 1, depth: 1 }; - encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); }); diff --git a/metal/src/ops/mod.rs b/metal/src/ops/mod.rs index 2fca47f0c6..613554ccb8 100644 --- a/metal/src/ops/mod.rs +++ b/metal/src/ops/mod.rs @@ -11,6 +11,7 @@ pub mod new_gelu; pub mod reduce; pub mod rms_norm; pub mod rotate_half; +pub mod scaled_masked_softmax; pub mod silu; pub mod slice; pub mod softmax; @@ -29,6 +30,7 @@ pub use new_gelu::MetalNewGelu; pub use reduce::MetalReduce; pub use rms_norm::MetalRmsNorm; pub use rotate_half::MetalRotateHalf; +pub use scaled_masked_softmax::MetalScaledMaskedSoftmax; pub use silu::MetalSilu; pub use slice::MetalSlice; pub use softmax::MetalSoftmax; diff --git a/metal/src/ops/scaled_masked_softmax.rs b/metal/src/ops/scaled_masked_softmax.rs new file mode 100644 index 0000000000..5715385ef2 --- /dev/null +++ b/metal/src/ops/scaled_masked_softmax.rs @@ -0,0 +1,57 @@ +use crate::kernels::nn::ScaledMaskedSoftmax; +use crate::ops::MetalEvalOp; +use crate::tensor::MetalTensorExt; +use crate::MetalContext; +use derive_new::new; +use tract_core::internal::*; + +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) +/// Only input of rank of 3 is supported +#[derive(Clone, Debug, new, Hash)] +pub struct MetalScaledMaskedSoftmax { + pub scale: Arc, +} + +impl Op for MetalScaledMaskedSoftmax { + fn name(&self) -> Cow { + "MetalScaledMaskedSoftmax".into() + } + + op_as_typed_op!(); +} + +impl MetalEvalOp for MetalScaledMaskedSoftmax { + fn metal_eval( + &self, + context: &MetalContext, + node_id: usize, + session: &mut SessionState, + inputs: TVec, + ) -> TractResult> { + let (opaque_input, opaque_mask) = args_2!(inputs); + let input = opaque_input.to_metal_tensor()?; + let mask = opaque_mask.to_metal_tensor()?; + let output = + crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), input.shape())?; + ScaledMaskedSoftmax.dispatch_eval(context, input, &self.scale, mask, &output)?; + Ok(tvec!(output.into_opaque_tensor().into_tvalue())) + } +} + +impl TypedOp for MetalScaledMaskedSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + crate::utils::metal_facts_from_gpu(inputs, |facts| { + ensure!(facts.len() == 2); + let dt = facts[0].datum_type; + ensure!(dt == facts[1].datum_type); + ensure!(facts[0].rank() == 3 && facts[1].rank() == 3); + let fact = dt.fact(facts[0].shape.clone()); + Ok(tvec!(fact)) + }) + .with_context(|| anyhow::anyhow!("Error while computing facts for {:?}", self.name())) + } + + as_op!(); +} + +crate::impl_eval_op_for_metal_op!(MetalScaledMaskedSoftmax); diff --git a/metal/src/rewrite_rules/fuse_axis_op.rs b/metal/src/rewrite_rules/fuse_axis_op.rs index b68b9facd0..9495390214 100644 --- a/metal/src/rewrite_rules/fuse_axis_op.rs +++ b/metal/src/rewrite_rules/fuse_axis_op.rs @@ -109,6 +109,7 @@ pub fn fuse_axis_op( crate::ops::MetalSlice, crate::ops::MetalConcat, crate::ops::MetalCast, + crate::ops::MetalScaledMaskedSoftmax, ); // Handle AxisOp::Move operator. diff --git a/metal/src/rewrite_rules/mod.rs b/metal/src/rewrite_rules/mod.rs index 73e7d05ec2..93c01740dc 100644 --- a/metal/src/rewrite_rules/mod.rs +++ b/metal/src/rewrite_rules/mod.rs @@ -4,6 +4,7 @@ mod new_gelu; mod rewire_metal_sync; mod rms_norm; mod rotate_half; +mod scaled_masked_softmax; mod silu; use tract_core::internal::*; @@ -15,6 +16,7 @@ pub use new_gelu::{as_new_gelu_rule, BasicNewGelu}; pub use rewire_metal_sync::{rewire_metal_sync, rewire_metal_sync_after_const}; pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm}; pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf}; +pub use scaled_masked_softmax::{as_scaled_masked_softmax_rule, BasicScaledMaskedSoftmax}; pub use silu::{as_silu_rule, BasicSilu}; use tract_core::ops::binary::TypedBinOp; diff --git a/metal/src/rewrite_rules/scaled_masked_softmax.rs b/metal/src/rewrite_rules/scaled_masked_softmax.rs new file mode 100644 index 0000000000..afdd3f8111 --- /dev/null +++ b/metal/src/rewrite_rules/scaled_masked_softmax.rs @@ -0,0 +1,118 @@ +use crate::rewrite_rules::{collect_node_const_inputs, previous_node, previous_nodes}; +use crate::rule_ensure; +use tract_core::ops::binary::TypedBinOp; + +use std::sync::Arc; +use tract_core::internal::*; +use tract_core::ops::binary::BinMiniOp; +use tract_core::ops::math::{Add, Mul}; +use tract_core::ops::nn::{Softmax, SoftmaxExp}; + +/// A = SOFTMAX(INPUT * SCALE + MASK, AXIS=2) +/// Only input of rank of 3 is supported. +#[derive(Clone, Debug, Hash)] +pub struct BasicScaledMaskedSoftmax { + pub scale: Arc, +} + +impl Op for BasicScaledMaskedSoftmax { + fn name(&self) -> Cow { + "BasicScaledMaskedSoftmax".to_string().into() + } + fn info(&self) -> TractResult> { + Ok(vec![format!("scale: {:?}", self.scale)]) + } + op_as_typed_op!(); +} + +impl EvalOp for BasicScaledMaskedSoftmax { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let (input, mask) = args_2!(inputs); + let dt = input.datum_type(); + let scaled_input = Mul.eval(input, self.scale.clone().into_tvalue(), dt)?; + let masked_input = Add.eval(scaled_input.into(), mask, dt)?; + let softmax = Softmax::new(tvec![2], None, SoftmaxExp::Libc) + .eval(tvec![masked_input.into()])?[0] + .clone(); + Ok(tvec![softmax]) + } +} + +impl TypedOp for BasicScaledMaskedSoftmax { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(inputs.len() == 2); + let (input, mask) = (inputs[0], inputs[1]); + ensure!(input.datum_type == mask.datum_type); + ensure!(input.rank() == 3 && mask.rank() == 3); + let dt = input.datum_type; + let fact = dt.fact(input.shape.clone()); + Ok(tvec!(fact)) + } + + as_op!(); +} + +/// Search pattern => A = SOFTMAX(A * SCALE + MASK, AXIS=2) +pub fn as_scaled_masked_softmax_rule( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + node_name: &str, + op: &Softmax, +) -> TractResult> { + rule_ensure!(op.axes.as_slice() == [2]); + + let in_fact = model.node_input_facts(node.id)?[0]; + let dt = in_fact.datum_type; + // Only F16 and F32 is supported. + rule_ensure!(matches!(dt, DatumType::F32 | DatumType::F16)); + + // Identify Add operator (Mask) + let Some(add_prev) = previous_node(model, node) else { return Ok(None) }; + let Some(add_prev_op) = add_prev.op_as::() else { return Ok(None) }; + rule_ensure!(add_prev_op.0.is::()); + + let mut in_add = previous_nodes(model, add_prev); + rule_ensure!(in_add.len() == 2); + + in_add.reverse(); + let (left, right) = (in_add.pop().unwrap(), in_add.pop().unwrap()); + + let (scale_node, mask_outlet) = if left.op_is::() { + (left, add_prev.inputs[1]) + } else { + (right, add_prev.inputs[0]) + }; + + let Some(scale_op) = scale_node.op_as::() else { return Ok(None) }; + rule_ensure!(scale_op.0.is::()); + + // Retrieve Scale + let mul_consts = collect_node_const_inputs(model, scale_node); + rule_ensure!(mul_consts.len() == 1); + let scale = mul_consts[0].0.clone(); + + rule_ensure!(scale.len() == 1); + rule_ensure!(scale.datum_type() == dt); + + // Ensure input and mask have the same rank + rule_ensure!(model.outlet_fact(scale_node.inputs[0])?.shape.rank() == 3); + rule_ensure!(model.outlet_fact(mask_outlet)?.shape.rank() == 3); + + let mut patch = TypedModelPatch::default(); + let input = patch.taps(model, &scale_node.inputs)?[0]; + let mask = patch.taps(model, &[mask_outlet])?[0]; + + let out = patch.wire_node( + format!("{node_name}.scaled_masked_softmax"), + BasicScaledMaskedSoftmax { scale }, + &[input, mask], + )?; + + patch.shunt_outside(model, node.id.into(), out[0])?; + Ok(Some(patch)) +} diff --git a/metal/src/transform.rs b/metal/src/transform.rs index aa65b5b5bf..b5a7590003 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -1,14 +1,17 @@ use crate::fact::MetalTypedFactExt; use crate::kernels::array::RotateHalf; use crate::kernels::matmul::{MetalGemmImplKind, MfaGemm, MlxGemm, MpsMatMul}; -use crate::kernels::nn::{ApplyRope, NewGelu, Reducer, RmsNorm, Silu, Softmax}; +use crate::kernels::nn::{ + ApplyRope, NewGelu, Reducer, RmsNorm, ScaledMaskedSoftmax, Silu, Softmax, +}; use crate::ops::{self, MetalAxisOp, MetalSync, MetalSyncKind}; #[allow(unused_imports)] use crate::rewrite_rules::{ - as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, as_silu_rule, - fuse_axis_op, remove_rms_norm_cast, rewire_metal_sync, rewire_metal_sync_after_const, - BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicSilu, + as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, + as_scaled_masked_softmax_rule, as_silu_rule, fuse_axis_op, remove_rms_norm_cast, + rewire_metal_sync, rewire_metal_sync_after_const, BasicApplyRope, BasicNewGelu, BasicRmsNorm, + BasicRotateHalf, BasicScaledMaskedSoftmax, BasicSilu, }; use crate::tensor::MetalTensorExt; use crate::{IntoMetal, MetalFact, MetalTensor}; @@ -64,6 +67,7 @@ impl ModelTransform for MetalTransform { .with_rule_for::("as-new-gelu", as_new_gelu_rule) .with_rule_for::("as-rotate-half", as_rotate_half_rule) .with_rule_for::("as-apply-rope", as_apply_rope_rule) + .with_rule_for::("as-scaled-masked-softmax", as_scaled_masked_softmax_rule) .rewrite(&(), model)?; let mut new = self.translate_model(model)?; @@ -209,6 +213,10 @@ impl Translate, TypedFact, Box> for Met .then(|| ops::MetalSoftmax::from_tract_core(op).ok()) .flatten() .map(|o| -> Box { Box::new(o) }) + } else if let Some(op) = node.op_as::() { + check_in_dts_are_supported(source, node.id, ScaledMaskedSoftmax::is_supported_dt)? + .then(|| ops::MetalScaledMaskedSoftmax { scale: op.scale.clone() }) + .map(|o| -> Box { Box::new(o) }) } else if let Some(op) = node.op_as::() { check_in_dts_are_supported(source, node.id, RmsNorm::is_supported_dt)? .then(|| ops::MetalRmsNorm::new(op.axis, op.eps.clone()))