Skip to content

Commit

Permalink
Fast but not right
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Nov 22, 2024
1 parent 95be077 commit f176524
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 48 deletions.
4 changes: 2 additions & 2 deletions crates/cubecl-linalg/src/matmul/components/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ pub enum MatrixLayout {
/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
match layout {
MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
MatrixLayout::RowMajor => cmma::MatrixLayout::ColMajor,
MatrixLayout::ColMajor => cmma::MatrixLayout::RowMajor,
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-linalg/src/matmul/components/stage/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,7 @@ create_cmma_stage!(S2x2x1, 2, 2, 1);
create_cmma_stage!(S2x2x2, 2, 2, 2);
create_cmma_stage!(S4x4x1, 4, 4, 1);
create_cmma_stage!(S4x4x2, 4, 4, 2);
create_cmma_stage!(S4x4x4, 4, 4, 4);
create_cmma_stage!(S8x1x1, 8, 1, 1);
create_cmma_stage!(S8x8x1, 8, 8, 1);
create_cmma_stage!(S8x8x2, 8, 8, 2);
15 changes: 10 additions & 5 deletions crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@ use std::marker::PhantomData;

use cubecl_core::prelude::*;

use crate::matmul::components::stage::{self, S8x8x1, StageSize};
use crate::matmul::components::stage::{self, S8x8x2, StageSize};
use crate::matmul::components::tile::accelerated::Accelerated16x16x16;
use crate::matmul::components::tile::Matmul;
use crate::matmul::components::MatmulProblem;
use crate::matmul::components::{batch, global};

use super::base;

type Stage = S8x8x1;
type Stage = S8x8x2;

pub struct Cmma<EG: Numeric> {
pub _eg: PhantomData<EG>,
}

impl<EG: Numeric> base::Algorithm<EG> for Cmma<EG> {
const PLANE_DIM: u32 = 32;

type EG = EG;
type ES = half::f16;
type EA = half::f16;
type EA = f32;

type TileMatmul = Accelerated16x16x16<Self::ES, Self::EA>;

Expand All @@ -29,8 +30,12 @@ impl<EG: Numeric> base::Algorithm<EG> for Cmma<EG> {

type GlobalMatmul = global::homogeneous::Matmul<Self::EG, Self::ES, Self::StageMatmul>;

type BatchMatmul =
batch::one_to_one::Matmul<Self::EG, Self::ES, Self::GlobalMatmul, batch::NaturalDispatch>;
type BatchMatmul = batch::one_to_one::Matmul<
Self::EG,
Self::ES,
Self::GlobalMatmul,
batch::TransposedDispatch,
>;

fn cube_dim() -> CubeDim {
CubeDim::new(Self::PLANE_DIM, Stage::NUM_M, 1)
Expand Down
14 changes: 9 additions & 5 deletions crates/cubecl-linalg/src/matmul/kernels/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ fn matmul_cmma_ref_no_check<R: Runtime, EG: Numeric>(
k: k as usize,
batches: out.shape[..out.shape.len() - 2].to_vec(),
lhs_layout: match transposed.0 {
true => matmul::components::MatrixLayout::ColMajor,
false => matmul::components::MatrixLayout::RowMajor,
true => matmul::components::MatrixLayout::RowMajor,
false => matmul::components::MatrixLayout::ColMajor,
},
rhs_layout: match transposed.1 {
true => matmul::components::MatrixLayout::ColMajor,
false => matmul::components::MatrixLayout::RowMajor,
true => matmul::components::MatrixLayout::RowMajor,
false => matmul::components::MatrixLayout::ColMajor,
},
lhs_line_size,
rhs_line_size,
Expand All @@ -152,7 +152,11 @@ pub(crate) fn matmul_cube_preparation<R: Runtime, EG: Numeric, D: Algorithm<EG>>
let cube_dim = D::cube_dim();
let cube_count = D::cube_count(&problem);

let advanced_config = Default::default();
let advanced_config = AdvancedConfig {
lhs_tiling_order: matmul::components::stage::TilingOrderConfig::ColMajor,
rhs_tiling_order: matmul::components::stage::TilingOrderConfig::ColMajor,
enforced_tile_layout: (None, None),
};

launch_matmul::<R, EG, D>(
client,
Expand Down
78 changes: 42 additions & 36 deletions crates/cubecl/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ fn run<R: Runtime, E: Float>(device: R::Device, strategy: matmul::Strategy) {
client.enable_timestamps();

let bench = MatmulBench::<R, E> {
b: 8,
m: 2048,
k: 2048,
n: 2048,
b: 2,
m: 4096,
k: 4096,
n: 4096,
client,
device,
strategy,
Expand All @@ -75,25 +75,31 @@ fn run<R: Runtime, E: Float>(device: R::Device, strategy: matmul::Strategy) {
}

fn main() {
#[cfg(feature = "wgpu")]
{
run::<cubecl::wgpu::WgpuRuntime, f32>(
Default::default(),
matmul::Strategy::Tiling2D(Default::default()),
);
run::<cubecl::wgpu::WgpuRuntime, f32>(Default::default(), matmul::Strategy::PlaneMma);
}
// #[cfg(feature = "wgpu")]
// {
// run::<cubecl::wgpu::WgpuRuntime, f32>(
// Default::default(),
// matmul::Strategy::Tiling2D(Default::default()),
// );
// run::<cubecl::wgpu::WgpuRuntime, f32>(Default::default(), matmul::Strategy::PlaneMma);
// }

#[cfg(feature = "wgpu-spirv")]
{
run::<cubecl::wgpu::WgpuRuntime<cubecl::wgpu::spirv::SpirvCompiler>, f32>(
Default::default(),
matmul::Strategy::Tiling2D(Default::default()),
);
run::<cubecl::wgpu::WgpuRuntime<cubecl::wgpu::spirv::SpirvCompiler>, f32>(
Default::default(),
matmul::Strategy::PlaneMma,
);
type R = cubecl::wgpu::WgpuRuntime<cubecl::wgpu::spirv::SpirvCompiler>;
// run::<cubecl::wgpu::WgpuRuntime<cubecl::wgpu::spirv::SpirvCompiler>, f32>(
// Default::default(),
// matmul::Strategy::Tiling2D(Default::default()),
// );
// run::<cubecl::wgpu::WgpuRuntime<cubecl::wgpu::spirv::SpirvCompiler>, f32>(
// Default::default(),
// matmul::Strategy::PlaneMma,
// );
run::<R, half::f16>(Default::default(), matmul::Strategy::default());
// run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::PlaneMma);
// run::<cubecl::cuda::CudaRuntime, half::f16>(Default::default(), matmul::Strategy::PlaneMma);
// run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::Accelerated);
run::<R, half::f16>(Default::default(), matmul::Strategy::Accelerated);
}

#[cfg(all(feature = "hip", target_os = "linux"))]
Expand Down Expand Up @@ -130,25 +136,25 @@ fn main() {

#[cfg(feature = "cuda")]
{
run::<cubecl::cuda::CudaRuntime, f32>(
Default::default(),
matmul::Strategy::Tiling2D(Default::default()),
);
run::<cubecl::cuda::CudaRuntime, half::f16>(
Default::default(),
matmul::Strategy::Tiling2D(Default::default()),
);
run::<cubecl::cuda::CudaRuntime, f32>(
Default::default(),
matmul::Strategy::CmmaOld(Default::default()),
);
// run::<cubecl::cuda::CudaRuntime, f32>(
// Default::default(),
// matmul::Strategy::Tiling2D(Default::default()),
// );
// run::<cubecl::cuda::CudaRuntime, half::f16>(
// Default::default(),
// matmul::Strategy::Tiling2D(Default::default()),
// );
// run::<cubecl::cuda::CudaRuntime, f32>(
// Default::default(),
// matmul::Strategy::CmmaOld(Default::default()),
// );
run::<cubecl::cuda::CudaRuntime, half::f16>(
Default::default(),
matmul::Strategy::CmmaOld(Default::default()),
matmul::Strategy::default(),
);
run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::PlaneMma);
run::<cubecl::cuda::CudaRuntime, half::f16>(Default::default(), matmul::Strategy::PlaneMma);
run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::Accelerated);
// run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::PlaneMma);
// run::<cubecl::cuda::CudaRuntime, half::f16>(Default::default(), matmul::Strategy::PlaneMma);
// run::<cubecl::cuda::CudaRuntime, f32>(Default::default(), matmul::Strategy::Accelerated);
run::<cubecl::cuda::CudaRuntime, half::f16>(
Default::default(),
matmul::Strategy::Accelerated,
Expand Down

0 comments on commit f176524

Please sign in to comment.