From fd9cb95d6d38ed71dc2d064064eea269328ad490 Mon Sep 17 00:00:00 2001 From: syl20bnr Date: Wed, 27 Nov 2024 21:09:27 -0500 Subject: [PATCH] Add Support for cast instruction in hip wmma intrinsic compiler --- .../src/hip/wmma/intrinsic_compiler.rs | 59 ++++++++++++++----- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs index accb853e..d05b857f 100644 --- a/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs +++ b/crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs @@ -18,6 +18,7 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { } fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("typedef _Float16 half8 __attribute__((ext_vector_type(8)));\n")?; f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?; f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n") } @@ -51,7 +52,12 @@ impl WmmaCompiler> for WmmaIntrinsicCompiler { ) -> std::fmt::Result { match fragment.ident { FragmentIdent::A | FragmentIdent::B => write!(f, "half16"), - FragmentIdent::Accumulator => write!(f, "float8"), + FragmentIdent::Accumulator => { + match fragment.elem { + Elem::F16 => write!(f, "half8"), + _ => write!(f, "float8"), + } + }, FragmentIdent::_Dialect(_) => Ok(()), } } @@ -77,9 +83,9 @@ for (uint i = 0; i < uint(8); ++i) {{ ) } } - WmmaInstruction::Load { frag, value, .. } => { + WmmaInstruction::Load { frag, value, layout, .. } => { // Matrix A must be in column major layout - // Matrix B must be in row major layout + // Matrices B, C and D must be in row major layout let item = value.item(); let mut value_ident = format!("{value}"); if item.vectorization > 1 { @@ -89,25 +95,45 @@ for (uint i = 0; i < uint(8); ++i) {{ )?; value_ident = format!("{value}_half"); } - let index = match frag { + let (index, length) = match frag { Variable::WmmaFragment { frag: inner, .. } => { - if (inner.ident == FragmentIdent::A - && inner.layout.unwrap() == FragmentLayout::ColMajor) - || (inner.ident == FragmentIdent::B - && inner.layout.unwrap() == FragmentLayout::RowMajor) - { - // correct layout - "i * uint(16) + wmmaLane" - } else { - // transpose - "i + wmmaLane * uint(16)" + match inner.ident { + FragmentIdent::A | FragmentIdent::B => { + let index = if (inner.ident == FragmentIdent::A + && inner.layout.unwrap() == FragmentLayout::ColMajor) + || (inner.ident == FragmentIdent::B + && inner.layout.unwrap() == FragmentLayout::RowMajor) + { + // correct layout + "i * uint(16) + wmmaLane" + } else { + // transpose + "i + wmmaLane * uint(16)" + }; + (index, 16) + }, + FragmentIdent::Accumulator => { + let length = 8; + // For the acc we check layout of the source + // the acc must be in row format + match layout { + Some(FragmentLayout::ColMajor) => { + ("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length) + }, + Some(FragmentLayout::RowMajor) => { + ("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length) + }, + _ => panic!("cannot load data to an accumulator without knowing the layout of the data "), + } + }, + other => panic!("unknown matrix identifier {other}") } } other => panic!("{other} is not a WMMMA fragment!"), }; write!( f, - "for (uint i = 0; i < uint(16); ++i) {{ + "for (uint i = 0; i < uint({length}); ++i) {{ {frag}[i] = {value_ident}[{index}]; }} " @@ -185,7 +211,8 @@ for (uint i = 0; i < uint(8); ++i) {{ let frag_idx = match frag { Variable::WmmaFragment { frag: inner, .. } => { match inner.elem { - Elem::F16 | Elem::BF16 => "elemIdx * 2", + Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::A || inner.ident == FragmentIdent::B => "elemIdx * 2", + Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::Accumulator => "elemIdx", Elem::F32 => "elemIdx", other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."), }