From a4e2b77dcc1c91e31ca95a8d55454f8c49e1f4f6 Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Thu, 28 Nov 2024 19:43:00 -0500 Subject: [PATCH] Add Support for cast instruction in hip wmma intrinsic compiler (#317) --- .../src/hip/wmma/intrinsic_compiler.rs | 169 ++++++++++++++---- crates/cubecl-cpp/src/shared/element.rs | 14 +- .../tests/cmma_matmul/matmul_test_launcher.rs | 1 + 3 files changed, 150 insertions(+), 34 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index accb853e5..8b5ceee56 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -18,8 +18,11 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { } fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?; - f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n") + f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?; + f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?; + f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?; + f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?; + f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n") } fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -50,8 +53,17 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { match fragment.ident { - FragmentIdent::A | FragmentIdent::B => write!(f, "half16"), - FragmentIdent::Accumulator => write!(f, "float8"), + FragmentIdent::A | FragmentIdent::B => match fragment.elem { + Elem::F16 => write!(f, "half16_t"), + Elem::BF16 => write!(f, "bhalf16_t"), + other => panic!("unsupported type {other} for {fragment}"), + }, + FragmentIdent::Accumulator => match fragment.elem { + Elem::F16 => write!(f, "half16_t"), + Elem::BF16 => write!(f, "bhalf16_t"), + Elem::F32 => write!(f, "float8_t"), + other => panic!("unsupported type {other} for {fragment}"), + }, FragmentIdent::_Dialect(_) => Ok(()), } } @@ -65,11 +77,16 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { let fill_with_zeros = matches!(value, Variable::ConstantScalar(number, _) if number.is_zero()); if fill_with_zeros { - writeln!(f, "{frag} = {{}};") + write!( + f, + "// fill +{frag} = {{}}; +" + ) } else { write!( f, - " + "// fill for (uint i = 0; i < uint(8); ++i) {{ {frag}[i] = {value}; }} @@ -77,9 +94,36 @@ for (uint i = 0; i < uint(8); ++i) {{ ) } } - WmmaInstruction::Load { frag, value, .. } => { - // Matrix A must be in column major layout - // Matrix B must be in row major layout + WmmaInstruction::Load { + frag, + value, + layout, + .. + } => { + // Matrix A must be in column major layout (so fragments correspond to a row) + // Matrices B, C and D must be in row major layout (so fragments correspond to a column) + // + // Each lane is a thread so each column get 8 VGPRs used to store fragments + // Here is the layout for C and D matrices and how they map to registers + // + // Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31 + // -------------------------------------------------------------------------------------------------------------- + // VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16| + // -------------------------------------------------------------------------------------------------------------- let item = value.item(); let mut value_ident = format!("{value}"); if item.vectorization > 1 { @@ -89,26 +133,46 @@ for (uint i = 0; i < uint(8); ++i) {{ )?; value_ident = format!("{value}_half"); } - let index = match frag { + // TODO: support iu8 and iu4 + let (index, length, step) = match frag { Variable::WmmaFragment { frag: inner, .. } => { - if (inner.ident == FragmentIdent::A - && inner.layout.unwrap() == FragmentLayout::ColMajor) - || (inner.ident == FragmentIdent::B - && inner.layout.unwrap() == FragmentLayout::RowMajor) - { - // correct layout - "i * uint(16) + wmmaLane" - } else { - // transpose - "i + wmmaLane * uint(16)" + match inner.ident { + FragmentIdent::A | FragmentIdent::B => { + let length = 16; + let step = 1; + // fragment a and b are always in half precision and they don't require special attention + // to how they are stored in memory as matrix A and B are also in half precision + let index = if (inner.ident == FragmentIdent::A + && inner.layout.unwrap() == FragmentLayout::ColMajor) + || (inner.ident == FragmentIdent::B + && inner.layout.unwrap() == FragmentLayout::RowMajor) + { + "i * uint(16) + wmmaLane" + } else { + "i + wmmaLane * uint(16)" + }; + (index, length, step) + } + FragmentIdent::Accumulator => { + let length = 8; + let step = get_output_accumulator_index_step(value, inner); + let index = match layout { + Some(FragmentLayout::ColMajor) => "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", + Some(FragmentLayout::RowMajor) => "(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", + _ => panic!("cannot load data to an accumulator without knowing the layout of the data"), + }; + (index, length, step) + } + other => panic!("unknown matrix identifier {other}"), } } other => panic!("{other} is not a WMMMA fragment!"), }; write!( f, - "for (uint i = 0; i < uint(16); ++i) {{ - {frag}[i] = {value_ident}[{index}]; + "// load +for (uint i = 0; i < uint({length}); ++i) {{ + {frag}[i * {step}] = {value_ident}[{index}]; }} " ) @@ -139,13 +203,15 @@ for (uint i = 0; i < uint(8); ++i) {{ } else { panic!("{frag_a} is not a WMMA fragment!") }; - let cd_format = if let Variable::WmmaFragment { frag: inner_c, .. } = frag_c { + let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } = + frag_c + { if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d { if inner_c.elem == inner_d.elem { match inner_c.elem { - Elem::F32 => "f32", - Elem::F16 => "f16", - Elem::BF16 => "bf16", + Elem::F32 => ("f32", ""), + Elem::F16 => ("f16", ", false"), + Elem::BF16 => ("bf16", ", false"), other => { panic!("{other} format not supported for {frag_c} and {frag_d}") } @@ -161,7 +227,7 @@ for (uint i = 0; i < uint(8); ++i) {{ }; writeln!( f, - "{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});" + "{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c}{opsel});" ) } WmmaInstruction::Store { @@ -187,7 +253,7 @@ for (uint i = 0; i < uint(8); ++i) {{ match inner.elem { Elem::F16 | Elem::BF16 => "elemIdx * 2", Elem::F32 => "elemIdx", - other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."), + other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."), } }, other => panic!("{frag} is not a WMMA fragment (it is a {other})!") @@ -200,7 +266,8 @@ for (uint i = 0; i < uint(8); ++i) {{ }; write!( f, - "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + "// store +for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16); {output_ident}[{output_idx}] = {frag}[{frag_idx}]; }} @@ -208,10 +275,20 @@ for (uint i = 0; i < uint(8); ++i) {{ ) } WmmaInstruction::Cast { input, output } => { + let step = match output { + Variable::WmmaFragment { frag: inner, .. } => match inner.ident { + FragmentIdent::Accumulator => { + get_output_accumulator_index_step(input, inner) + } + _ => 1, + }, + _ => 1, + }; write!( f, - "for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ - {output}[elemIdx] = {input}[elemIdx]; + "// cast +for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ + {output}[elemIdx * {step}] = {input}[elemIdx]; }} " ) @@ -251,3 +328,33 @@ for (uint i = 0; i < uint(8); ++i) {{ result } } + +fn get_output_accumulator_index_step( + input: &Variable>, + output: &Fragment>, +) -> u32 { + // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either: + // - a vector of 8 float + // - a vector of 16 half + // Depending on the precision used for the input, the whole 32 bits per register will be used or + // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means + // that we only assign values to even indexes of the accumulator (0, 2, 4, ...) + + assert_eq!( + output.ident, + FragmentIdent::>::Accumulator + ); + + match input.elem() { + Elem::F16 | Elem::BF16 | Elem::F32 => { + match output.elem { + // loading into accumulator of 16 half precision + Elem::F16 | Elem::BF16 => 2, + // loading into accumulator of 8 full precision + Elem::F32 => 1, + other => panic!("unsupported format {other} for {output}"), + } + } + other => panic!("unsupported format {other} for {input}"), + } +} diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 5d774cbe7..cd80b2b70 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -5,7 +5,7 @@ use cubecl_core::{ use half::{bf16, f16}; use std::fmt::Display; -use super::{Dialect, Fragment, COUNTER_TMP_VAR}; +use super::{Dialect, Fragment, FragmentIdent, COUNTER_TMP_VAR}; #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum Elem { @@ -290,9 +290,17 @@ impl Display for Variable { Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"), Variable::WmmaFragment { id: index, - frag: _, + frag, depth, - } => write!(f, "frag_{index}_{depth}"), + } => { + let name = match frag.ident { + FragmentIdent::A => "a", + FragmentIdent::B => "b", + FragmentIdent::Accumulator => "acc", + FragmentIdent::_Dialect(_) => "", + }; + write!(f, "frag_{name}_{index}_{depth}") + } Variable::GridDimGlobal => f.write_str("gridDimGlobal"), Self::Tmp { id, .. } => write!(f, "_tmp_{id}"), } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs index 2f5847260..ae216e52d 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma_matmul/matmul_test_launcher.rs @@ -38,6 +38,7 @@ where if A::check_availability::(&client).is_err() { // Can't execute the test. + println!("Skipped - not supported!"); return; }