Skip to content

Commit

Permalink
CMMA: cube dispatch strategy (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Sep 17, 2024
1 parent 9a8fdec commit fac3cfc
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 52 deletions.
23 changes: 16 additions & 7 deletions crates/cubecl-linalg/src/matmul/cmma/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use cubecl_core::{self as cubecl, prelude::*};

use super::block_loop::block_loop;
use super::config::ComptimeCmmaInfo;
use super::cube_dispatch::base::{
ColMajorCubeDispatch, CubeDispatch, RowMajorCubeDispatch, SwizzleCubeDispatch,
};

#[cube(launch_unchecked)]
#[allow(unused_mut)]
Expand Down Expand Up @@ -92,14 +95,9 @@ fn calculate_offsets<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &Tensor<F>,
#[comptime] config: ComptimeCmmaInfo,
#[comptime] comptime_info: ComptimeCmmaInfo,
) -> Offsets {
let block_size_m = config.block_size_m;
let block_size_n = config.block_size_m;

// Cube offset
let cube_row = CUBE_POS_X * block_size_m;
let cube_col = CUBE_POS_Y * block_size_n;
let (cube_row, cube_col) = get_row_col(comptime_info);

let rank = out.rank();

Expand Down Expand Up @@ -127,6 +125,17 @@ fn calculate_offsets<F: Float>(
}
}

#[cube]
pub(crate) fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) {
if comptime_info.cube_dispatch == 0 {
RowMajorCubeDispatch::get_row_col(comptime_info)
} else if comptime_info.cube_dispatch == 1 {
ColMajorCubeDispatch::get_row_col(comptime_info)
} else {
SwizzleCubeDispatch::get_row_col(comptime_info)
}
}

#[cube]
fn make_shared_memories<FC: Float>(#[comptime] config: ComptimeCmmaInfo) -> SharedMemories<FC> {
let block_size_m = config.block_size_m;
Expand Down
44 changes: 38 additions & 6 deletions crates/cubecl-linalg/src/matmul/cmma/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use cubecl_core::prelude::*;

// It is assumed that CMMA uses 32 units to compute 16x16x16 tiles
// TODO put it in config and split tile size into three different parameters
// TODO add number of smem banks
pub(crate) const CMMA_COOP_DIM: usize = 32;
pub(crate) const CMMA_TILE_SIZE: usize = 16;

Expand All @@ -14,6 +15,28 @@ pub enum WriteOutStrategy {
ReuseSmem,
}

/// How cubes are dispatched in the hypercube
/// Should impact L2 cache reuse
#[derive(Clone, Copy)]
pub enum CubeDispatchStrategy {
/// Cubes are dispatched row major
RowMajor,
/// Cubes are dispatched col major
ColMajor,
/// Cubes follow swizzle pattern, see https://bruce-lee-ly.medium.com/nvidia-tensor-core-cuda-hgemm-advanced-optimization-5a17eb77dd85
Swizzle,
}

impl From<CubeDispatchStrategy> for u32 {
fn from(value: CubeDispatchStrategy) -> Self {
match value {
CubeDispatchStrategy::RowMajor => 0,
CubeDispatchStrategy::ColMajor => 1,
CubeDispatchStrategy::Swizzle => 2,
}
}
}

pub struct CmmaConfig {
/// Corresponds to the number of tiles in the m and n dimensions for a block
pub b_mn: usize,
Expand All @@ -23,13 +46,19 @@ pub struct CmmaConfig {
pub unroll: bool,
/// Whether to write all accumulators in different spots of a large shared memory or reuse the space
pub write_out_strategy: WriteOutStrategy,
/// Corresponds to the number of accumulators per warp. Equals b_mn / b_k
pub alpha: usize,
/// Order in which to dispatch cubes
pub cube_dispatch: CubeDispatchStrategy,
}

impl Default for CmmaConfig {
fn default() -> Self {
Self::new(128, 16, false, WriteOutStrategy::ReuseSmem)
Self::new(
128,
16,
false,
WriteOutStrategy::ReuseSmem,
CubeDispatchStrategy::ColMajor,
)
}
}

Expand All @@ -39,16 +68,17 @@ impl CmmaConfig {
b_k: usize,
unroll: bool,
write_out_strategy: WriteOutStrategy,
cube_dispatch: CubeDispatchStrategy,
) -> CmmaConfig {
assert!(b_mn % CMMA_TILE_SIZE == 0);
assert!(b_k % CMMA_TILE_SIZE == 0);
assert!(b_mn % b_k == 0);
CmmaConfig {
b_mn,
b_k,
alpha: b_mn / b_k,
unroll,
write_out_strategy,
cube_dispatch,
}
}

Expand All @@ -66,8 +96,9 @@ impl CmmaConfig {
check_n_bounds: n % self.b_mn != 0,
coop_dim: CMMA_COOP_DIM as u32,
num_coops: num_coops as u32,
num_accumulators: self.alpha as u32,
num_accumulators: (self.b_mn / self.b_k) as u32,
write_out_reuse_smem: self.write_out_strategy == WriteOutStrategy::ReuseSmem,
cube_dispatch: self.cube_dispatch.into(),
}
}

Expand Down Expand Up @@ -114,7 +145,6 @@ impl Init for ComptimeCmmaInfo {
}

#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
/// Tiling 2D parameters
pub struct ComptimeCmmaInfo {
/// Block size along dimension of lhs
pub block_size_m: u32,
Expand All @@ -140,4 +170,6 @@ pub struct ComptimeCmmaInfo {
pub num_accumulators: u32,
/// Write out strategy: false = large, true = reuse
pub write_out_reuse_smem: bool,
/// 0 = RowMajor, 1 = ColMajor, 2 = Swizzle
pub cube_dispatch: u32,
}
56 changes: 56 additions & 0 deletions crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::config::ComptimeCmmaInfo;

#[cube]
pub(crate) trait CubeDispatch {
fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32);
}

pub(crate) struct RowMajorCubeDispatch {}
pub(crate) struct ColMajorCubeDispatch {}
pub(crate) struct SwizzleCubeDispatch {}

#[cube]
impl CubeDispatch for RowMajorCubeDispatch {
fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) {
let block_size_m = comptime_info.block_size_m;
let block_size_n = comptime_info.block_size_n;

(CUBE_POS_Y * block_size_m, CUBE_POS_X * block_size_n)
}
}

