Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Nov 19, 2024
1 parent d84d743 commit 1f7063f
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 175 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use cubecl_core::prelude::*;

use crate::matmul::components::stage::{self, StageSize};
use crate::matmul::components::stage::{self};
use crate::matmul::components::{batch, global, tile};
use crate::matmul::components::{MatmulKernel, MatmulProblem};
use crate::matmul::kernels::matmul::AdvancedConfig;
Expand All @@ -26,7 +26,6 @@ pub trait Algorithm<EG: Numeric> {

type TileMatmul: tile::Matmul<Self::ES, Self::EA> + MatmulKernel<Self::ES, Self::EA>;

type StageSize: StageSize;
type StageMatmul: stage::Matmul<
Self::ES,
Self::EG,
Expand Down
16 changes: 5 additions & 11 deletions crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,21 @@ impl<EG: Numeric> base::Algorithm<EG> for Cmma<EG> {

type TileMatmul = Accelerated16x16x16<Self::ES, Self::EA>;

type StageSize = S4x4x2;
type StageMatmul = stage::multi_buffer::Matmul<
Self::ES,
Self::EG,
Self::EA,
Self::TileMatmul,
Self::StageSize,
>;
type StageMatmul =
stage::multi_buffer::Matmul<Self::ES, Self::EG, Self::EA, Self::TileMatmul, S4x4x2>;

type GlobalMatmul = global::homogeneous::Matmul<Self::EG, Self::ES, Self::StageMatmul>;

type BatchMatmul =
batch::one_to_one::Matmul<Self::EG, Self::ES, Self::GlobalMatmul, batch::NaturalDispatch>;

fn cube_dim() -> CubeDim {
CubeDim::new(Self::PLANE_DIM, Self::StageSize::NUM_M, 1)
CubeDim::new(Self::PLANE_DIM, S4x4x2::NUM_M, 1)
}

fn cube_count(problem: &MatmulProblem) -> CubeCount {
let m_stage = Self::StageSize::NUM_M * Self::TileMatmul::M;
let n_stage = Self::StageSize::NUM_N * Self::TileMatmul::N;
let m_stage = S4x4x2::NUM_M * Self::TileMatmul::M;
let n_stage = S4x4x2::NUM_N * Self::TileMatmul::N;
let cubes_needed_m = (problem.m as u32 + m_stage - 1) / m_stage;
let cubes_needed_n = (problem.n as u32 + n_stage - 1) / n_stage;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,21 @@ impl<EG: Numeric> base::Algorithm<EG> for PlaneMma<EG> {

type TileMatmul = PlaneMma16x16x16<Self::ES, Self::EA>;

type StageSize = S4x4x2;
type StageMatmul = stage::multi_buffer::Matmul<
Self::ES,
Self::EG,
Self::EA,
Self::TileMatmul,
Self::StageSize,
>;
type StageMatmul =
stage::multi_buffer::Matmul<Self::ES, Self::EG, Self::EA, Self::TileMatmul, S4x4x2>;

type GlobalMatmul = global::homogeneous::Matmul<Self::EG, Self::ES, Self::StageMatmul>;

type BatchMatmul =
batch::one_to_one::Matmul<Self::EG, Self::ES, Self::GlobalMatmul, batch::NaturalDispatch>;

fn cube_dim() -> CubeDim {
CubeDim::new(Self::PLANE_DIM, Self::StageSize::NUM_M, 1)
CubeDim::new(Self::PLANE_DIM, S4x4x2::NUM_M, 1)
}

fn cube_count(problem: &MatmulProblem) -> CubeCount {
let m_stage = Self::StageSize::NUM_M * Self::TileMatmul::M;
let n_stage = Self::StageSize::NUM_N * Self::TileMatmul::K;
let m_stage = S4x4x2::NUM_M * Self::TileMatmul::M;
let n_stage = S4x4x2::NUM_N * Self::TileMatmul::K;
let cubes_needed_m = (problem.m as u32 + m_stage - 1) / m_stage;
let cubes_needed_n = (problem.n as u32 + n_stage - 1) / n_stage;

Expand Down
Loading

0 comments on commit 1f7063f

Please sign in to comment.