Skip to content

Commit

Permalink
Refactor/reuse cmma matrices 2 (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Sep 10, 2024
1 parent 4e3106d commit ccd7299
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 180 deletions.
11 changes: 10 additions & 1 deletion crates/cubecl-core/src/frontend/sequence.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{branch::Iterable, indexation::Index, CubeContext, CubeType, ExpandElementTyped, Init};
use super::{
branch::Iterable, indexation::Index, CubeContext, CubeType, ExpandElementTyped, Init,
IntoRuntime,
};
use crate::unexpanded;
use std::{cell::RefCell, rc::Rc};

Expand Down Expand Up @@ -147,3 +150,9 @@ impl<T: CubeType> SequenceExpand<T> {
self.values.borrow()[index].clone()
}
}

impl<T: CubeType> IntoRuntime for Sequence<T> {
fn __expand_runtime_method(self, _context: &mut CubeContext) -> SequenceExpand<T> {
unimplemented!("Sequence doesn't exist at compile time");
}
}
37 changes: 32 additions & 5 deletions crates/cubecl-linalg/src/matmul/cmma/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub fn cmma_kernel<F: Float, FC: Float>(
let runtime_info = RuntimeCmmaInfo { ids, dims, offsets };

let shared_memories = make_shared_memories::<FC>(comptime_info);
let accumulate = make_accumulators::<F>(comptime_info);
let cmma_matrices = make_cmma_matrices::<F, FC>(comptime_info);
block_loop::<F, FC>(
lhs,
rhs,
out,
shared_memories,
accumulate,
cmma_matrices,
runtime_info,
comptime_info,
);
Expand Down Expand Up @@ -68,6 +68,13 @@ pub(crate) struct Offsets {
pub cube_col: u32,
}

#[derive(CubeType)]
pub(crate) struct CmmaMatrices<F: Float, FC: Float> {
pub accumulators: Sequence<cmma::Matrix<F>>,
pub lhs: cmma::Matrix<FC>,
pub rhs: cmma::Matrix<FC>,
}

#[cube]
fn get_dims<F: Float>(lhs: &Tensor<F>, rhs: &Tensor<F>) -> Dimensions {
let rank = lhs.rank();
Expand Down Expand Up @@ -133,9 +140,9 @@ fn make_shared_memories<FC: Float>(#[comptime] config: ComptimeCmmaInfo) -> Shar
}

#[cube]
pub(crate) fn make_accumulators<F: Float>(
pub(crate) fn make_cmma_matrices<F: Float, FC: Float>(
#[comptime] config: ComptimeCmmaInfo,
) -> Sequence<cmma::Matrix<F>> {
) -> CmmaMatrices<F, FC> {
let num_accumulators = config.num_accumulators;
let mut accumulators = Sequence::<cmma::Matrix<F>>::new();

Expand All @@ -154,7 +161,27 @@ pub(crate) fn make_accumulators<F: Float>(
accumulators.push(acc);
}

accumulators
let lhs = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::A,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);

let rhs = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::B,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);

CmmaMatrices::<F, FC> {
accumulators,
lhs,
rhs,
}
}

#[cube]
Expand Down
20 changes: 15 additions & 5 deletions crates/cubecl-linalg/src/matmul/cmma/block_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use super::{
base::{RuntimeCmmaInfo, SharedMemories},
base::{CmmaMatrices, RuntimeCmmaInfo, SharedMemories},
compute_loop::compute_loop,
config::ComptimeCmmaInfo,
load_shared_memory::load_to_shared_memories,
Expand All @@ -15,7 +15,7 @@ pub(crate) fn block_loop<F: Float, FC: Float>(
rhs: &Tensor<F>,
out: &mut Tensor<F>,
shared_memories: SharedMemories<FC>,
mut accumulators: Sequence<cmma::Matrix<F>>,
mut cmma_matrices: CmmaMatrices<F, FC>,
runtime_info: RuntimeCmmaInfo,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
Expand All @@ -42,7 +42,7 @@ pub(crate) fn block_loop<F: Float, FC: Float>(

compute_loop::<F, FC>(
shared_memories,
&mut accumulators,
&mut cmma_matrices,
runtime_info.ids,
comptime_info,
);
Expand All @@ -51,8 +51,18 @@ pub(crate) fn block_loop<F: Float, FC: Float>(
}

if write_out_reuse_smem {
ReuseSmemWriter::write_to_output(out, accumulators, runtime_info, comptime_info);
ReuseSmemWriter::write_to_output(
out,
cmma_matrices.accumulators,
runtime_info,
comptime_info,
);
} else {
LargeSmemWriter::write_to_output(out, accumulators, runtime_info, comptime_info);
LargeSmemWriter::write_to_output(
out,
cmma_matrices.accumulators,
runtime_info,
comptime_info,
);
}
}
37 changes: 15 additions & 22 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use super::base::{Ids, SharedMemories};
use super::base::{CmmaMatrices, Ids, SharedMemories};
use super::config::ComptimeCmmaInfo;

#[cube]
#[allow(unused_mut)]
pub(crate) fn compute_loop<F: Float, FC: Float>(
shared_memories: SharedMemories<FC>,
accumulators: &mut Sequence<cmma::Matrix<F>>,
cmma_matrices: &mut CmmaMatrices<F, FC>,
ids: Ids,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
Expand All @@ -20,13 +20,19 @@ pub(crate) fn compute_loop<F: Float, FC: Float>(
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;

let lhs = &cmma_matrices.lhs;
let rhs = &cmma_matrices.rhs;
let accumulators = &cmma_matrices.accumulators;

#[unroll]
for n in 0..num_accumulators {
compute_tile::<F, FC>(
tile_row,
tile_col_base + n,
shared_memories,
*accumulators.index(n),
lhs,
rhs,
accumulators.index(n),
comptime_info,
);
}
Expand All @@ -37,7 +43,9 @@ fn compute_tile<F: Float, FC: Float>(
tile_row: u32,
tile_col: u32,
shared_memories: SharedMemories<FC>,
accumulator: cmma::Matrix<F>,
lhs: &cmma::Matrix<FC>,
rhs: &cmma::Matrix<FC>,
accumulator: &cmma::Matrix<F>,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
let block_size_k = comptime_info.block_size_k;
Expand All @@ -61,24 +69,9 @@ fn compute_tile<F: Float, FC: Float>(
.rhs
.slice(shared_rhs_pos, shared_rhs_pos + smem_stride);

let a = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::A,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);
let b = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::B,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
);

cmma::load(&a, lhs_slice, 16);
cmma::load(&b, rhs_slice, 16);
cmma::load::<FC>(lhs, lhs_slice, 16);
cmma::load::<FC>(rhs, rhs_slice, 16);

cmma::execute::<FC, FC, F, F>(&a, &b, &accumulator, &accumulator);
cmma::execute::<FC, FC, F, F>(lhs, rhs, accumulator, accumulator);
}
}
2 changes: 1 addition & 1 deletion crates/cubecl-linalg/src/matmul/cmma/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct CmmaConfig {

impl Default for CmmaConfig {
fn default() -> Self {
Self::new(128, 16, true, WriteOutStrategy::ReuseSmem)
Self::new(128, 16, false, WriteOutStrategy::ReuseSmem)
}
}

Expand Down
7 changes: 4 additions & 3 deletions crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cubecl::prelude::*;
use cubecl_core as cubecl;

use crate::matmul::cmma::{
base::{make_accumulators, Ids, SharedMemories},
base::{make_cmma_matrices, Ids, SharedMemories},
compute_loop::compute_loop,
config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy},
};
Expand Down Expand Up @@ -34,11 +34,11 @@ fn compute_loop_test<F: Float, FC: Float>(
}

let shared_memories = SharedMemories::<FC> { lhs, rhs };
let mut accumulators = make_accumulators::<F>(comptime_info);
let mut matrices = make_cmma_matrices::<F, FC>(comptime_info);

compute_loop(
shared_memories,
&mut accumulators,
&mut matrices,
Ids {
coop: UNIT_POS_Y,
lane: UNIT_POS_X,
Expand All @@ -51,6 +51,7 @@ fn compute_loop_test<F: Float, FC: Float>(
let slice_offset = tile_size * tile_size;
let offset = UNIT_POS_Y * slice_offset * num_accumulators;

let accumulators = matrices.accumulators;
#[unroll]
for n in 0..num_accumulators {
let slice =
Expand Down
Loading

0 comments on commit ccd7299

Please sign in to comment.