Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Nov 28, 2024
1 parent 43ca9fd commit 6b4409e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ where

let range = k_range.1 - k_range.0;
let num_stages = (range + k_step - 1) / k_step;
let num_loops = num_stages - 1; // one stage is computed outside the loop
let num_loops = num_stages;

SMM::zero_accumulator(acc, config.to_smm_config());

Expand All @@ -67,52 +67,31 @@ where
// Load A
Self::LhsLoader::fill_stage(&mut lhs_loader, config);
Self::RhsLoader::fill_stage(&mut rhs_loader, config);

let lhs_buffer_reader_a = Self::LhsLoader::as_stage_reader(&lhs_loader);
let rhs_buffer_reader_a = Self::RhsLoader::as_stage_reader(&rhs_loader);

sync_units();

///////////////
// Compute A & Load B
SMM::execute(
&lhs_buffer_reader_a,
&rhs_buffer_reader_a,
&mut lhs_tile_a,
&mut rhs_tile_a,
acc,
config.to_smm_config(),
);

// Get B
Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step);
Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step);
Self::LhsLoader::fill_stage(&mut lhs_loader, config);
Self::RhsLoader::fill_stage(&mut rhs_loader, config);

let lhs_buffer_reader_b = Self::LhsLoader::as_stage_reader(&lhs_loader);
let rhs_buffer_reader_b = Self::RhsLoader::as_stage_reader(&rhs_loader);

sync_units();

for _ in 0..num_loops {
///////////////
// Compute B & Load A
SMM::execute(
&lhs_buffer_reader_b,
&rhs_buffer_reader_b,
&mut lhs_tile_b,
&mut rhs_tile_b,
acc,
config.to_smm_config(),
);
sync_units();

Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step);
Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step);
///////////////
// Load B & Advance
Self::LhsLoader::fill_stage(&mut lhs_loader, config);
Self::RhsLoader::fill_stage(&mut rhs_loader, config);

sync_units();
Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step);
Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step);

///////////////
// Compute A & Load B
// Execute A
SMM::execute(
&lhs_buffer_reader_a,
&rhs_buffer_reader_a,
Expand All @@ -122,24 +101,27 @@ where
config.to_smm_config(),
);

Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step);
Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step);
sync_units();

///////////////
// Load Next A
Self::LhsLoader::fill_stage(&mut lhs_loader, config);
Self::RhsLoader::fill_stage(&mut rhs_loader, config);

sync_units();
}
Self::LhsLoader::advance_view(&mut lhs_loader, buffer_step);
Self::RhsLoader::advance_view(&mut rhs_loader, buffer_step);

///////////////
// Compute B
SMM::execute(
&lhs_buffer_reader_b,
&rhs_buffer_reader_b,
&mut lhs_tile_b,
&mut rhs_tile_b,
acc,
config.to_smm_config(),
);
///////////////
// Execute B
SMM::execute(
&lhs_buffer_reader_b,
&rhs_buffer_reader_b,
&mut lhs_tile_b,
&mut rhs_tile_b,
acc,
config.to_smm_config(),
);
}

sync_units();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ where
}

#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
/// Configuration for the full load matmul
/// Configuration for the full load matmul
pub struct Config<S: stage::Config> {
smm_config: S,
check_m_bounds: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,10 @@ impl<EG: Numeric, ES: Numeric, EA: Numeric, Stage: StageSize, TMM: tile::Matmul<
type TileMatmul = TMM;

type StageMatmul =
stage::multi_buffer::Matmul<Self::ES, Self::EG, Self::EA, Self::TileMatmul, Stage>;
stage::single_buffer::Matmul<Self::ES, Self::EG, Self::EA, Self::TileMatmul, Stage>;

type GlobalMatmul = global::full_load::Matmul<
Self::EG,
Self::ES,
Self::EA,
Self::StageMatmul,
global::full_load::CyclicLoading,
global::full_load::CyclicLoading,
>;
type GlobalMatmul =
global::buffered::pipelined::Matmul<Self::EG, Self::ES, Self::EA, Self::StageMatmul>;

type BatchMatmul = batch::one_to_one::Matmul<Self::EG, Self::ES, Self::GlobalMatmul, Dispatch>;

Expand All @@ -63,4 +57,12 @@ impl<EG: Numeric, ES: Numeric, EA: Numeric, Stage: StageSize, TMM: tile::Matmul<

Dispatch::cube_count(cubes_for_m, cubes_for_n, problem.num_batches() as u32)
}

fn advanced_config() -> crate::matmul::kernels::matmul::AdvancedConfig {
crate::matmul::kernels::matmul::AdvancedConfig {
lhs_tiling_order: stage::TilingOrderConfig::ColMajor,
rhs_tiling_order: stage::TilingOrderConfig::RowMajor,
enforced_tile_layout: (None, None),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ macro_rules! matmul_test_define {
Self::TileMatmul,
S4x4x2,
>;
type GlobalMatmul =
global::buffered::pipelined::Matmul<Self::EG, Self::ES, Self::EA, Self::StageMatmul>;
type GlobalMatmul = global::buffered::pipelined::Matmul<
Self::EG,
Self::ES,
Self::EA,
Self::StageMatmul,
>;
type BatchMatmul = batch::one_to_one::Matmul<
Self::EG,
Self::ES,
Expand Down Expand Up @@ -109,8 +113,12 @@ macro_rules! matmul_test_define {
Self::TileMatmul,
S1x1x2,
>;
type GlobalMatmul =
global::buffered::pipelined::Matmul<Self::EG, Self::ES, Self::EA, Self::StageMatmul>;
type GlobalMatmul = global::buffered::pipelined::Matmul<
Self::EG,
Self::ES,
Self::EA,
Self::StageMatmul,
>;
type BatchMatmul = batch::one_to_one::Matmul<
Self::EG,
Self::ES,
Expand Down Expand Up @@ -170,8 +178,12 @@ macro_rules! matmul_test_define {
Self::TileMatmul,
S1x1x2,
>;
type GlobalMatmul =
global::buffered::pipelined::Matmul<Self::EG, Self::ES, Self::EA, Self::StageMatmul>;
type GlobalMatmul = global::buffered::pipelined::Matmul<
Self::EG,
Self::ES,
Self::EA,
Self::StageMatmul,
>;
type BatchMatmul = batch::one_to_one::Matmul<
Self::EG,
Self::ES,
Expand Down

0 comments on commit 6b4409e

Please sign in to comment.