Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/reuse cmma matrices #107

Closed
wants to merge 80 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
b7064e1
minor refactor
louisfd Aug 20, 2024
9527d3f
change accumulators for sequence
louisfd Aug 20, 2024
eb22cda
add failing test
Aug 21, 2024
e9d473d
Merge branch 'main' into refactor/cmma_generalize
Aug 21, 2024
72b3f89
wip
louisfd Aug 27, 2024
700c0cf
:wq Merge branch 'main' of github.com:tracel-ai/cubecl
louisfd Aug 27, 2024
55b5fd3
Merge branch 'main' into refactor/cmma_generalize
louisfd Aug 27, 2024
ec83d3c
wip
louisfd Aug 27, 2024
7a4f3e4
wip
louisfd Aug 27, 2024
5ed5caa
wip
Aug 27, 2024
f8aa418
wip
louisfd Aug 27, 2024
335e4c2
wip
louisfd Aug 27, 2024
6d20a18
wip
louisfd Aug 27, 2024
da02986
wip
Aug 27, 2024
9e917d5
coop and lane independant from unit pos
louisfd Aug 27, 2024
561f71c
custom block size
louisfd Aug 27, 2024
9a6fc84
num accumulators
louisfd Aug 28, 2024
6dbf866
fix k loop test
Aug 28, 2024
3aacdf6
allowing any config wip
louisfd Aug 28, 2024
c55dd64
merge
louisfd Aug 28, 2024
b6d778d
generalize fragment to sm
louisfd Aug 28, 2024
e37d9cd
Merge branch 'main' of github.com:tracel-ai/cubecl into refactor/cmma…
Aug 28, 2024
c7abc89
Merge branch 'refactor/cmma_generalize' of github.com:tracel-ai/cubec…
Aug 28, 2024
5831bd1
sm max in bytes
louisfd Aug 28, 2024
730e190
wip
louisfd Aug 29, 2024
644b4ea
Merge branch 'refactor/cmma_generalize' of github.com:tracel-ai/cubec…
Aug 29, 2024
0349242
add index of error
Aug 29, 2024
def320f
refactor load and write tests
louisfd Aug 30, 2024
99dc7dc
refactor compute loop test
louisfd Aug 30, 2024
f71a959
Merge branch 'refactor/cmma_generalize' of github.com:tracel-ai/cubec…
louisfd Aug 30, 2024
bfca4fa
Merge branch 'main' of github.com:tracel-ai/cubecl
louisfd Aug 30, 2024
daa39f5
Merge branch 'main' into refactor/cmma_generalize
louisfd Aug 30, 2024
829d50f
add vec1
louisfd Aug 30, 2024
0f6b146
vec tests
louisfd Aug 30, 2024
9772a3a
unhardcode
louisfd Aug 30, 2024
a73158d
wip refactor only two degrees of liberty
louisfd Aug 30, 2024
e9cbeca
block config
louisfd Sep 4, 2024
3a531d8
add tests
louisfd Sep 4, 2024
e1ea5e0
testing alternate block sizes
Sep 4, 2024
348953f
fix write
louisfd Sep 4, 2024
1f2b62f
played with tests
Sep 4, 2024
668ba03
ignore failing test
Sep 4, 2024
ef0e746
Merge branch 'main' of github.com:tracel-ai/cubecl
louisfd Sep 4, 2024
c343f66
Merge branch 'main' into refactor/cmma_generalize
louisfd Sep 4, 2024
49683cb
fmt
louisfd Sep 4, 2024
acb9285
fix
louisfd Sep 5, 2024
18a115a
back to using unit pos directly
louisfd Sep 5, 2024
a0db0e6
refactor vec
louisfd Sep 5, 2024
2f64b5d
fix equation
Sep 5, 2024
c97651d
reused smem
louisfd Sep 6, 2024
ff9cb68
Merge branch 'refactor/cmma_generalize' into feat/reuse_out_smem
Sep 6, 2024
e1cb240
works
Sep 6, 2024
6abeaaf
re3factor wip
louisfd Sep 6, 2024
68e921a
wip refactor
louisfd Sep 6, 2024
a87e0f1
wip refactor runtime info
louisfd Sep 6, 2024
68ae416
runtime info wip
louisfd Sep 6, 2024
9e6bbb2
fix mixed args
Sep 6, 2024
a7e2ed3
complete runtime info refactor
louisfd Sep 6, 2024
ecfc8c4
still a bug when b_k>16
Sep 6, 2024
71414af
clippy
louisfd Sep 6, 2024
0b6aeb7
merge main
louisfd Sep 6, 2024
69d5e24
rename confusing lane_dim
louisfd Sep 6, 2024
e2e4283
fix mistake
louisfd Sep 9, 2024
b99c945
little refactor
Sep 9, 2024
5a9eb11
little refactor
Sep 9, 2024
87c3d7c
fix 32x32 test
Sep 9, 2024
6275052
fmt
louisfd Sep 9, 2024
0f180fc
merge main
louisfd Sep 9, 2024
903a84d
fix merge
louisfd Sep 9, 2024
07a31af
new only at beginning
louisfd Sep 9, 2024
e76a211
minor
Sep 9, 2024
6e83f33
ignore failing
Sep 9, 2024
9de899a
remove messy things from topology flaky test
louisfd Sep 9, 2024
b583675
Merge pull request #106 from tracel-ai/fix/topology_flaky_test
louisfd Sep 9, 2024
e0b59c9
Merge branch 'main' into feat/reuse_out_smem
louisfd Sep 9, 2024
e5865b5
Merge branch 'feat/reuse_out_smem' of github.com:tracel-ai/cubecl int…
louisfd Sep 9, 2024
d90d529
Merge pull request #101 from tracel-ai/feat/reuse_out_smem
louisfd Sep 9, 2024
46e44f2
Merge branch 'main' into refactor/reuse_cmma_matrices
louisfd Sep 9, 2024
d4d3d55
Merge branch 'main' into refactor/reuse_cmma_matrices
louisfd Sep 9, 2024
31596c5
Merge branch 'refactor/reuse_cmma_matrices' of github.com:tracel-ai/c…
louisfd Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
/// The maximal size of a shared memory
/// The maximal size of a shared memory, in bytes
fn max_shared_memory_size() -> usize;
}
1 change: 1 addition & 0 deletions crates/cubecl-core/src/frontend/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{cell::RefCell, rc::Rc};
/// All methods [push](Sequence::push), [index](Sequence::index) and
/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
/// on the generated kernel.
#[derive(Debug, Clone)]
pub struct Sequence<T: CubeType> {
values: Vec<T>,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/ir/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ pub struct CubeDim {
}

impl CubeDim {
pub(crate) fn num_elems(&self) -> u32 {
pub fn num_elems(&self) -> u32 {
self.x * self.y * self.z
}
}
Expand Down
14 changes: 3 additions & 11 deletions crates/cubecl-core/src/runtime_tests/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,33 @@ use crate as cubecl;
use cubecl::prelude::*;

#[cube(launch)]
pub fn kernel_absolute_pos(output1: &mut Array<UInt>, output2: &mut Array<UInt>) {
pub fn kernel_absolute_pos(output1: &mut Array<UInt>) {
if ABSOLUTE_POS >= output1.len() {
return;
}

output1[ABSOLUTE_POS] = ABSOLUTE_POS;
output2[ABSOLUTE_POS] = ABSOLUTE_POS;
}

pub fn test_kernel_topology_absolute_pos<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let cube_count = (3, 5, 7);
let cube_dim = (16, 16, 1);
let extra: u32 = 3u32;

let length =
(cube_count.0 * cube_count.1 * cube_count.2 * cube_dim.0 * cube_dim.1 * cube_dim.2) + extra;
let length = cube_count.0 * cube_count.1 * cube_count.2 * cube_dim.0 * cube_dim.1 * cube_dim.2;
let handle1 = client.empty(length as usize * core::mem::size_of::<u32>());
let handle2 = client.empty(length as usize * core::mem::size_of::<u32>());

unsafe {
kernel_absolute_pos::launch::<R>(
&client,
CubeCount::Static(cube_count.0, cube_count.1, cube_count.2),
CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2),
ArrayArg::from_raw_parts(&handle1, length as usize, 1),
ArrayArg::from_raw_parts(&handle2, length as usize, 1),
)
};

let actual = client.read(handle1.binding());
let actual = u32::from_bytes(&actual);
let mut expect: Vec<u32> = (0..length - extra).collect();
expect.push(0);
expect.push(0);
expect.push(0);
let expect: Vec<u32> = (0..length).collect();

assert_eq!(actual, &expect);
}
Expand Down
3 changes: 1 addition & 2 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ impl Compiler for CudaCompiler {
}

fn max_shared_memory_size() -> usize {
// TODO: Find out this value.
usize::MAX
49152
}
}

