diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index be85935bd4..83d3850fff 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -96,7 +96,7 @@ impl GemmDispatchParams { n, k, transpose_a, - a_offset: a_offset + a_batch_idx * n * k * dt.size_of(), + a_offset: a_offset + a_batch_idx * m * k * dt.size_of(), transpose_b, b_offset, c_offset: c_offset + a_batch_idx * m * n * dt.size_of(), @@ -446,6 +446,46 @@ mod tests { }] ); + assert_eq!( + GemmDispatchParams::compute_dispatches_params( + dt, + 0, + &[2, k, m], + true, + 0, + &[1, k, n], + false, + 100, + &[2, m, n], + )?, + vec![ + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 0, + transpose_b: false, + b_offset: 0, + c_offset: 100, + }, + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 1 * m * k * dt.size_of(), + transpose_b: false, + b_offset: 0, + c_offset: 100 + 1 * m * n * dt.size_of(), + } + ] + ); + assert_eq!( GemmDispatchParams::compute_dispatches_params( dt, diff --git a/metal/src/transform.rs b/metal/src/transform.rs index ba990475c3..748e2a4c57 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -56,7 +56,17 @@ impl ModelTransform for MetalTransform { } fn transform(&self, model: &mut TypedModel) -> TractResult<()> { + self.transform_up_to_phase(model, usize::MAX) + } +} + +impl MetalTransform { + pub fn transform_up_to_phase(&self, model: &mut TypedModel, stop_at_phase: usize) -> TractResult<()> { rewrite_einsums_as_matmul(model)?; + if stop_at_phase == 0 { + return Ok(()); + } + Rewriter::default() .with_rule_for::("as-rms-norm", as_rms_norm_rule) .with_rule_for::("remove_rms_norm_cast", remove_rms_norm_cast) @@ -66,18 +76,23 @@ impl ModelTransform for MetalTransform { //.with_rule_for::("as-apply-rope", as_apply_rope_rule) .rewrite(&(), model)?; - let mut new = self.translate_model(model)?; + if stop_at_phase == 1 { + return Ok(()); + } + + *model = self.translate_model(model)?; + + if stop_at_phase == 2 { + return Ok(()); + } Rewriter::default() .with_rule_for::("rewire-metal-sync", rewire_metal_sync) .with_rule_for::("fuse_axis_op", fuse_axis_op) - .rewrite(&(), &mut new)?; - *model = new; + .rewrite(&(), model)?; Ok(()) } -} -impl MetalTransform { fn sync_inputs_if_required( &self, model: &mut TypedModel, diff --git a/test-rt/suite-unit/src/conv_f32.rs b/test-rt/suite-unit/src/conv_f32.rs index 0ab23f493f..b050b89a24 100644 --- a/test-rt/suite-unit/src/conv_f32.rs +++ b/test-rt/suite-unit/src/conv_f32.rs @@ -1229,5 +1229,19 @@ pub fn suite() -> TractResult { }, ); + suite.add( + "bug_metal_0", + ConvProblem { + shape_in: DataFormat::NHWC.from_n_c_hw(2, 1, [4]).unwrap(), + kernel_format: KernelFormat::OIHW, + group: 1, + data: arr3(&[[[0f32], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.]]]).into_dyn(), + kernel: arr3(&[[[0f32]], [[1.]]]).into_dyn(), + bias: None, + pad: PaddingSpec::Valid, + strides: tvec!(1), + }, + ); + Ok(suite) } diff --git a/test-rt/test-metal/src/lib.rs b/test-rt/test-metal/src/lib.rs index 3d0dd12874..c042852219 100644 --- a/test-rt/test-metal/src/lib.rs +++ b/test-rt/test-metal/src/lib.rs @@ -1,33 +1,55 @@ #![cfg(all(test, any(target_os = "macos", target_os = "ios")))] +use std::borrow::Cow; +use std::sync::Arc; +use tract_core::internal::*; + +use tract_core::runtime::Runtime; + #[path = "../suite.rs"] mod suite; -mod run_with_metal { - use super::*; - use tract_core::internal::*; - use tract_core::transform::ModelTransform; - - #[derive(Debug)] - struct RunWithMetal; +#[derive(Debug)] +struct MetalTestRuntime { + name: &'static str, + phase: usize, + optimize: bool, +} - impl Runtime for RunWithMetal { - fn name(&self) -> Cow { - "run_with_metal".into() - } +impl Runtime for MetalTestRuntime { + fn name(&self) -> Cow { + self.name.into() + } - fn prepare(&self, model: TypedModel) -> TractResult> { - let metal_model = tract_metal::transform::MetalTransform::default().transform_into(&model)?; - Ok(Box::new(Arc::new(metal_model.into_optimized()?.into_runnable()?))) + fn prepare(&self, mut model: TypedModel) -> TractResult> { + tract_metal::transform::MetalTransform::default() + .transform_up_to_phase(&mut model, self.phase)?; + if self.optimize { + model = model.into_optimized()?; } + Ok(Box::new(Arc::new(model.into_runnable()?))) } +} - fn runtime() -> &'static RunWithMetal { - lazy_static::lazy_static! { - static ref RT: RunWithMetal = RunWithMetal; - }; - &RT - } +macro_rules! metal_test_suite { + ($id: ident, $phase: expr, $optimize: expr) => { + mod $id { + use super::*; + + fn runtime() -> &'static MetalTestRuntime { + lazy_static::lazy_static! { + static ref RT: MetalTestRuntime = MetalTestRuntime { name: stringify!($id), phase: $phase, optimize: $optimize }; + }; + &RT + } - include!(concat!(env!("OUT_DIR"), "/tests/tests.rs")); + include!(concat!(env!("OUT_DIR"), "/tests/tests.rs")); + } + }; } + +metal_test_suite!(metal_phase_0_einsum, 0, false); +metal_test_suite!(metal_phase_1_pre_translate, 1, false); +metal_test_suite!(metal_phase_2_translate, 2, false); +metal_test_suite!(metal_phase_3_post_translate, 3, false); +metal_test_suite!(optimized_metal, usize::MAX, true);