#[cube]
impl CubeDispatch for ColMajorCubeDispatch {
fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) {
let block_size_m = comptime_info.block_size_m;
let block_size_n = comptime_info.block_size_n;

(CUBE_POS_X * block_size_m, CUBE_POS_Y * block_size_n)
}
}

#[cube]
impl CubeDispatch for SwizzleCubeDispatch {
#[allow(clippy::modulo_one)] // it somehow seems assumed that cube count is 1
fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) {
let block_size_m = comptime_info.block_size_m;
let block_size_n = comptime_info.block_size_n;

let num_elem_per_swizzle_col = CUBE_COUNT_Y * 2;
let nth_cube = CUBE_POS_X * CUBE_COUNT_Y + CUBE_POS_Y;
let swizzle_id = nth_cube % num_elem_per_swizzle_col;

let swizzle_col = nth_cube / num_elem_per_swizzle_col;
let col_within_swizzle = swizzle_id / CUBE_COUNT_Y;
let cube_col = swizzle_col * 2 + col_within_swizzle;

let topdown_row = swizzle_id % CUBE_COUNT_Y;
let is_bottom_up = (nth_cube / num_elem_per_swizzle_col) % 2;
let cube_row = topdown_row + is_bottom_up * (CUBE_COUNT_Y - 2 * topdown_row - 1);

(cube_row * block_size_m, cube_col * block_size_n)
}
}
1 change: 1 addition & 0 deletions crates/cubecl-linalg/src/matmul/cmma/cube_dispatch/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod base;
1 change: 1 addition & 0 deletions crates/cubecl-linalg/src/matmul/cmma/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod block_io;
mod block_loop;
pub(crate) mod compute_loop;
pub(crate) mod config;
pub(crate) mod cube_dispatch;
mod launch;
pub(crate) mod load_shared_memory;
pub(crate) mod write_output;
Expand Down
20 changes: 16 additions & 4 deletions crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cubecl_core as cubecl;
use crate::matmul::cmma::{
base::{make_cmma_matrices, Ids, SharedMemories},
compute_loop::compute_loop,
config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy},
config::{CmmaConfig, ComptimeCmmaInfo},
};
use crate::matmul::tests::test_utils::{
assert_equals, cmma_available, create_empty, range_tensor_f16,
Expand Down Expand Up @@ -105,7 +105,11 @@ fn compute_loop_test_case<R: Runtime>(
/// Exported test
pub fn cmma_compute_loop_block_equal_tile_test<R: Runtime>(device: &R::Device) {
compute_loop_test_case::<R>(
CmmaConfig::new(16, 16, false, WriteOutStrategy::LargeSmem),
CmmaConfig {
b_mn: 16,
b_k: 16,
..Default::default()
},
&[
19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0,
20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0,
Expand Down Expand Up @@ -147,7 +151,11 @@ pub fn cmma_compute_loop_block_equal_tile_test<R: Runtime>(device: &R::Device) {
/// Exported test
pub fn cmma_compute_loop_block_larger_than_tile_test<R: Runtime>(device: &R::Device) {
compute_loop_test_case::<R>(
CmmaConfig::new(32, 32, false, WriteOutStrategy::LargeSmem),
CmmaConfig {
b_mn: 32,
b_k: 32,
..Default::default()
},
&[
1610496.0, 1614832.0, 1619168.0, 1623504.0, 1627840.0, 1632176.0, 1636512.0, 1640848.0,
1645184.0, 1649520.0, 1653856.0, 1658192.0, 1662528.0, 1666864.0, 1671200.0, 1675536.0,
Expand Down Expand Up @@ -290,7 +298,11 @@ pub fn cmma_compute_loop_block_larger_than_tile_test<R: Runtime>(device: &R::Dev
/// Exported test
pub fn cmma_compute_loop_b_mn_larger_than_b_k_test<R: Runtime>(device: &R::Device) {
compute_loop_test_case::<R>(
CmmaConfig::new(32, 16, false, WriteOutStrategy::LargeSmem),
CmmaConfig {
b_mn: 32,
b_k: 16,
..Default::default()
},
&[
19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0,
20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0,
Expand Down
Loading

0 comments on commit fac3cfc

Please sign in to comment.