From f17652430565eeba75ae9e4e2691349d4fbd4a56 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 22 Nov 2024 14:05:35 -0500 Subject: [PATCH] Fast but not right --- .../src/matmul/components/config.rs | 4 +- .../src/matmul/components/stage/base.rs | 2 + .../matmul/kernels/matmul/algorithm/cmma.rs | 15 ++-- .../src/matmul/kernels/matmul/base.rs | 14 ++-- crates/cubecl/benches/matmul.rs | 78 ++++++++++--------- 5 files changed, 65 insertions(+), 48 deletions(-) diff --git a/crates/cubecl-linalg/src/matmul/components/config.rs b/crates/cubecl-linalg/src/matmul/components/config.rs index 52089f28a..643767b35 100644 --- a/crates/cubecl-linalg/src/matmul/components/config.rs +++ b/crates/cubecl-linalg/src/matmul/components/config.rs @@ -48,8 +48,8 @@ pub enum MatrixLayout { /// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API. pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout { match layout { - MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor, - MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor, + MatrixLayout::RowMajor => cmma::MatrixLayout::ColMajor, + MatrixLayout::ColMajor => cmma::MatrixLayout::RowMajor, } } diff --git a/crates/cubecl-linalg/src/matmul/components/stage/base.rs b/crates/cubecl-linalg/src/matmul/components/stage/base.rs index c8acc652b..43a0fa22e 100644 --- a/crates/cubecl-linalg/src/matmul/components/stage/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/stage/base.rs @@ -142,5 +142,7 @@ create_cmma_stage!(S2x2x1, 2, 2, 1); create_cmma_stage!(S2x2x2, 2, 2, 2); create_cmma_stage!(S4x4x1, 4, 4, 1); create_cmma_stage!(S4x4x2, 4, 4, 2); +create_cmma_stage!(S4x4x4, 4, 4, 4); create_cmma_stage!(S8x1x1, 8, 1, 1); create_cmma_stage!(S8x8x1, 8, 8, 1); +create_cmma_stage!(S8x8x2, 8, 8, 2); diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs index f72b29daa..3f0a1fcb4 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use cubecl_core::prelude::*; -use crate::matmul::components::stage::{self, S8x8x1, StageSize}; +use crate::matmul::components::stage::{self, S8x8x2, StageSize}; use crate::matmul::components::tile::accelerated::Accelerated16x16x16; use crate::matmul::components::tile::Matmul; use crate::matmul::components::MatmulProblem; @@ -10,7 +10,7 @@ use crate::matmul::components::{batch, global}; use super::base; -type Stage = S8x8x1; +type Stage = S8x8x2; pub struct Cmma { pub _eg: PhantomData, @@ -18,9 +18,10 @@ pub struct Cmma { impl base::Algorithm for Cmma { const PLANE_DIM: u32 = 32; + type EG = EG; type ES = half::f16; - type EA = half::f16; + type EA = f32; type TileMatmul = Accelerated16x16x16; @@ -29,8 +30,12 @@ impl base::Algorithm for Cmma { type GlobalMatmul = global::homogeneous::Matmul; - type BatchMatmul = - batch::one_to_one::Matmul; + type BatchMatmul = batch::one_to_one::Matmul< + Self::EG, + Self::ES, + Self::GlobalMatmul, + batch::TransposedDispatch, + >; fn cube_dim() -> CubeDim { CubeDim::new(Self::PLANE_DIM, Stage::NUM_M, 1) diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs index 1ae4e4c4e..6fed5f4ba 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs @@ -123,12 +123,12 @@ fn matmul_cmma_ref_no_check( k: k as usize, batches: out.shape[..out.shape.len() - 2].to_vec(), lhs_layout: match transposed.0 { - true => matmul::components::MatrixLayout::ColMajor, - false => matmul::components::MatrixLayout::RowMajor, + true => matmul::components::MatrixLayout::RowMajor, + false => matmul::components::MatrixLayout::ColMajor, }, rhs_layout: match transposed.1 { - true => matmul::components::MatrixLayout::ColMajor, - false => matmul::components::MatrixLayout::RowMajor, + true => matmul::components::MatrixLayout::RowMajor, + false => matmul::components::MatrixLayout::ColMajor, }, lhs_line_size, rhs_line_size, @@ -152,7 +152,11 @@ pub(crate) fn matmul_cube_preparation> let cube_dim = D::cube_dim(); let cube_count = D::cube_count(&problem); - let advanced_config = Default::default(); + let advanced_config = AdvancedConfig { + lhs_tiling_order: matmul::components::stage::TilingOrderConfig::ColMajor, + rhs_tiling_order: matmul::components::stage::TilingOrderConfig::ColMajor, + enforced_tile_layout: (None, None), + }; launch_matmul::( client, diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index fe7a2dbff..82cad4054 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -61,10 +61,10 @@ fn run(device: R::Device, strategy: matmul::Strategy) { client.enable_timestamps(); let bench = MatmulBench:: { - b: 8, - m: 2048, - k: 2048, - n: 2048, + b: 2, + m: 4096, + k: 4096, + n: 4096, client, device, strategy, @@ -75,25 +75,31 @@ fn run(device: R::Device, strategy: matmul::Strategy) { } fn main() { - #[cfg(feature = "wgpu")] - { - run::( - Default::default(), - matmul::Strategy::Tiling2D(Default::default()), - ); - run::(Default::default(), matmul::Strategy::PlaneMma); - } + // #[cfg(feature = "wgpu")] + // { + // run::( + // Default::default(), + // matmul::Strategy::Tiling2D(Default::default()), + // ); + // run::(Default::default(), matmul::Strategy::PlaneMma); + // } #[cfg(feature = "wgpu-spirv")] { - run::, f32>( - Default::default(), - matmul::Strategy::Tiling2D(Default::default()), - ); - run::, f32>( - Default::default(), - matmul::Strategy::PlaneMma, - ); + type R = cubecl::wgpu::WgpuRuntime; + // run::, f32>( + // Default::default(), + // matmul::Strategy::Tiling2D(Default::default()), + // ); + // run::, f32>( + // Default::default(), + // matmul::Strategy::PlaneMma, + // ); + run::(Default::default(), matmul::Strategy::default()); + // run::(Default::default(), matmul::Strategy::PlaneMma); + // run::(Default::default(), matmul::Strategy::PlaneMma); + // run::(Default::default(), matmul::Strategy::Accelerated); + run::(Default::default(), matmul::Strategy::Accelerated); } #[cfg(all(feature = "hip", target_os = "linux"))] @@ -130,25 +136,25 @@ fn main() { #[cfg(feature = "cuda")] { - run::( - Default::default(), - matmul::Strategy::Tiling2D(Default::default()), - ); - run::( - Default::default(), - matmul::Strategy::Tiling2D(Default::default()), - ); - run::( - Default::default(), - matmul::Strategy::CmmaOld(Default::default()), - ); + // run::( + // Default::default(), + // matmul::Strategy::Tiling2D(Default::default()), + // ); + // run::( + // Default::default(), + // matmul::Strategy::Tiling2D(Default::default()), + // ); + // run::( + // Default::default(), + // matmul::Strategy::CmmaOld(Default::default()), + // ); run::( Default::default(), - matmul::Strategy::CmmaOld(Default::default()), + matmul::Strategy::default(), ); - run::(Default::default(), matmul::Strategy::PlaneMma); - run::(Default::default(), matmul::Strategy::PlaneMma); - run::(Default::default(), matmul::Strategy::Accelerated); + // run::(Default::default(), matmul::Strategy::PlaneMma); + // run::(Default::default(), matmul::Strategy::PlaneMma); + // run::(Default::default(), matmul::Strategy::Accelerated); run::( Default::default(), matmul::Strategy::Accelerated,