diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 2845e5e6ce..83d3850fff 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -446,6 +446,46 @@ mod tests { }] ); + assert_eq!( + GemmDispatchParams::compute_dispatches_params( + dt, + 0, + &[2, k, m], + true, + 0, + &[1, k, n], + false, + 100, + &[2, m, n], + )?, + vec![ + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 0, + transpose_b: false, + b_offset: 0, + c_offset: 100, + }, + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 1 * m * k * dt.size_of(), + transpose_b: false, + b_offset: 0, + c_offset: 100 + 1 * m * n * dt.size_of(), + } + ] + ); + assert_eq!( GemmDispatchParams::compute_dispatches_params( dt,