From 08416704a16f5a76ccd7085527ea69ebd520a7fe Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Wed, 27 Nov 2024 21:09:27 -0500 Subject: [PATCH 01/12] Add Support for cast instruction in hip wmma intrinsic compiler --- .../src/hip/wmma/intrinsic_compiler.rs | 60 ++++++++++++++----- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index accb853e..37d5e1ed 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -18,6 +18,7 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { } fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("typedef _Float16 half8 __attribute__((ext_vector_type(8)));\n")?; f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?; f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n") } @@ -51,7 +52,12 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { ) -> std::fmt::Result { match fragment.ident { FragmentIdent::A | FragmentIdent::B => write!(f, "half16"), - FragmentIdent::Accumulator => write!(f, "float8"), + FragmentIdent::Accumulator => { + match fragment.elem { + Elem::F16 => write!(f, "half8"), + _ => write!(f, "float8"), + } + }, FragmentIdent::_Dialect(_) => Ok(()), } } @@ -77,9 +83,9 @@ for (uint i = 0; i < uint(8); ++i) {{ ) } } - WmmaInstruction::Load { frag, value, .. } => { + WmmaInstruction::Load { frag, value, layout, .. } => { // Matrix A must be in column major layout - // Matrix B must be in row major layout + // Matrices B, C and D must be in row major layout let item = value.item(); let mut value_ident = format!("{value}"); if item.vectorization > 1 { @@ -89,25 +95,46 @@ for (uint i = 0; i < uint(8); ++i) {{ )?; value_ident = format!("{value}_half"); } - let index = match frag { + let (index, length) = match frag { Variable::WmmaFragment { frag: inner, .. } => { - if (inner.ident == FragmentIdent::A - && inner.layout.unwrap() == FragmentLayout::ColMajor) - || (inner.ident == FragmentIdent::B - && inner.layout.unwrap() == FragmentLayout::RowMajor) - { - // correct layout - "i * uint(16) + wmmaLane" - } else { - // transpose - "i + wmmaLane * uint(16)" + match inner.ident { + FragmentIdent::A | FragmentIdent::B => { + let index = if (inner.ident == FragmentIdent::A + && inner.layout.unwrap() == FragmentLayout::ColMajor) + || (inner.ident == FragmentIdent::B + && inner.layout.unwrap() == FragmentLayout::RowMajor) + { + // correct layout + "i * uint(16) + wmmaLane" + } else { + // transpose + "i + wmmaLane * uint(16)" + }; + (index, 16) + }, + FragmentIdent::Accumulator => { + let length = 8; + // For the acc we check layout of the source + // the acc must be in row format which mean that each lane (thread) is in col format + // moreover even rows are in 1~16 first thread of the wavefront and odd rows are in 17~32 + match layout { + Some(FragmentLayout::ColMajor) => { + ("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length) + }, + Some(FragmentLayout::RowMajor) => { + ("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length) + }, + _ => panic!("cannot load data to an accumulator without knowing the layout of the data "), + } + }, + other => panic!("unknown matrix identifier {other}") } } other => panic!("{other} is not a WMMMA fragment!"), }; write!( f, - "for (uint i = 0; i < uint(16); ++i) {{ + "for (uint i = 0; i < uint({length}); ++i) {{ {frag}[i] = {value_ident}[{index}]; }} " @@ -185,7 +212,8 @@ for (uint i = 0; i < uint(8); ++i) {{ let frag_idx = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.elem { - Elem::F16 | Elem::BF16 => "elemIdx * 2", + Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::A || inner.ident == FragmentIdent::B => "elemIdx * 2", + Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::Accumulator => "elemIdx", Elem::F32 => "elemIdx", other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."), } From 042850681f3958906fbae070c7c9c061814b040c Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Wed, 27 Nov 2024 21:15:38 -0500 Subject: [PATCH 02/12] Fix format --- .../src/hip/wmma/intrinsic_compiler.rs | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 37d5e1ed..0d2e715f 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -52,11 +52,9 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { ) -> std::fmt::Result { match fragment.ident { FragmentIdent::A | FragmentIdent::B => write!(f, "half16"), - FragmentIdent::Accumulator => { - match fragment.elem { - Elem::F16 => write!(f, "half8"), - _ => write!(f, "float8"), - } + FragmentIdent::Accumulator => match fragment.elem { + Elem::F16 => write!(f, "half8"), + _ => write!(f, "float8"), }, FragmentIdent::_Dialect(_) => Ok(()), } @@ -83,7 +81,12 @@ for (uint i = 0; i < uint(8); ++i) {{ ) } } - WmmaInstruction::Load { frag, value, layout, .. } => { + WmmaInstruction::Load { + frag, + value, + layout, + .. + } => { // Matrix A must be in column major layout // Matrices B, C and D must be in row major layout let item = value.item(); @@ -100,7 +103,7 @@ for (uint i = 0; i < uint(8); ++i) {{ match inner.ident { FragmentIdent::A | FragmentIdent::B => { let index = if (inner.ident == FragmentIdent::A - && inner.layout.unwrap() == FragmentLayout::ColMajor) + && inner.layout.unwrap() == FragmentLayout::ColMajor) || (inner.ident == FragmentIdent::B && inner.layout.unwrap() == FragmentLayout::RowMajor) { @@ -111,7 +114,7 @@ for (uint i = 0; i < uint(8); ++i) {{ "i + wmmaLane * uint(16)" }; (index, 16) - }, + } FragmentIdent::Accumulator => { let length = 8; // For the acc we check layout of the source @@ -126,8 +129,8 @@ for (uint i = 0; i < uint(8); ++i) {{ }, _ => panic!("cannot load data to an accumulator without knowing the layout of the data "), } - }, - other => panic!("unknown matrix identifier {other}") + } + other => panic!("unknown matrix identifier {other}"), } } other => panic!("{other} is not a WMMMA fragment!"), From c9a8fbb5ce84c0e412c4567dc1e2718ba0524bc5 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Wed, 27 Nov 2024 21:27:27 -0500 Subject: [PATCH 03/12] Fix typo in comment --- crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 0d2e715f..dfa0003b 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -218,7 +218,7 @@ for (uint i = 0; i < uint(8); ++i) {{ Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::A || inner.ident == FragmentIdent::B => "elemIdx * 2", Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::Accumulator => "elemIdx", Elem::F32 => "elemIdx", - other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."), + other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."), } }, other => panic!("{frag} is not a WMMA fragment (it is a {other})!") From fac3354c2c4a4302f3e7f1cc3c6dcac5956dfeb1 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Wed, 27 Nov 2024 22:14:30 -0500 Subject: [PATCH 04/12] Track whether frag c has been used or not by the intrinsic --- .../src/hip/wmma/intrinsic_compiler.rs | 18 ++++++++++-------- crates/cubecl-cpp/src/shared/base.rs | 1 + crates/cubecl-cpp/src/shared/mma.rs | 1 + .../tests/cmma_matmul/matmul_test_launcher.rs | 1 + 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index dfa0003b..17a52cfa 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -169,13 +169,15 @@ for (uint i = 0; i < uint(8); ++i) {{ } else { panic!("{frag_a} is not a WMMA fragment!") }; - let cd_format = if let Variable::WmmaFragment { frag: inner_c, .. } = frag_c { - if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d { + let (cd_format, opsel) = if let Variable::WmmaFragment { frag: mut inner_c, .. } = frag_c { + if let Variable::WmmaFragment { frag: mut inner_d, .. } = frag_d { if inner_c.elem == inner_d.elem { + inner_c.amd_intrinsic_computed = true; + inner_d.amd_intrinsic_computed = true; match inner_c.elem { - Elem::F32 => "f32", - Elem::F16 => "f16", - Elem::BF16 => "bf16", + Elem::F32 => ("f32", ""), + Elem::F16 => ("f16", ", false"), + Elem::BF16 => ("bf16", ", false"), other => { panic!("{other} format not supported for {frag_c} and {frag_d}") } @@ -191,7 +193,7 @@ for (uint i = 0; i < uint(8); ++i) {{ }; writeln!( f, - "{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});" + "{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c}{opsel});" ) } WmmaInstruction::Store { @@ -215,8 +217,8 @@ for (uint i = 0; i < uint(8); ++i) {{ let frag_idx = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.elem { - Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::A || inner.ident == FragmentIdent::B => "elemIdx * 2", - Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::Accumulator => "elemIdx", + Elem::F16 | Elem::BF16 if inner.amd_intrinsic_computed => "elemIdx * 2", + Elem::F16 | Elem::BF16 if !inner.amd_intrinsic_computed => "elemIdx", Elem::F32 => "elemIdx", other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."), } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 64fedb65..dcda9405 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -892,6 +892,7 @@ impl CppCompiler { k: matrix.k, elem: self.compile_elem(matrix.elem), layout: self.compile_matrix_layout(matrix.layout), + amd_intrinsic_computed: false, } } diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index 3e06db52..e9d6d7a6 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -88,6 +88,7 @@ pub struct Fragment { pub k: u8, pub elem: Elem, pub layout: Option>, + pub amd_intrinsic_computed: bool, } /// Warp Matrix-Multiply and Accumulate Instruction. diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs index 2f584726..ae216e52 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs @@ -38,6 +38,7 @@ where if A::check_availability::(&client).is_err() { // Can't execute the test. + println!("Skipped - not supported!"); return; } From e24715a0e6ca50192db390e7b9ea321e15200ed7 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 11:46:49 -0500 Subject: [PATCH 05/12] Add more supported types --- .../src/hip/wmma/intrinsic_compiler.rs | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 17a52cfa..58b84951 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -18,9 +18,11 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { } fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("typedef _Float16 half8 __attribute__((ext_vector_type(8)));\n")?; - f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?; - f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n") + f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?; + f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?; + f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?; + f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?; + f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n") } fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -51,10 +53,16 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { match fragment.ident { - FragmentIdent::A | FragmentIdent::B => write!(f, "half16"), + FragmentIdent::A | FragmentIdent::B => match fragment.elem { + Elem::F16 => write!(f, "half16_t"), + Elem::BF16 => write!(f, "bhalf16_t"), + other => panic!("unsupported type {other} for {fragment}"), + } FragmentIdent::Accumulator => match fragment.elem { - Elem::F16 => write!(f, "half8"), - _ => write!(f, "float8"), + Elem::F16 => write!(f, "half16_t"), + Elem::BF16 => write!(f, "bhalf16_t"), + Elem::F32 => write!(f, "float8_t"), + other => panic!("unsupported type {other} for {fragment}"), }, FragmentIdent::_Dialect(_) => Ok(()), } From d3862f94018f25ed760c077ddb5c4de198162a02 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 12:11:44 -0500 Subject: [PATCH 06/12] Additional cast test from Nat --- crates/cubecl-core/src/runtime_tests/cmma.rs | 82 ++++++++++++++++++-- 1 file changed, 76 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index 8737c736..992bce9f 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -5,7 +5,7 @@ use cubecl::{ ir::{Elem, FloatKind}, prelude::*, }; -use half::f16; +use half::{bf16, f16}; #[cube(launch)] /// Executes Out = Lhs @ Rhs.T @@ -88,7 +88,7 @@ pub fn kernel_simple_tf32(lhs: &Array, rhs: &Array, out: &mut Array< } #[cube(launch)] -pub fn cast_matrix(input: &Array, out: &mut Array) { +pub fn cast_matrix_f16(input: &Array, out: &mut Array) { let acc = unsafe { cmma::Matrix::::uninitialized( cmma::MatrixIdent::Accumulator, @@ -110,6 +110,29 @@ pub fn cast_matrix(input: &Array, out: &mut Array) { ); } +#[cube(launch)] +pub fn cast_matrix_bf16(input: &Array, out: &mut Array) { + let acc = unsafe { + cmma::Matrix::::uninitialized( + cmma::MatrixIdent::Accumulator, + 16, + 16, + 16, + cmma::MatrixLayout::Undefined, + ) + }; + cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor); + + let output = cmma::cast::(&acc); + + cmma::store( + &mut out.to_slice_mut(), + &output, + 16, + cmma::MatrixLayout::RowMajor, + ); +} + pub fn test_simple_1( client: ComputeClient, cube_dimensions: CubeDim, @@ -174,7 +197,7 @@ pub fn test_simple_1( assert_eq!(expected, actual); } -pub fn test_cmma_cast_acc( +pub fn test_cmma_cast_f16( client: ComputeClient, cube_dimensions: CubeDim, ) { @@ -195,7 +218,7 @@ pub fn test_cmma_cast_acc( let out = client.empty(core::mem::size_of::() * 256); unsafe { - cast_matrix::launch::( + cast_matrix_f16::launch::( &client, CubeCount::Static(1, 1, 1), cube_dimensions, @@ -211,6 +234,43 @@ pub fn test_cmma_cast_acc( assert_eq!(actual, expected); } +pub fn test_cmma_cast_bf16( + client: ComputeClient, + cube_dimensions: CubeDim, +) { + if !client.properties().feature_enabled(Feature::Cmma { + a: Elem::Float(FloatKind::BF16), + b: Elem::Float(FloatKind::BF16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }) { + // We can't execute the test, skip. + return; + } + + let input: Vec = (0..256).map(|i| i as f32).collect(); + let input = client.create(f32::as_bytes(&input)); + let out = client.empty(core::mem::size_of::() * 256); + + unsafe { + cast_matrix_bf16::launch::( + &client, + CubeCount::Static(1, 1, 1), + cube_dimensions, + ArrayArg::from_raw_parts::(&input, 256, 1), + ArrayArg::from_raw_parts::(&out, 256, 1), + ) + }; + + let actual = client.read_one(out.binding()); + let actual = bf16::from_bytes(&actual); + let expected: Vec = (0..256).map(|i| bf16::from_f32(i as f32)).collect(); + + assert_eq!(actual, expected); +} + pub fn test_simple_tf32( client: ComputeClient, cube_dimensions: CubeDim, @@ -305,10 +365,20 @@ macro_rules! testgen_cmma { } #[test] - fn test_cmma_cast_acc() { + fn test_cmma_cast_f16() { + let client = TestRuntime::client(&Default::default()); + let cube_dimensions = CubeDim::new(32, 1, 1); + cubecl_core::runtime_tests::cmma::test_cmma_cast_f16::( + client, + cube_dimensions, + ); + } + + #[test] + fn test_cmma_cast_bf16() { let client = TestRuntime::client(&Default::default()); let cube_dimensions = CubeDim::new(32, 1, 1); - cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::( + cubecl_core::runtime_tests::cmma::test_cmma_cast_bf16::( client, cube_dimensions, ); From eb94f6ef801f457d2021063123c10cb6d08cba9b Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 13:47:27 -0500 Subject: [PATCH 07/12] More explicit fragment variable names in CPP compiler --- crates/cubecl-cpp/src/shared/element.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 5d774cbe..2d821f51 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -5,7 +5,7 @@ use cubecl_core::{ use half::{bf16, f16}; use std::fmt::Display; -use super::{Dialect, Fragment, COUNTER_TMP_VAR}; +use super::{Dialect, Fragment, FragmentIdent, COUNTER_TMP_VAR}; #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum Elem { @@ -290,9 +290,17 @@ impl Display for Variable { Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"), Variable::WmmaFragment { id: index, - frag: _, + frag, depth, - } => write!(f, "frag_{index}_{depth}"), + } => { + let name = match frag.ident { + FragmentIdent::A => "a", + FragmentIdent::B => "b", + FragmentIdent::Accumulator => "acc", + FragmentIdent::_Dialect(_) => "", + }; + write!(f, "frag_{name}_{index}_{depth}") + }, Variable::GridDimGlobal => f.write_str("gridDimGlobal"), Self::Tmp { id, .. } => write!(f, "_tmp_{id}"), } From fe30e6b10075009e3e3b9fdfc54a776e0cc6e75e Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 13:47:50 -0500 Subject: [PATCH 08/12] Add comments to identify instruction names in hip intrinsic compiler --- .../src/hip/wmma/intrinsic_compiler.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 58b84951..03fbac94 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -77,11 +77,15 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { let fill_with_zeros = matches!(value, Variable::ConstantScalar(number, _) if number.is_zero()); if fill_with_zeros { - writeln!(f, "{frag} = {{}};") + write!( + f, + "// fill +{frag} = {{}}; +") } else { write!( f, - " + "// fill for (uint i = 0; i < uint(8); ++i) {{ {frag}[i] = {value}; }} @@ -145,7 +149,8 @@ for (uint i = 0; i < uint(8); ++i) {{ }; write!( f, - "for (uint i = 0; i < uint({length}); ++i) {{ + "// load +for (uint i = 0; i < uint({length}); ++i) {{ {frag}[i] = {value_ident}[{index}]; }} " @@ -241,7 +246,8 @@ for (uint i = 0; i < uint(8); ++i) {{ }; write!( f, - "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + "// store +for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16); {output_ident}[{output_idx}] = {frag}[{frag_idx}]; }} @@ -251,7 +257,8 @@ for (uint i = 0; i < uint(8); ++i) {{ WmmaInstruction::Cast { input, output } => { write!( f, - "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + "// cast +for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ {output}[elemIdx] = {input}[elemIdx]; }} " From e709a5de5cb6609c0f8a4dd15759838402edf650 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 18:10:01 -0500 Subject: [PATCH 09/12] Correctly cast fragments and load fragments to registers --- .../src/hip/wmma/intrinsic_compiler.rs | 110 ++++++++++++++---- crates/cubecl-cpp/src/shared/base.rs | 1 - crates/cubecl-cpp/src/shared/element.rs | 2 +- crates/cubecl-cpp/src/shared/mma.rs | 1 - 4 files changed, 86 insertions(+), 28 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 03fbac94..746a12e3 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -57,7 +57,7 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { Elem::F16 => write!(f, "half16_t"), Elem::BF16 => write!(f, "bhalf16_t"), other => panic!("unsupported type {other} for {fragment}"), - } + }, FragmentIdent::Accumulator => match fragment.elem { Elem::F16 => write!(f, "half16_t"), Elem::BF16 => write!(f, "bhalf16_t"), @@ -81,7 +81,8 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { f, "// fill {frag} = {{}}; -") +" + ) } else { write!( f, @@ -99,8 +100,30 @@ for (uint i = 0; i < uint(8); ++i) {{ layout, .. } => { - // Matrix A must be in column major layout - // Matrices B, C and D must be in row major layout + // Matrix A must be in column major layout (so fragments correspond to a row) + // Matrices B, C and D must be in row major layout (so fragments correspond to a column) + // + // Each lane is a thread so each column get 8 VGPRs used to store fragments + // Here is the layout for C and D matrices and how they map to registers + // + // Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31 + // -------------------------------------------------------------------------------------------------------------- + // VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16| + // -------------------------------------------------------------------------------------------------------------- let item = value.item(); let mut value_ident = format!("{value}"); if item.vectorization > 1 { @@ -110,37 +133,39 @@ for (uint i = 0; i < uint(8); ++i) {{ )?; value_ident = format!("{value}_half"); } - let (index, length) = match frag { + // TODO: support iu8 and iu4 + let (index, length, step) = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.ident { FragmentIdent::A | FragmentIdent::B => { + let length = 16; + let step = 1; + // fragment a and b are always in half precision and they don't require special attention + // to how they are stored in memory as matrix A and B are also in half precision let index = if (inner.ident == FragmentIdent::A && inner.layout.unwrap() == FragmentLayout::ColMajor) || (inner.ident == FragmentIdent::B && inner.layout.unwrap() == FragmentLayout::RowMajor) { - // correct layout "i * uint(16) + wmmaLane" } else { - // transpose "i + wmmaLane * uint(16)" }; - (index, 16) - } + (index.to_string(), length, step) + }, FragmentIdent::Accumulator => { let length = 8; - // For the acc we check layout of the source - // the acc must be in row format which mean that each lane (thread) is in col format - // moreover even rows are in 1~16 first thread of the wavefront and odd rows are in 17~32 - match layout { + let step = get_output_accumulator_index_step(value, inner); + let index = match layout { Some(FragmentLayout::ColMajor) => { - ("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length) + format!("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)") }, Some(FragmentLayout::RowMajor) => { - ("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length) + format!("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane") }, - _ => panic!("cannot load data to an accumulator without knowing the layout of the data "), - } + _ => panic!("cannot load data to an accumulator without knowing the layout of the data"), + }; + (index, length, step) } other => panic!("unknown matrix identifier {other}"), } @@ -151,7 +176,7 @@ for (uint i = 0; i < uint(8); ++i) {{ f, "// load for (uint i = 0; i < uint({length}); ++i) {{ - {frag}[i] = {value_ident}[{index}]; + {frag}[i * {step}] = {value_ident}[{index}]; }} " ) @@ -182,11 +207,11 @@ for (uint i = 0; i < uint({length}); ++i) {{ } else { panic!("{frag_a} is not a WMMA fragment!") }; - let (cd_format, opsel) = if let Variable::WmmaFragment { frag: mut inner_c, .. } = frag_c { - if let Variable::WmmaFragment { frag: mut inner_d, .. } = frag_d { + let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } = + frag_c + { + if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d { if inner_c.elem == inner_d.elem { - inner_c.amd_intrinsic_computed = true; - inner_d.amd_intrinsic_computed = true; match inner_c.elem { Elem::F32 => ("f32", ""), Elem::F16 => ("f16", ", false"), @@ -230,8 +255,7 @@ for (uint i = 0; i < uint({length}); ++i) {{ let frag_idx = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.elem { - Elem::F16 | Elem::BF16 if inner.amd_intrinsic_computed => "elemIdx * 2", - Elem::F16 | Elem::BF16 if !inner.amd_intrinsic_computed => "elemIdx", + Elem::F16 | Elem::BF16 => "elemIdx * 2", Elem::F32 => "elemIdx", other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."), } @@ -255,11 +279,20 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ ) } WmmaInstruction::Cast { input, output } => { + let step = match output { + Variable::WmmaFragment { frag: inner, .. } => { + match inner.ident { + FragmentIdent::Accumulator => get_output_accumulator_index_step(input, inner), + _ => 1, + } + } + _ => 1, + }; write!( f, "// cast for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ - {output}[elemIdx] = {input}[elemIdx]; + {output}[elemIdx * {step}] = {input}[elemIdx]; }} " ) @@ -299,3 +332,30 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ result } } + +fn get_output_accumulator_index_step( + input: &Variable>, + output: &Fragment>) -> u32 { + // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either: + // - a vector of 8 floats + // - a vector of 16 halfs + // Depending on the precision used for the input, the whole 32 bits per register will be used or + // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means + // that we only assign values to even indexes of the accumulator (0, 2, 4, ...) + + assert_eq!(output.ident, FragmentIdent::>::Accumulator); + + let step = match input.elem() { + Elem::F16 | Elem::BF16 | Elem::F32 => { + match output.elem { + // loading into accumulator of 16 half precision + Elem::F16 | Elem::BF16 => 2, + // loading into accumulator of 8 full precision + Elem::F32 => 1, + other => panic!("unsupported format {other} for {output}"), + } + }, + other => panic!("unsupported format {other} for {input}"), + }; + step +} diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index dcda9405..64fedb65 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -892,7 +892,6 @@ impl CppCompiler { k: matrix.k, elem: self.compile_elem(matrix.elem), layout: self.compile_matrix_layout(matrix.layout), - amd_intrinsic_computed: false, } } diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 2d821f51..cd80b2b7 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -300,7 +300,7 @@ impl Display for Variable { FragmentIdent::_Dialect(_) => "", }; write!(f, "frag_{name}_{index}_{depth}") - }, + } Variable::GridDimGlobal => f.write_str("gridDimGlobal"), Self::Tmp { id, .. } => write!(f, "_tmp_{id}"), } diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index e9d6d7a6..3e06db52 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -88,7 +88,6 @@ pub struct Fragment { pub k: u8, pub elem: Elem, pub layout: Option>, - pub amd_intrinsic_computed: bool, } /// Warp Matrix-Multiply and Accumulate Instruction. From 4a2a2dfb350eba387499cdff939b285ba5ec768b Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 18:12:07 -0500 Subject: [PATCH 10/12] Fix format --- .../src/hip/wmma/intrinsic_compiler.rs | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 746a12e3..266469c6 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -152,7 +152,7 @@ for (uint i = 0; i < uint(8); ++i) {{ "i + wmmaLane * uint(16)" }; (index.to_string(), length, step) - }, + } FragmentIdent::Accumulator => { let length = 8; let step = get_output_accumulator_index_step(value, inner); @@ -280,12 +280,12 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ } WmmaInstruction::Cast { input, output } => { let step = match output { - Variable::WmmaFragment { frag: inner, .. } => { - match inner.ident { - FragmentIdent::Accumulator => get_output_accumulator_index_step(input, inner), - _ => 1, + Variable::WmmaFragment { frag: inner, .. } => match inner.ident { + FragmentIdent::Accumulator => { + get_output_accumulator_index_step(input, inner) } - } + _ => 1, + }, _ => 1, }; write!( @@ -335,7 +335,8 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ fn get_output_accumulator_index_step( input: &Variable>, - output: &Fragment>) -> u32 { + output: &Fragment>, +) -> u32 { // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either: // - a vector of 8 floats // - a vector of 16 halfs @@ -343,7 +344,10 @@ fn get_output_accumulator_index_step( // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means // that we only assign values to even indexes of the accumulator (0, 2, 4, ...) - assert_eq!(output.ident, FragmentIdent::>::Accumulator); + assert_eq!( + output.ident, + FragmentIdent::>::Accumulator + ); let step = match input.elem() { Elem::F16 | Elem::BF16 | Elem::F32 => { @@ -354,7 +358,7 @@ fn get_output_accumulator_index_step( Elem::F32 => 1, other => panic!("unsupported format {other} for {output}"), } - }, + } other => panic!("unsupported format {other} for {input}"), }; step From e4323a8bc423b24e9ee8c1fc287eecfbb6406fbd Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 18:15:34 -0500 Subject: [PATCH 11/12] Fix linting --- .../cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 266469c6..8d9c242d 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -151,18 +151,14 @@ for (uint i = 0; i < uint(8); ++i) {{ } else { "i + wmmaLane * uint(16)" }; - (index.to_string(), length, step) + (index, length, step) } FragmentIdent::Accumulator => { let length = 8; let step = get_output_accumulator_index_step(value, inner); let index = match layout { - Some(FragmentLayout::ColMajor) => { - format!("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)") - }, - Some(FragmentLayout::RowMajor) => { - format!("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane") - }, + Some(FragmentLayout::ColMajor) => "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", + Some(FragmentLayout::RowMajor) => "(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", _ => panic!("cannot load data to an accumulator without knowing the layout of the data"), }; (index, length, step) @@ -349,7 +345,7 @@ fn get_output_accumulator_index_step( FragmentIdent::>::Accumulator ); - let step = match input.elem() { + match input.elem() { Elem::F16 | Elem::BF16 | Elem::F32 => { match output.elem { // loading into accumulator of 16 half precision @@ -360,6 +356,5 @@ fn get_output_accumulator_index_step( } } other => panic!("unsupported format {other} for {input}"), - }; - step + } } From fbfc7b6147d4ff27168d1718bd649206dbe228f6 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 18:16:50 -0500 Subject: [PATCH 12/12] Fix typos --- crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 8d9c242d..8b5ceee5 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -334,8 +334,8 @@ fn get_output_accumulator_index_step( output: &Fragment>, ) -> u32 { // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either: - // - a vector of 8 floats - // - a vector of 16 halfs + // - a vector of 8 float + // - a vector of 16 half // Depending on the precision used for the input, the whole 32 bits per register will be used or // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means // that we only assign values to even indexes of the accumulator (0, 2, 4, ...)