From e709a5de5cb6609c0f8a4dd15759838402edf650 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Thu, 28 Nov 2024 18:10:01 -0500 Subject: [PATCH] Correctly cast fragments and load fragments to registers --- .../src/hip/wmma/intrinsic_compiler.rs | 110 ++++++++++++++---- crates/cubecl-cpp/src/shared/base.rs | 1 - crates/cubecl-cpp/src/shared/element.rs | 2 +- crates/cubecl-cpp/src/shared/mma.rs | 1 - 4 files changed, 86 insertions(+), 28 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index 03fbac94..746a12e3 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -57,7 +57,7 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { Elem::F16 => write!(f, "half16_t"), Elem::BF16 => write!(f, "bhalf16_t"), other => panic!("unsupported type {other} for {fragment}"), - } + }, FragmentIdent::Accumulator => match fragment.elem { Elem::F16 => write!(f, "half16_t"), Elem::BF16 => write!(f, "bhalf16_t"), @@ -81,7 +81,8 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { f, "// fill {frag} = {{}}; -") +" + ) } else { write!( f, @@ -99,8 +100,30 @@ for (uint i = 0; i < uint(8); ++i) {{ layout, .. } => { - // Matrix A must be in column major layout - // Matrices B, C and D must be in row major layout + // Matrix A must be in column major layout (so fragments correspond to a row) + // Matrices B, C and D must be in row major layout (so fragments correspond to a column) + // + // Each lane is a thread so each column get 8 VGPRs used to store fragments + // Here is the layout for C and D matrices and how they map to registers + // + // Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31 + // -------------------------------------------------------------------------------------------------------------- + // VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 | + // -------------------------------------------------------------------------------------------------------------- + // VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16| + // -------------------------------------------------------------------------------------------------------------- + // VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16| + // -------------------------------------------------------------------------------------------------------------- let item = value.item(); let mut value_ident = format!("{value}"); if item.vectorization > 1 { @@ -110,37 +133,39 @@ for (uint i = 0; i < uint(8); ++i) {{ )?; value_ident = format!("{value}_half"); } - let (index, length) = match frag { + // TODO: support iu8 and iu4 + let (index, length, step) = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.ident { FragmentIdent::A | FragmentIdent::B => { + let length = 16; + let step = 1; + // fragment a and b are always in half precision and they don't require special attention + // to how they are stored in memory as matrix A and B are also in half precision let index = if (inner.ident == FragmentIdent::A && inner.layout.unwrap() == FragmentLayout::ColMajor) || (inner.ident == FragmentIdent::B && inner.layout.unwrap() == FragmentLayout::RowMajor) { - // correct layout "i * uint(16) + wmmaLane" } else { - // transpose "i + wmmaLane * uint(16)" }; - (index, 16) - } + (index.to_string(), length, step) + }, FragmentIdent::Accumulator => { let length = 8; - // For the acc we check layout of the source - // the acc must be in row format which mean that each lane (thread) is in col format - // moreover even rows are in 1~16 first thread of the wavefront and odd rows are in 17~32 - match layout { + let step = get_output_accumulator_index_step(value, inner); + let index = match layout { Some(FragmentLayout::ColMajor) => { - ("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length) + format!("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)") }, Some(FragmentLayout::RowMajor) => { - ("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length) + format!("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane") }, - _ => panic!("cannot load data to an accumulator without knowing the layout of the data "), - } + _ => panic!("cannot load data to an accumulator without knowing the layout of the data"), + }; + (index, length, step) } other => panic!("unknown matrix identifier {other}"), } @@ -151,7 +176,7 @@ for (uint i = 0; i < uint(8); ++i) {{ f, "// load for (uint i = 0; i < uint({length}); ++i) {{ - {frag}[i] = {value_ident}[{index}]; + {frag}[i * {step}] = {value_ident}[{index}]; }} " ) @@ -182,11 +207,11 @@ for (uint i = 0; i < uint({length}); ++i) {{ } else { panic!("{frag_a} is not a WMMA fragment!") }; - let (cd_format, opsel) = if let Variable::WmmaFragment { frag: mut inner_c, .. } = frag_c { - if let Variable::WmmaFragment { frag: mut inner_d, .. } = frag_d { + let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } = + frag_c + { + if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d { if inner_c.elem == inner_d.elem { - inner_c.amd_intrinsic_computed = true; - inner_d.amd_intrinsic_computed = true; match inner_c.elem { Elem::F32 => ("f32", ""), Elem::F16 => ("f16", ", false"), @@ -230,8 +255,7 @@ for (uint i = 0; i < uint({length}); ++i) {{ let frag_idx = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.elem { - Elem::F16 | Elem::BF16 if inner.amd_intrinsic_computed => "elemIdx * 2", - Elem::F16 | Elem::BF16 if !inner.amd_intrinsic_computed => "elemIdx", + Elem::F16 | Elem::BF16 => "elemIdx * 2", Elem::F32 => "elemIdx", other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."), } @@ -255,11 +279,20 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ ) } WmmaInstruction::Cast { input, output } => { + let step = match output { + Variable::WmmaFragment { frag: inner, .. } => { + match inner.ident { + FragmentIdent::Accumulator => get_output_accumulator_index_step(input, inner), + _ => 1, + } + } + _ => 1, + }; write!( f, "// cast for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ - {output}[elemIdx] = {input}[elemIdx]; + {output}[elemIdx * {step}] = {input}[elemIdx]; }} " ) @@ -299,3 +332,30 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{ result } } + +fn get_output_accumulator_index_step( + input: &Variable>, + output: &Fragment>) -> u32 { + // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either: + // - a vector of 8 floats + // - a vector of 16 halfs + // Depending on the precision used for the input, the whole 32 bits per register will be used or + // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means + // that we only assign values to even indexes of the accumulator (0, 2, 4, ...) + + assert_eq!(output.ident, FragmentIdent::>::Accumulator); + + let step = match input.elem() { + Elem::F16 | Elem::BF16 | Elem::F32 => { + match output.elem { + // loading into accumulator of 16 half precision + Elem::F16 | Elem::BF16 => 2, + // loading into accumulator of 8 full precision + Elem::F32 => 1, + other => panic!("unsupported format {other} for {output}"), + } + }, + other => panic!("unsupported format {other} for {input}"), + }; + step +} diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index dcda9405..64fedb65 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -892,7 +892,6 @@ impl CppCompiler { k: matrix.k, elem: self.compile_elem(matrix.elem), layout: self.compile_matrix_layout(matrix.layout), - amd_intrinsic_computed: false, } } diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 2d821f51..cd80b2b7 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -300,7 +300,7 @@ impl Display for Variable { FragmentIdent::_Dialect(_) => "", }; write!(f, "frag_{name}_{index}_{depth}") - }, + } Variable::GridDimGlobal => f.write_str("gridDimGlobal"), Self::Tmp { id, .. } => write!(f, "_tmp_{id}"), } diff --git a/crates/cubecl-cpp/src/shared/mma.rs b/crates/cubecl-cpp/src/shared/mma.rs index e9d6d7a6..3e06db52 100644 --- a/crates/cubecl-cpp/src/shared/mma.rs +++ b/crates/cubecl-cpp/src/shared/mma.rs @@ -88,7 +88,6 @@ pub struct Fragment { pub k: u8, pub elem: Elem, pub layout: Option>, - pub amd_intrinsic_computed: bool, } /// Warp Matrix-Multiply and Accumulate Instruction.