Skip to content

Commit

Permalink
easier kernel selection
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Nov 19, 2024
1 parent 70264ed commit 61df81b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod base;
mod selection;

pub mod cmma;
pub mod plane_mma;

pub use base::Algorithm;
pub use selection::*;
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use cubecl_core::{
client::ComputeClient,
prelude::{Numeric, TensorHandleRef},
Runtime,
};

use crate::matmul::{components::MatmulProblem, kernels::matmul::base::matmul_cube_preparation};

use super::{cmma::Cmma, plane_mma::PlaneMma};

pub struct CmmaSelector;

impl CmmaSelector {
pub fn select_kernel<R: Runtime, EG: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
problem: MatmulProblem,
) {
// TODO if problem.m < problem.n...
matmul_cube_preparation::<R, EG, Cmma<EG>>(client, lhs, rhs, out, problem);
}
}

pub struct PlaneMmaSelector;

impl PlaneMmaSelector {
pub fn select_kernel<R: Runtime, EG: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
problem: MatmulProblem,
) {
// TODO if problem.m < problem.n...
matmul_cube_preparation::<R, EG, PlaneMma<EG>>(client, lhs, rhs, out, problem);
}
}
60 changes: 34 additions & 26 deletions crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,10 @@ use crate::matmul;
use crate::matmul::components::{MatmulLaunch, MatmulProblem};
use crate::tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle};

use super::algorithm::{CmmaSelector, PlaneMmaSelector};
use super::cmma::Cmma;
use super::config::AdvancedConfig;
use super::{cmma::Cmma, plane_mma::PlaneMma, Algorithm};

/// Launch a matrix multiplication kernel.
///
/// Cmma will be used if available and enabled,
/// otherwise it will fall back on a non-cmma implementation
pub fn launch_ref<R: Runtime, EG: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
disable_cmma: bool,
) {
if !disable_cmma && Cmma::<EG>::check_availability::<R>(client).is_ok() {
matmul_cmma_ref::<R, EG, Cmma<EG>>(client, lhs, rhs, out);
} else {
matmul_cmma_ref::<R, EG, PlaneMma<EG>>(client, lhs, rhs, out);
}
}
use super::Algorithm;

/// Launch a matrix multiplication kernel.
///
Expand All @@ -47,16 +31,21 @@ pub fn launch<R: Runtime, EG: Numeric>(
lhs.as_ref(),
rhs.as_ref(),
out.as_ref(),
disable_cmma,
disable_cmma || Cmma::<EG>::check_availability::<R>(client).is_err(),
);
out
}

fn matmul_cmma_ref<R: Runtime, EG: Numeric, D: Algorithm<EG>>(
/// Launch a matrix multiplication kernel.
///
/// Cmma will be used if available and enabled,
/// otherwise it will fall back on a non-cmma implementation
pub fn launch_ref<R: Runtime, EG: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
disable_cmma: bool,
) {
let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) {
MatrixLayout::Contiguous => (false, false),
Expand All @@ -71,43 +60,48 @@ fn matmul_cmma_ref<R: Runtime, EG: Numeric, D: Algorithm<EG>>(
let (rhs_make_contiguous, rhs_transposed) = check_layout(&rhs);

match (lhs_make_contiguous, rhs_make_contiguous) {
(false, false) => matmul_cmma_ref_no_check::<R, EG, D>(
(false, false) => matmul_cmma_ref_no_check::<R, EG>(
client,
lhs,
rhs,
out,
(lhs_transposed, rhs_transposed),
disable_cmma,
),
(false, true) => matmul_cmma_ref_no_check::<R, EG, D>(
(false, true) => matmul_cmma_ref_no_check::<R, EG>(
client,
lhs,
into_contiguous::<R, EG>(client, rhs).as_ref(),
out,
(lhs_transposed, rhs_transposed),
disable_cmma,
),
(true, false) => matmul_cmma_ref_no_check::<R, EG, D>(
(true, false) => matmul_cmma_ref_no_check::<R, EG>(
client,
into_contiguous::<R, EG>(client, lhs).as_ref(),
rhs,
out,
(lhs_transposed, rhs_transposed),
disable_cmma,
),
(true, true) => matmul_cmma_ref_no_check::<R, EG, D>(
(true, true) => matmul_cmma_ref_no_check::<R, EG>(
client,
into_contiguous::<R, EG>(client, lhs).as_ref(),
into_contiguous::<R, EG>(client, rhs).as_ref(),
out,
(lhs_transposed, rhs_transposed),
disable_cmma,
),
}
}

fn matmul_cmma_ref_no_check<R: Runtime, EG: Numeric, D: Algorithm<EG>>(
fn matmul_cmma_ref_no_check<R: Runtime, EG: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
transposed: (bool, bool),
disable_cmma: bool,
) {
let rank = lhs.strides.len();

Expand Down Expand Up @@ -141,6 +135,20 @@ fn matmul_cmma_ref_no_check<R: Runtime, EG: Numeric, D: Algorithm<EG>>(
out_line_size,
};

if disable_cmma {
PlaneMmaSelector::select_kernel::<R, EG>(client, lhs, rhs, out, problem);
} else {
CmmaSelector::select_kernel::<R, EG>(client, lhs, rhs, out, problem);
}
}

pub(crate) fn matmul_cube_preparation<R: Runtime, EG: Numeric, D: Algorithm<EG>>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
problem: MatmulProblem,
) {
let cube_dim = D::cube_dim();
let cube_count = D::cube_count(&problem);

Expand Down

0 comments on commit 61df81b

Please sign in to comment.