From 9f788bcfd3ca40dc5444e9c8f7b89aaef9c35874 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 10 Dec 2024 11:19:32 +0100 Subject: [PATCH] extra unit-test for params gen --- metal/src/kernels/matmul/mod.rs | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/metal/src/kernels/matmul/mod.rs b/metal/src/kernels/matmul/mod.rs index 2845e5e6ce..83d3850fff 100644 --- a/metal/src/kernels/matmul/mod.rs +++ b/metal/src/kernels/matmul/mod.rs @@ -446,6 +446,46 @@ mod tests { }] ); + assert_eq!( + GemmDispatchParams::compute_dispatches_params( + dt, + 0, + &[2, k, m], + true, + 0, + &[1, k, n], + false, + 100, + &[2, m, n], + )?, + vec![ + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 0, + transpose_b: false, + b_offset: 0, + c_offset: 100, + }, + GemmDispatchParams { + dt, + batch: 1, + m, + n, + k, + transpose_a: true, + a_offset: 1 * m * k * dt.size_of(), + transpose_b: false, + b_offset: 0, + c_offset: 100 + 1 * m * n * dt.size_of(), + } + ] + ); + assert_eq!( GemmDispatchParams::compute_dispatches_params( dt,