diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs index 69accf46..76cc98a1 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/mod.rs @@ -1,6 +1,8 @@ mod base; +mod selection; pub mod cmma; pub mod plane_mma; pub use base::Algorithm; +pub use selection::*; diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selection.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selection.rs new file mode 100644 index 00000000..b7f7c016 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/selection.rs @@ -0,0 +1,39 @@ +use cubecl_core::{ + client::ComputeClient, + prelude::{Numeric, TensorHandleRef}, + Runtime, +}; + +use crate::matmul::{components::MatmulProblem, kernels::matmul::base::matmul_cube_preparation}; + +use super::{cmma::Cmma, plane_mma::PlaneMma}; + +pub struct CmmaSelector; + +impl CmmaSelector { + pub fn select_kernel( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, + problem: MatmulProblem, + ) { + // TODO if problem.m < problem.n... + matmul_cube_preparation::>(client, lhs, rhs, out, problem); + } +} + +pub struct PlaneMmaSelector; + +impl PlaneMmaSelector { + pub fn select_kernel( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, + problem: MatmulProblem, + ) { + // TODO if problem.m < problem.n... + matmul_cube_preparation::>(client, lhs, rhs, out, problem); + } +} diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs index 9584e657..1ae4e4c4 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs @@ -10,26 +10,10 @@ use crate::matmul; use crate::matmul::components::{MatmulLaunch, MatmulProblem}; use crate::tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle}; +use super::algorithm::{CmmaSelector, PlaneMmaSelector}; +use super::cmma::Cmma; use super::config::AdvancedConfig; -use super::{cmma::Cmma, plane_mma::PlaneMma, Algorithm}; - -/// Launch a matrix multiplication kernel. -/// -/// Cmma will be used if available and enabled, -/// otherwise it will fall back on a non-cmma implementation -pub fn launch_ref( - client: &ComputeClient, - lhs: TensorHandleRef<'_, R>, - rhs: TensorHandleRef<'_, R>, - out: TensorHandleRef<'_, R>, - disable_cmma: bool, -) { - if !disable_cmma && Cmma::::check_availability::(client).is_ok() { - matmul_cmma_ref::>(client, lhs, rhs, out); - } else { - matmul_cmma_ref::>(client, lhs, rhs, out); - } -} +use super::Algorithm; /// Launch a matrix multiplication kernel. /// @@ -47,16 +31,21 @@ pub fn launch( lhs.as_ref(), rhs.as_ref(), out.as_ref(), - disable_cmma, + disable_cmma || Cmma::::check_availability::(client).is_err(), ); out } -fn matmul_cmma_ref>( +/// Launch a matrix multiplication kernel. +/// +/// Cmma will be used if available and enabled, +/// otherwise it will fall back on a non-cmma implementation +pub fn launch_ref( client: &ComputeClient, lhs: TensorHandleRef<'_, R>, rhs: TensorHandleRef<'_, R>, out: TensorHandleRef<'_, R>, + disable_cmma: bool, ) { let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) { MatrixLayout::Contiguous => (false, false), @@ -71,43 +60,48 @@ fn matmul_cmma_ref>( let (rhs_make_contiguous, rhs_transposed) = check_layout(&rhs); match (lhs_make_contiguous, rhs_make_contiguous) { - (false, false) => matmul_cmma_ref_no_check::( + (false, false) => matmul_cmma_ref_no_check::( client, lhs, rhs, out, (lhs_transposed, rhs_transposed), + disable_cmma, ), - (false, true) => matmul_cmma_ref_no_check::( + (false, true) => matmul_cmma_ref_no_check::( client, lhs, into_contiguous::(client, rhs).as_ref(), out, (lhs_transposed, rhs_transposed), + disable_cmma, ), - (true, false) => matmul_cmma_ref_no_check::( + (true, false) => matmul_cmma_ref_no_check::( client, into_contiguous::(client, lhs).as_ref(), rhs, out, (lhs_transposed, rhs_transposed), + disable_cmma, ), - (true, true) => matmul_cmma_ref_no_check::( + (true, true) => matmul_cmma_ref_no_check::( client, into_contiguous::(client, lhs).as_ref(), into_contiguous::(client, rhs).as_ref(), out, (lhs_transposed, rhs_transposed), + disable_cmma, ), } } -fn matmul_cmma_ref_no_check>( +fn matmul_cmma_ref_no_check( client: &ComputeClient, lhs: TensorHandleRef<'_, R>, rhs: TensorHandleRef<'_, R>, out: TensorHandleRef<'_, R>, transposed: (bool, bool), + disable_cmma: bool, ) { let rank = lhs.strides.len(); @@ -141,6 +135,20 @@ fn matmul_cmma_ref_no_check>( out_line_size, }; + if disable_cmma { + PlaneMmaSelector::select_kernel::(client, lhs, rhs, out, problem); + } else { + CmmaSelector::select_kernel::(client, lhs, rhs, out, problem); + } +} + +pub(crate) fn matmul_cube_preparation>( + client: &ComputeClient, + lhs: TensorHandleRef<'_, R>, + rhs: TensorHandleRef<'_, R>, + out: TensorHandleRef<'_, R>, + problem: MatmulProblem, +) { let cube_dim = D::cube_dim(); let cube_count = D::cube_count(&problem);