diff --git a/core/src/ops/einsum/as_matmul.rs b/core/src/ops/einsum/as_matmul.rs index 929fd45e1d..812f230cd5 100644 --- a/core/src/ops/einsum/as_matmul.rs +++ b/core/src/ops/einsum/as_matmul.rs @@ -136,7 +136,7 @@ fn einsum_rules( Ok(Some(patch)) } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Copy, Default)] pub struct BasicMatMul { pub transpose_a: bool, pub transpose_b: bool, diff --git a/metal/src/kernels/array/cast.rs b/metal/src/kernels/array/cast.rs index 54ebf152c6..4a3e00b9bd 100644 --- a/metal/src/kernels/array/cast.rs +++ b/metal/src/kernels/array/cast.rs @@ -30,6 +30,7 @@ impl Cast { | DatumType::I16 | DatumType::I32 | DatumType::I64 + | DatumType::Bool ) } diff --git a/metal/src/rewrite_rules/mod.rs b/metal/src/rewrite_rules/mod.rs index 93c01740dc..35ef699d4b 100644 --- a/metal/src/rewrite_rules/mod.rs +++ b/metal/src/rewrite_rules/mod.rs @@ -6,6 +6,7 @@ mod rms_norm; mod rotate_half; mod scaled_masked_softmax; mod silu; +mod untranspose_matmul_output; use tract_core::internal::*; use tract_core::ops::konst::Const; @@ -18,6 +19,7 @@ pub use rms_norm::{as_rms_norm_rule, remove_rms_norm_cast, BasicRmsNorm}; pub use rotate_half::{as_rotate_half_rule, BasicRotateHalf}; pub use scaled_masked_softmax::{as_scaled_masked_softmax_rule, BasicScaledMaskedSoftmax}; pub use silu::{as_silu_rule, BasicSilu}; +pub use untranspose_matmul_output::untranspose_matmul_output; use tract_core::ops::binary::TypedBinOp; use tract_core::ops::math::{Add, Mul}; diff --git a/metal/src/rewrite_rules/untranspose_matmul_output.rs b/metal/src/rewrite_rules/untranspose_matmul_output.rs new file mode 100644 index 0000000000..326252f5cf --- /dev/null +++ b/metal/src/rewrite_rules/untranspose_matmul_output.rs @@ -0,0 +1,24 @@ +use crate::rule_ensure; +use tract_core::internal::*; +use tract_core::ops::einsum::BasicMatMul; + +/// Rewrite BasicMatMul { .. transpose_c: true } to BasicMatMul { .. transpose_c: false} +pub fn untranspose_matmul_output( + _ctx: &(), + model: &TypedModel, + node: &TypedNode, + _node_name: &str, + op: &BasicMatMul, +) -> TractResult> { + rule_ensure!(op.transpose_c); + + let new_matmul = BasicMatMul { + transpose_a: !op.transpose_b, + transpose_b: !op.transpose_a, + transpose_c: false, + ..*op + }; + + TypedModelPatch::replace_single_op(model, node, &[node.inputs[1], node.inputs[0]], new_matmul) + .map(Some) +} diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 440f62e674..6897d176ea 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -4,14 +4,12 @@ use crate::kernels::matmul::{MetalGemmImplKind, MfaGemm, MlxGemm, MpsMatMul}; use crate::kernels::nn::{ ApplyRope, NewGelu, Reducer, RmsNorm, ScaledMaskedSoftmax, Silu, Softmax, }; -use crate::ops::{self, MetalAxisOp, MetalSync, MetalSyncKind}; +use crate::ops::{self, MetalSync, MetalSyncKind}; -#[allow(unused_imports)] +use crate::rewrite_rules; use crate::rewrite_rules::{ - as_apply_rope_rule, as_new_gelu_rule, as_rms_norm_rule, as_rotate_half_rule, - as_scaled_masked_softmax_rule, as_silu_rule, fuse_axis_op, remove_rms_norm_cast, - rewire_metal_sync, rewire_metal_sync_after_const, BasicApplyRope, BasicNewGelu, BasicRmsNorm, - BasicRotateHalf, BasicScaledMaskedSoftmax, BasicSilu, + BasicApplyRope, BasicNewGelu, BasicRmsNorm, BasicRotateHalf, BasicScaledMaskedSoftmax, + BasicSilu, }; use crate::tensor::MetalTensorExt; use crate::{IntoMetal, MetalFact, MetalTensor}; @@ -64,20 +62,25 @@ impl ModelTransform for MetalTransform { } impl MetalTransform { - pub fn transform_up_to_phase(&self, model: &mut TypedModel, stop_at_phase: usize) -> TractResult<()> { + pub fn transform_up_to_phase( + &self, + model: &mut TypedModel, + stop_at_phase: usize, + ) -> TractResult<()> { rewrite_einsums_as_matmul(model)?; if stop_at_phase == 0 { return Ok(()); } Rewriter::default() - .with_rule_for::("as-rms-norm", as_rms_norm_rule) - .with_rule_for::("remove_rms_norm_cast", remove_rms_norm_cast) - .with_rule_for::("as-silu", as_silu_rule) - .with_rule_for::("as-new-gelu", as_new_gelu_rule) - .with_rule_for::("as-rotate-half", as_rotate_half_rule) - .with_rule_for::("as-apply-rope", as_apply_rope_rule) - .with_rule_for::("as-scaled-masked-softmax", as_scaled_masked_softmax_rule) + .with_rule_for("as-rms-norm", rewrite_rules::as_rms_norm_rule) + .with_rule_for("remove_rms_norm_cast", rewrite_rules::remove_rms_norm_cast) + .with_rule_for("as-silu", rewrite_rules::as_silu_rule) + .with_rule_for("as-new-gelu", rewrite_rules::as_new_gelu_rule) + .with_rule_for("as-rotate-half", rewrite_rules::as_rotate_half_rule) + .with_rule_for("as-apply-rope", rewrite_rules::as_apply_rope_rule) + .with_rule_for("as-scaled-masked-softmax", rewrite_rules::as_scaled_masked_softmax_rule) + .with_rule_for("untranspose-matmul-output", rewrite_rules::untranspose_matmul_output) .rewrite(&(), model)?; if stop_at_phase == 1 { @@ -91,9 +94,12 @@ impl MetalTransform { } Rewriter::default() - .with_rule_for::("rewire-metal-sync", rewire_metal_sync) - .with_rule_for::("rewire-metal-sync-after-const", rewire_metal_sync_after_const) - .with_rule_for::("fuse_axis_op", fuse_axis_op) + .with_rule_for("rewire-metal-sync", rewrite_rules::rewire_metal_sync) + .with_rule_for( + "rewire-metal-sync-after-const", + rewrite_rules::rewire_metal_sync_after_const, + ) + .with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op) .rewrite(&(), model)?; Ok(()) } diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index b1c44c5d09..39bffaa067 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -557,7 +557,7 @@ pub fn rewrite_matmul_to_same_rank( inputs[1] = patch.wire_node(format!("{prefix}.extra_b_axis.{i}"), AxisOp::Add(0), &[inputs[1]])?[1]; } - let result = patch.wire_node(prefix, op.clone(), &inputs)?[0]; + let result = patch.wire_node(prefix, *op, &inputs)?[0]; patch.shunt_outside(model, node.id.into(), result)?; Ok(Some(patch)) } diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 7c92629b4b..9d6aa409cc 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -51,7 +51,7 @@ fn trivial_axes_around_matmul( wire[1] = patch.wire_node(format!("{name}.rm_b_axis_{axis}"), AxisOp::Rm(*axis), &[wire[1]])?[0]; } - let mut out = patch.wire_node(&node.name, conv.clone(), &wire)?; + let mut out = patch.wire_node(&node.name, *conv, &wire)?; for axis in trivial_axes { out = patch.wire_node(format!("{name}.add_axis_{axis}"), AxisOp::Add(axis), &out)?; }