Expand Down
135 changes: 87 additions & 48 deletions crates/cubecl-linalg/src/matmul/cmma/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,7 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use super::block_loop::block_loop;
use super::config::CmmaConfig;

#[cube(launch_unchecked)]
#[allow(unused_mut)]
pub fn cmma_kernel<F: Float, FC: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &mut Tensor<F>,
config: Comptime<CmmaConfig>,
) {
let dims = get_dims::<F>(lhs, rhs);
let offsets = calculate_offsets::<F>(lhs, rhs, out, config);
let shared_memories = make_shared_memories::<FC>(config);
let accumulate = make_accumulators::<F>();
block_loop::<F, FC>(
lhs,
rhs,
out,
offsets,
shared_memories,
accumulate,
config,
dims,
);
}
use super::config::ComptimeCmmaInfo;

#[derive(CubeType, Copy, Clone)]
pub(crate) struct Dimensions {
Expand All @@ -36,15 +12,22 @@ pub(crate) struct Dimensions {
}

#[derive(CubeType, Copy, Clone)]
pub(crate) struct SharedMemories<FC: Float> {
pub lhs: SharedMemory<FC>,
pub rhs: SharedMemory<FC>,
pub(crate) struct Ids {
pub coop: UInt,
pub lane: UInt,
}

#[derive(CubeType, Copy, Clone)]
pub(crate) struct Accumulators<F: Float> {
pub first: cmma::Matrix<F>,
pub second: cmma::Matrix<F>,
pub(crate) struct RuntimeCmmaInfo {
pub ids: Ids,
pub dims: Dimensions,
pub offsets: Offsets,
}

#[derive(CubeType, Copy, Clone)]
pub(crate) struct SharedMemories<FC: Float> {
pub lhs: SharedMemory<FC>,
pub rhs: SharedMemory<FC>,
}

#[derive(CubeType, Copy, Clone)]
Expand All @@ -57,7 +40,39 @@ pub(crate) struct Offsets {
pub batch_out: UInt,
pub cube_row: UInt,
pub cube_col: UInt,
pub k: UInt,
}

#[derive(CubeType)]
pub(crate) struct CmmaMatrices<F: Float, FC: Float> {
pub accumulators: Sequence<cmma::Matrix<F>>,
pub lhs: cmma::Matrix<FC>,
pub rhs: cmma::Matrix<FC>,
}

#[cube(launch_unchecked)]
#[allow(unused_mut)]
pub fn cmma_kernel<F: Float, FC: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &mut Tensor<F>,
comptime_info: Comptime<ComptimeCmmaInfo>,
) {
let ids = get_ids();
let dims = get_dims::<F>(lhs, rhs);
let offsets = calculate_offsets::<F>(lhs, rhs, out, comptime_info);
let runtime_info = RuntimeCmmaInfo { ids, dims, offsets };

let shared_memories = make_shared_memories::<FC>(comptime_info);
let cmma_matrices = make_cmma_matrices::<F, FC>(comptime_info);
block_loop::<F, FC>(
lhs,
rhs,
out,
shared_memories,
cmma_matrices,
runtime_info,
comptime_info,
);
}

#[cube]
Expand All @@ -77,7 +92,7 @@ fn calculate_offsets<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &Tensor<F>,
config: Comptime<CmmaConfig>,
config: Comptime<ComptimeCmmaInfo>,
) -> Offsets {
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_n = Comptime::map(config, |c| c.block_size_n);
Expand Down Expand Up @@ -109,12 +124,11 @@ fn calculate_offsets<F: Float>(
batch_out,
cube_row,
cube_col,
k: UInt::new(0), // Changes during kernel
}
}

#[cube]
fn make_shared_memories<FC: Float>(config: Comptime<CmmaConfig>) -> SharedMemories<FC> {
fn make_shared_memories<FC: Float>(config: Comptime<ComptimeCmmaInfo>) -> SharedMemories<FC> {
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_n = Comptime::map(config, |c| c.block_size_n);
Expand All @@ -126,28 +140,53 @@ fn make_shared_memories<FC: Float>(config: Comptime<CmmaConfig>) -> SharedMemori
}

#[cube]
pub(crate) fn make_accumulators<F: Float>() -> Accumulators<F> {
// Assumes two per warp. TODO generalize
let acc0 = cmma::Matrix::<F>::new(
cmma::MatrixIdent::Accumulator,
pub(crate) fn make_cmma_matrices<F: Float, FC: Float>(
config: Comptime<ComptimeCmmaInfo>,
) -> CmmaMatrices<F, FC> {
let num_accumulators = Comptime::map(config, |c| c.num_accumulators);
let mut accumulators = Sequence::<cmma::Matrix<F>>::new();

for _ in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) {
let acc = cmma::Matrix::<F>::new(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
);

cmma::fill::<F>(&acc, F::new(0.0));

accumulators.push(acc);
}

let lhs = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::A,
16,
16,
16,
cmma::MatrixLayout::Undefined,
cmma::MatrixLayout::RowMajor,
);
let acc1 = cmma::Matrix::<F>::new(
cmma::MatrixIdent::Accumulator,

let rhs = cmma::Matrix::<FC>::new(
cmma::MatrixIdent::B,
16,
16,
16,
cmma::MatrixLayout::Undefined,
cmma::MatrixLayout::RowMajor,
);

cmma::fill::<F>(&acc0, F::new(0.0));
cmma::fill::<F>(&acc1, F::new(0.0));
CmmaMatrices {
accumulators,
lhs,
rhs,
}
}

Accumulators {
first: acc0,
second: acc1,
#[cube]
fn get_ids() -> Ids {
Ids {
coop: UNIT_POS_Y,
lane: UNIT_POS_X,
}
}
3 changes: 0 additions & 3 deletions crates/cubecl-linalg/src/matmul/cmma/block_io/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::base::Dimensions;
use crate::matmul::cmma::config::CmmaConfig;

#[cube]
pub(crate) trait BlockLoader<F: Float, FC: Float>: Send + Sync + 'static {
Expand All @@ -24,12 +23,10 @@ pub(crate) trait BlockWriter<F: Float>: Send + Sync + 'static {
fn write_output(
out: &mut Tensor<F>,
accumulator_sm: SharedMemory<F>,
n_iter: UInt,
batch_offset: UInt,
read_position: UInt,
write_row: UInt,
write_col: UInt,
dims: Dimensions,
config: Comptime<CmmaConfig>,
);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::{base::Dimensions, config::CmmaConfig};
use crate::matmul::cmma::base::Dimensions;

use super::base::{BlockLoader, BlockWriter};

Expand All @@ -21,13 +21,18 @@ impl<F: Float, FC: Float> BlockLoader<F, FC> for HorizontalCheckBlockIO {
) {
let tensor_vec = Comptime::vectorization(tensor);
let tensor_vec_r = Comptime::runtime(tensor_vec);
let is_scalar = Comptime::map(tensor_vec, |v| v.val == 1);

if read_col < dim_horizontal {
let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec_r;
let value = tensor[read_pos];

for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) {
shared_memory[write_pos + i] = FC::cast_from(value[i]);
if Comptime::get(is_scalar) {
shared_memory[write_pos] = FC::cast_from(value);
} else {
for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) {
shared_memory[write_pos + i] = FC::cast_from(value[i]);
}
}
} else {
for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) {
Expand All @@ -42,33 +47,31 @@ impl<F: Float> BlockWriter<F> for HorizontalCheckBlockIO {
fn write_output(
out: &mut Tensor<F>,
accumulator_sm: SharedMemory<F>,
n_iter: UInt,
batch_offset: UInt,
read_position: UInt,
write_row: UInt,
write_col: UInt,
dims: Dimensions,
config: Comptime<CmmaConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let out_vec = Comptime::vectorization(out);
let out_vec_r = Comptime::runtime(out_vec);
let is_scalar = Comptime::map(out_vec, |v| v.val == 1);

let col_with_n_iter = write_col + n_iter * Comptime::runtime(tile_size);

if col_with_n_iter < dims.n {
let n_iter_read_offset = n_iter * Comptime::runtime(tile_size * tile_size);
let read_position = read_position + n_iter_read_offset;
if write_col < dims.n {
let write_position = batch_offset + write_row * dims.n + write_col;

let write_position = batch_offset + write_row * dims.n + col_with_n_iter;
if Comptime::get(is_scalar) {
let val = accumulator_sm[read_position];
out[write_position / out_vec_r] = val;
} else {
let mut value = F::vectorized_empty(Comptime::get(out_vec));

let mut value = F::vectorized_empty(Comptime::get(out_vec));
for i in range(0u32, Comptime::get(out_vec), Comptime::new(true)) {
value[i] = accumulator_sm[read_position + i];
}

for i in range(0u32, 4u32, Comptime::new(true)) {
value[i] = accumulator_sm[read_position + i];
out[write_position / out_vec_r] = value;
}

out[write_position / out_vec_r] = value;
}
}
}
Loading
Loading