Skip to content

Commit

Permalink
Add Support for cast instruction in hip wmma intrinsic compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr committed Nov 28, 2024
1 parent 5296f55 commit fd9cb95
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
}

fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("typedef _Float16 half8 __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n")
}
Expand Down Expand Up @@ -51,7 +52,12 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
) -> std::fmt::Result {
match fragment.ident {
FragmentIdent::A | FragmentIdent::B => write!(f, "half16"),
FragmentIdent::Accumulator => write!(f, "float8"),
FragmentIdent::Accumulator => {
match fragment.elem {
Elem::F16 => write!(f, "half8"),
_ => write!(f, "float8"),
}
},
FragmentIdent::_Dialect(_) => Ok(()),
}
}
Expand All @@ -77,9 +83,9 @@ for (uint i = 0; i < uint(8); ++i) {{
)
}
}
WmmaInstruction::Load { frag, value, .. } => {
WmmaInstruction::Load { frag, value, layout, .. } => {
// Matrix A must be in column major layout
// Matrix B must be in row major layout
// Matrices B, C and D must be in row major layout
let item = value.item();
let mut value_ident = format!("{value}");
if item.vectorization > 1 {
Expand All @@ -89,25 +95,45 @@ for (uint i = 0; i < uint(8); ++i) {{
)?;
value_ident = format!("{value}_half");
}
let index = match frag {
let (index, length) = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
// correct layout
"i * uint(16) + wmmaLane"
} else {
// transpose
"i + wmmaLane * uint(16)"
match inner.ident {
FragmentIdent::A | FragmentIdent::B => {
let index = if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
// correct layout
"i * uint(16) + wmmaLane"
} else {
// transpose
"i + wmmaLane * uint(16)"
};
(index, 16)
},
FragmentIdent::Accumulator => {
let length = 8;
// For the acc we check layout of the source
// the acc must be in row format
match layout {
Some(FragmentLayout::ColMajor) => {
("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length)
},
Some(FragmentLayout::RowMajor) => {
("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length)
},
_ => panic!("cannot load data to an accumulator without knowing the layout of the data "),
}
},
other => panic!("unknown matrix identifier {other}")
}
}
other => panic!("{other} is not a WMMMA fragment!"),
};
write!(
f,
"for (uint i = 0; i < uint(16); ++i) {{
"for (uint i = 0; i < uint({length}); ++i) {{
{frag}[i] = {value_ident}[{index}];
}}
"
Expand Down Expand Up @@ -185,7 +211,8 @@ for (uint i = 0; i < uint(8); ++i) {{
let frag_idx = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
match inner.elem {
Elem::F16 | Elem::BF16 => "elemIdx * 2",
Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::A || inner.ident == FragmentIdent::B => "elemIdx * 2",
Elem::F16 | Elem::BF16 if inner.ident == FragmentIdent::Accumulator => "elemIdx",
Elem::F32 => "elemIdx",
other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."),
}
Expand Down

0 comments on commit fd9cb95

Please sign in to comment.