diff --git a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs index 2b3f8dae..c53ba17d 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/buffered/pipelined/base.rs @@ -56,7 +56,7 @@ where let range = k_range.1 - k_range.0; let num_stages = (range + k_step - 1) / k_step; - let num_loops = num_stages - 1; // one stage is computed outside the loop + let num_loops = num_stages; SMM::zero_accumulator(acc, config.to_smm_config()); @@ -67,52 +67,31 @@ where // Load A Self::LhsLoader::fill_stage(&mut lhs_loader, config); Self::RhsLoader::fill_stage(&mut rhs_loader, config); + let lhs_buffer_reader_a = Self::LhsLoader::as_stage_reader(&lhs_loader); let rhs_buffer_reader_a = Self::RhsLoader::as_stage_reader(&rhs_loader); - sync_units(); - /////////////// - // Compute A & Load B - SMM::execute( - &lhs_buffer_reader_a, - &rhs_buffer_reader_a, - &mut lhs_tile_a, - &mut rhs_tile_a, - acc, - config.to_smm_config(), - ); - + // Get B Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step); Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step); - Self::LhsLoader::fill_stage(&mut lhs_loader, config); - Self::RhsLoader::fill_stage(&mut rhs_loader, config); + let lhs_buffer_reader_b = Self::LhsLoader::as_stage_reader(&lhs_loader); let rhs_buffer_reader_b = Self::RhsLoader::as_stage_reader(&rhs_loader); - sync_units(); - for _ in 0..num_loops { - /////////////// - // Compute B & Load A - SMM::execute( - &lhs_buffer_reader_b, - &rhs_buffer_reader_b, - &mut lhs_tile_b, - &mut rhs_tile_b, - acc, - config.to_smm_config(), - ); + sync_units(); - Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step); - Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step); + /////////////// + // Load B & Advance Self::LhsLoader::fill_stage(&mut lhs_loader, config); Self::RhsLoader::fill_stage(&mut rhs_loader, config); - sync_units(); + Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step); + Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step); /////////////// - // Compute A & Load B + // Execute A SMM::execute( &lhs_buffer_reader_a, &rhs_buffer_reader_a, @@ -122,24 +101,27 @@ where config.to_smm_config(), ); - Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step); - Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step); + sync_units(); + + /////////////// + // Load Next A Self::LhsLoader::fill_stage(&mut lhs_loader, config); Self::RhsLoader::fill_stage(&mut rhs_loader, config); - sync_units(); - } + Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step); + Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step); - /////////////// - // Compute B - SMM::execute( - &lhs_buffer_reader_b, - &rhs_buffer_reader_b, - &mut lhs_tile_b, - &mut rhs_tile_b, - acc, - config.to_smm_config(), - ); + /////////////// + // Execute B + SMM::execute( + &lhs_buffer_reader_b, + &rhs_buffer_reader_b, + &mut lhs_tile_b, + &mut rhs_tile_b, + acc, + config.to_smm_config(), + ); + } sync_units(); diff --git a/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs b/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs index 814354f6..373ea82f 100644 --- a/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs +++ b/crates/cubecl-linalg/src/matmul/components/global/full_load/base.rs @@ -182,7 +182,7 @@ where } #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] -/// Configuration for the full load matmul +/// Configuration for the full load matmul pub struct Config { smm_config: S, check_m_bounds: bool, diff --git a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs index d94a25c9..d9cbe8b0 100644 --- a/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs +++ b/crates/cubecl-linalg/src/matmul/kernels/matmul/algorithm/standard.rs @@ -38,16 +38,10 @@ impl; + stage::single_buffer::Matmul; - type GlobalMatmul = global::full_load::Matmul< - Self::EG, - Self::ES, - Self::EA, - Self::StageMatmul, - global::full_load::CyclicLoading, - global::full_load::CyclicLoading, - >; + type GlobalMatmul = + global::buffered::pipelined::Matmul; type BatchMatmul = batch::one_to_one::Matmul; @@ -63,4 +57,12 @@ impl crate::matmul::kernels::matmul::AdvancedConfig { + crate::matmul::kernels::matmul::AdvancedConfig { + lhs_tiling_order: stage::TilingOrderConfig::ColMajor, + rhs_tiling_order: stage::TilingOrderConfig::RowMajor, + enforced_tile_layout: (None, None), + } + } } diff --git a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs index d506ac10..e8e8a752 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_macros/cmma/matmul_algorithm.rs @@ -48,8 +48,12 @@ macro_rules! matmul_test_define { Self::TileMatmul, S4x4x2, >; - type GlobalMatmul = - global::buffered::pipelined::Matmul; + type GlobalMatmul = global::buffered::pipelined::Matmul< + Self::EG, + Self::ES, + Self::EA, + Self::StageMatmul, + >; type BatchMatmul = batch::one_to_one::Matmul< Self::EG, Self::ES, @@ -109,8 +113,12 @@ macro_rules! matmul_test_define { Self::TileMatmul, S1x1x2, >; - type GlobalMatmul = - global::buffered::pipelined::Matmul; + type GlobalMatmul = global::buffered::pipelined::Matmul< + Self::EG, + Self::ES, + Self::EA, + Self::StageMatmul, + >; type BatchMatmul = batch::one_to_one::Matmul< Self::EG, Self::ES, @@ -170,8 +178,12 @@ macro_rules! matmul_test_define { Self::TileMatmul, S1x1x2, >; - type GlobalMatmul = - global::buffered::pipelined::Matmul; + type GlobalMatmul = global::buffered::pipelined::Matmul< + Self::EG, + Self::ES, + Self::EA, + Self::StageMatmul, + >; type BatchMatmul = batch::one_to_one::Matmul< Self::EG, Self::ES,