From fac3cfc818c4899b70f3a34e85ee979682c1d3c6 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Tue, 17 Sep 2024 11:19:38 -0400 Subject: [PATCH] CMMA: cube dispatch strategy (#126) --- crates/cubecl-linalg/src/matmul/cmma/base.rs | 23 ++++-- .../cubecl-linalg/src/matmul/cmma/config.rs | 44 +++++++++-- .../src/matmul/cmma/cube_dispatch/base.rs | 56 ++++++++++++++ .../src/matmul/cmma/cube_dispatch/mod.rs | 1 + crates/cubecl-linalg/src/matmul/cmma/mod.rs | 1 + .../src/matmul/tests/cmma/compute_loop.rs | 20 ++++- .../matmul/tests/cmma/load_shared_memory.rs | 74 +++++++++++++++---- .../src/matmul/tests/cmma/write_output.rs | 73 +++++++++++++----- .../src/matmul/tests/matmul_tests.rs | 11 +-- 9 files changed, 251 insertions(+), 52 deletions(-) create mode 100644 crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/base.rs create mode 100644 crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/mod.rs diff --git a/crates/cubecl-linalg/src/matmul/cmma/base.rs b/crates/cubecl-linalg/src/matmul/cmma/base.rs index c9aabd7d..711a4e69 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/base.rs @@ -3,6 +3,9 @@ use cubecl_core::{self as cubecl, prelude::*}; use super::block_loop::block_loop; use super::config::ComptimeCmmaInfo; +use super::cube_dispatch::base::{ + ColMajorCubeDispatch, CubeDispatch, RowMajorCubeDispatch, SwizzleCubeDispatch, +}; #[cube(launch_unchecked)] #[allow(unused_mut)] @@ -92,14 +95,9 @@ fn calculate_offsets( lhs: &Tensor, rhs: &Tensor, out: &Tensor, - #[comptime] config: ComptimeCmmaInfo, + #[comptime] comptime_info: ComptimeCmmaInfo, ) -> Offsets { - let block_size_m = config.block_size_m; - let block_size_n = config.block_size_m; - - // Cube offset - let cube_row = CUBE_POS_X * block_size_m; - let cube_col = CUBE_POS_Y * block_size_n; + let (cube_row, cube_col) = get_row_col(comptime_info); let rank = out.rank(); @@ -127,6 +125,17 @@ fn calculate_offsets( } } +#[cube] +pub(crate) fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) { + if comptime_info.cube_dispatch == 0 { + RowMajorCubeDispatch::get_row_col(comptime_info) + } else if comptime_info.cube_dispatch == 1 { + ColMajorCubeDispatch::get_row_col(comptime_info) + } else { + SwizzleCubeDispatch::get_row_col(comptime_info) + } +} + #[cube] fn make_shared_memories(#[comptime] config: ComptimeCmmaInfo) -> SharedMemories { let block_size_m = config.block_size_m; diff --git a/crates/cubecl-linalg/src/matmul/cmma/config.rs b/crates/cubecl-linalg/src/matmul/cmma/config.rs index d418b497..2a49b20e 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/config.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/config.rs @@ -2,6 +2,7 @@ use cubecl_core::prelude::*; // It is assumed that CMMA uses 32 units to compute 16x16x16 tiles // TODO put it in config and split tile size into three different parameters +// TODO add number of smem banks pub(crate) const CMMA_COOP_DIM: usize = 32; pub(crate) const CMMA_TILE_SIZE: usize = 16; @@ -14,6 +15,28 @@ pub enum WriteOutStrategy { ReuseSmem, } +/// How cubes are dispatched in the hypercube +/// Should impact L2 cache reuse +#[derive(Clone, Copy)] +pub enum CubeDispatchStrategy { + /// Cubes are dispatched row major + RowMajor, + /// Cubes are dispatched col major + ColMajor, + /// Cubes follow swizzle pattern, see https://bruce-lee-ly.medium.com/nvidia-tensor-core-cuda-hgemm-advanced-optimization-5a17eb77dd85 + Swizzle, +} + +impl From for u32 { + fn from(value: CubeDispatchStrategy) -> Self { + match value { + CubeDispatchStrategy::RowMajor => 0, + CubeDispatchStrategy::ColMajor => 1, + CubeDispatchStrategy::Swizzle => 2, + } + } +} + pub struct CmmaConfig { /// Corresponds to the number of tiles in the m and n dimensions for a block pub b_mn: usize, @@ -23,13 +46,19 @@ pub struct CmmaConfig { pub unroll: bool, /// Whether to write all accumulators in different spots of a large shared memory or reuse the space pub write_out_strategy: WriteOutStrategy, - /// Corresponds to the number of accumulators per warp. Equals b_mn / b_k - pub alpha: usize, + /// Order in which to dispatch cubes + pub cube_dispatch: CubeDispatchStrategy, } impl Default for CmmaConfig { fn default() -> Self { - Self::new(128, 16, false, WriteOutStrategy::ReuseSmem) + Self::new( + 128, + 16, + false, + WriteOutStrategy::ReuseSmem, + CubeDispatchStrategy::ColMajor, + ) } } @@ -39,6 +68,7 @@ impl CmmaConfig { b_k: usize, unroll: bool, write_out_strategy: WriteOutStrategy, + cube_dispatch: CubeDispatchStrategy, ) -> CmmaConfig { assert!(b_mn % CMMA_TILE_SIZE == 0); assert!(b_k % CMMA_TILE_SIZE == 0); @@ -46,9 +76,9 @@ impl CmmaConfig { CmmaConfig { b_mn, b_k, - alpha: b_mn / b_k, unroll, write_out_strategy, + cube_dispatch, } } @@ -66,8 +96,9 @@ impl CmmaConfig { check_n_bounds: n % self.b_mn != 0, coop_dim: CMMA_COOP_DIM as u32, num_coops: num_coops as u32, - num_accumulators: self.alpha as u32, + num_accumulators: (self.b_mn / self.b_k) as u32, write_out_reuse_smem: self.write_out_strategy == WriteOutStrategy::ReuseSmem, + cube_dispatch: self.cube_dispatch.into(), } } @@ -114,7 +145,6 @@ impl Init for ComptimeCmmaInfo { } #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] -/// Tiling 2D parameters pub struct ComptimeCmmaInfo { /// Block size along dimension of lhs pub block_size_m: u32, @@ -140,4 +170,6 @@ pub struct ComptimeCmmaInfo { pub num_accumulators: u32, /// Write out strategy: false = large, true = reuse pub write_out_reuse_smem: bool, + /// 0 = RowMajor, 1 = ColMajor, 2 = Swizzle + pub cube_dispatch: u32, } diff --git a/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/base.rs b/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/base.rs new file mode 100644 index 00000000..b0f65efb --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/base.rs @@ -0,0 +1,56 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::matmul::cmma::config::ComptimeCmmaInfo; + +#[cube] +pub(crate) trait CubeDispatch { + fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32); +} + +pub(crate) struct RowMajorCubeDispatch {} +pub(crate) struct ColMajorCubeDispatch {} +pub(crate) struct SwizzleCubeDispatch {} + +#[cube] +impl CubeDispatch for RowMajorCubeDispatch { + fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) { + let block_size_m = comptime_info.block_size_m; + let block_size_n = comptime_info.block_size_n; + + (CUBE_POS_Y * block_size_m, CUBE_POS_X * block_size_n) + } +} + +#[cube] +impl CubeDispatch for ColMajorCubeDispatch { + fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) { + let block_size_m = comptime_info.block_size_m; + let block_size_n = comptime_info.block_size_n; + + (CUBE_POS_X * block_size_m, CUBE_POS_Y * block_size_n) + } +} + +#[cube] +impl CubeDispatch for SwizzleCubeDispatch { + #[allow(clippy::modulo_one)] // it somehow seems assumed that cube count is 1 + fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) { + let block_size_m = comptime_info.block_size_m; + let block_size_n = comptime_info.block_size_n; + + let num_elem_per_swizzle_col = CUBE_COUNT_Y * 2; + let nth_cube = CUBE_POS_X * CUBE_COUNT_Y + CUBE_POS_Y; + let swizzle_id = nth_cube % num_elem_per_swizzle_col; + + let swizzle_col = nth_cube / num_elem_per_swizzle_col; + let col_within_swizzle = swizzle_id / CUBE_COUNT_Y; + let cube_col = swizzle_col * 2 + col_within_swizzle; + + let topdown_row = swizzle_id % CUBE_COUNT_Y; + let is_bottom_up = (nth_cube / num_elem_per_swizzle_col) % 2; + let cube_row = topdown_row + is_bottom_up * (CUBE_COUNT_Y - 2 * topdown_row - 1); + + (cube_row * block_size_m, cube_col * block_size_n) + } +} diff --git a/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/mod.rs b/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/mod.rs new file mode 100644 index 00000000..671c5262 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/mod.rs @@ -0,0 +1 @@ +pub(crate) mod base; diff --git a/crates/cubecl-linalg/src/matmul/cmma/mod.rs b/crates/cubecl-linalg/src/matmul/cmma/mod.rs index 7bfd4946..b844aed6 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/mod.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/mod.rs @@ -3,6 +3,7 @@ mod block_io; mod block_loop; pub(crate) mod compute_loop; pub(crate) mod config; +pub(crate) mod cube_dispatch; mod launch; pub(crate) mod load_shared_memory; pub(crate) mod write_output; diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index 34b2686b..949dfcf8 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -4,7 +4,7 @@ use cubecl_core as cubecl; use crate::matmul::cmma::{ base::{make_cmma_matrices, Ids, SharedMemories}, compute_loop::compute_loop, - config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy}, + config::{CmmaConfig, ComptimeCmmaInfo}, }; use crate::matmul::tests::test_utils::{ assert_equals, cmma_available, create_empty, range_tensor_f16, @@ -105,7 +105,11 @@ fn compute_loop_test_case( /// Exported test pub fn cmma_compute_loop_block_equal_tile_test(device: &R::Device) { compute_loop_test_case::( - CmmaConfig::new(16, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 16, + b_k: 16, + ..Default::default() + }, &[ 19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0, 20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0, @@ -147,7 +151,11 @@ pub fn cmma_compute_loop_block_equal_tile_test(device: &R::Device) { /// Exported test pub fn cmma_compute_loop_block_larger_than_tile_test(device: &R::Device) { compute_loop_test_case::( - CmmaConfig::new(32, 32, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 32, + ..Default::default() + }, &[ 1610496.0, 1614832.0, 1619168.0, 1623504.0, 1627840.0, 1632176.0, 1636512.0, 1640848.0, 1645184.0, 1649520.0, 1653856.0, 1658192.0, 1662528.0, 1666864.0, 1671200.0, 1675536.0, @@ -290,7 +298,11 @@ pub fn cmma_compute_loop_block_larger_than_tile_test(device: &R::Dev /// Exported test pub fn cmma_compute_loop_b_mn_larger_than_b_k_test(device: &R::Device) { compute_loop_test_case::( - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + ..Default::default() + }, &[ 19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0, 20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index 588e12ec..b00b319c 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -4,7 +4,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::matmul::cmma::base::{Dimensions, Ids, Offsets, RuntimeCmmaInfo}; -use crate::matmul::cmma::config::{CmmaConfig, WriteOutStrategy}; +use crate::matmul::cmma::config::CmmaConfig; use crate::matmul::tests::test_utils::{assert_equals_range, create_empty}; use crate::matmul::{ cmma::{config::ComptimeCmmaInfo, load_shared_memory::*}, @@ -184,7 +184,11 @@ pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, @@ -223,7 +227,11 @@ pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, @@ -262,7 +270,11 @@ pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, @@ -300,7 +312,11 @@ pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(devi n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, @@ -337,7 +353,11 @@ pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: & n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, @@ -373,7 +393,11 @@ pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 144., 145., @@ -409,7 +433,11 @@ pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., 1036., 1037., 1038., 1039., 1088., 1089., 1090., 1091., 1092., 1093., 1094., 1095., @@ -449,7 +477,11 @@ pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, @@ -488,7 +520,11 @@ pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { n: 64, }, 0, - CmmaConfig::new(64, 32, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 64, + b_k: 32, + ..Default::default() + }, &[ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 144., 145., @@ -524,7 +560,11 @@ pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { n: 64, }, 32, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, @@ -564,7 +604,11 @@ pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { n: 64, }, 32, - CmmaConfig::new(B_MN, B_K, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: B_MN, + b_k: B_K, + ..Default::default() + }, &[ 2048., 2049., 2050., 2051., 2052., 2053., 2054., 2055., 2056., 2057., 2058., 2059., 2060., 2061., 2062., 2063., 2112., 2113., 2114., 2115., 2116., 2117., 2118., 2119., @@ -604,7 +648,11 @@ pub fn load_shared_memory_rhs_larger_block_test(device: &R::Device) n: 32, }, 0, - CmmaConfig::new(32, 32, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 32, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index de5d27b9..3dccea93 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -3,8 +3,10 @@ use std::ops::Range; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::{Dimensions, Ids, Offsets, RuntimeCmmaInfo}; -use crate::matmul::cmma::config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy}; +use crate::matmul::cmma::base::{get_row_col, Dimensions, Ids, Offsets, RuntimeCmmaInfo}; +use crate::matmul::cmma::config::{ + CmmaConfig, ComptimeCmmaInfo, CubeDispatchStrategy, WriteOutStrategy, +}; use crate::matmul::cmma::write_output::base::shared_memory_to_output; use crate::matmul::tests::test_utils::{ assert_equals, assert_equals_range, range_tensor, zeros_tensor, @@ -19,13 +21,11 @@ fn write_output_test( m: u32, k: u32, n: u32, - #[comptime] config: ComptimeCmmaInfo, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let num_accumulators = config.num_accumulators; - let tile_size = config.tile_size; - let num_coops = config.num_coops; - let block_size_m = config.block_size_m; - let block_size_n = config.block_size_n; + let num_accumulators = comptime_info.num_accumulators; + let tile_size = comptime_info.tile_size; + let num_coops = comptime_info.num_coops; let sm_stride = tile_size * tile_size; let sm_size = num_accumulators * num_coops * sm_stride; @@ -35,12 +35,14 @@ fn write_output_test( accumulate[i] = acc_sm_arr[i]; } + let (cube_row, cube_col) = get_row_col(comptime_info); + let offsets = Offsets { batch_lhs: 0, batch_rhs: 0, batch_out: 0, - cube_row: CUBE_POS_X * block_size_m, - cube_col: CUBE_POS_Y * block_size_n, + cube_row, + cube_col, }; let dims = Dimensions { m, k, n }; let ids = Ids { @@ -50,6 +52,7 @@ fn write_output_test( let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; let smem_position_base = num_accumulators * ids.coop; + #[unroll] for n_iter in 0..num_accumulators { shared_memory_to_output( @@ -58,7 +61,7 @@ fn write_output_test( accumulate, n_iter, runtime_info, - config, + comptime_info, ); } } @@ -108,7 +111,13 @@ pub fn cmma_write_output_warp_test(device: &R::Device) { k: 16, n: 32, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, @@ -166,7 +175,13 @@ pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: k: 16, n: 28, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, @@ -218,7 +233,13 @@ pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R k: 16, n: 32, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, @@ -270,7 +291,13 @@ pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::D k: 16, n: 28, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, @@ -318,7 +345,13 @@ pub fn cmma_write_output_second_warp_test(device: &R::Device) { k: 16, n: 64, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, @@ -417,7 +450,13 @@ pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) k: 16, n: 64, }, - CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem), + CmmaConfig { + b_mn: 32, + b_k: 16, + write_out_strategy: WriteOutStrategy::LargeSmem, + cube_dispatch: CubeDispatchStrategy::ColMajor, + ..Default::default() + }, &[ 512.0, 513.0, 514.0, 515.0, 516.0, 517.0, 518.0, 519.0, 520.0, 521.0, 522.0, 523.0, 524.0, 525.0, 526.0, 527.0, 768.0, 769.0, 770.0, 771.0, 772.0, 773.0, 774.0, 775.0, diff --git a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs index 0bda6d29..05563611 100644 --- a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs +++ b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs @@ -3,10 +3,7 @@ use half::f16; use crate::{ matmul::{ - cmma::{ - config::{CmmaConfig, WriteOutStrategy}, - launch, - }, + cmma::{config::CmmaConfig, launch}, tiling2d, }, tensor::TensorHandle, @@ -42,7 +39,11 @@ macro_rules! alternate_block_sizes { compute_f16: true, } .test_cmma::( - CmmaConfig::new($b_mn, $b_k, false, WriteOutStrategy::ReuseSmem), + CmmaConfig { + b_mn: $b_mn, + b_k: $b_k, + ..Default::default() + }, device, ); }