Skip to content

Commit

Permalink
Add Support for cast instruction in hip wmma intrinsic compiler (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr authored Nov 29, 2024
1 parent 140ab04 commit a4e2b77
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
169 changes: 138 additions & 31 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
}

fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n")
f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
}

fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -50,8 +53,17 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match fragment.ident {
FragmentIdent::A | FragmentIdent::B => write!(f, "half16"),
FragmentIdent::Accumulator => write!(f, "float8"),
FragmentIdent::A | FragmentIdent::B => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::Accumulator => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
Elem::F32 => write!(f, "float8_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::_Dialect(_) => Ok(()),
}
}
Expand All @@ -65,21 +77,53 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
let fill_with_zeros =
matches!(value, Variable::ConstantScalar(number, _) if number.is_zero());
if fill_with_zeros {
writeln!(f, "{frag} = {{}};")
write!(
f,
"// fill
{frag} = {{}};
"
)
} else {
write!(
f,
"
"// fill
for (uint i = 0; i < uint(8); ++i) {{
{frag}[i] = {value};
}}
"
)
}
}
WmmaInstruction::Load { frag, value, .. } => {
// Matrix A must be in column major layout
// Matrix B must be in row major layout
WmmaInstruction::Load {
frag,
value,
layout,
..
} => {
// Matrix A must be in column major layout (so fragments correspond to a row)
// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
//
// Each lane is a thread so each column get 8 VGPRs used to store fragments
// Here is the layout for C and D matrices and how they map to registers
//
// Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31
// --------------------------------------------------------------------------------------------------------------
// VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16|
// --------------------------------------------------------------------------------------------------------------
let item = value.item();
let mut value_ident = format!("{value}");
if item.vectorization > 1 {
Expand All @@ -89,26 +133,46 @@ for (uint i = 0; i < uint(8); ++i) {{
)?;
value_ident = format!("{value}_half");
}
let index = match frag {
// TODO: support iu8 and iu4
let (index, length, step) = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
// correct layout
"i * uint(16) + wmmaLane"
} else {
// transpose
"i + wmmaLane * uint(16)"
match inner.ident {
FragmentIdent::A | FragmentIdent::B => {
let length = 16;
let step = 1;
// fragment a and b are always in half precision and they don't require special attention
// to how they are stored in memory as matrix A and B are also in half precision
let index = if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
"i * uint(16) + wmmaLane"
} else {
"i + wmmaLane * uint(16)"
};
(index, length, step)
}
FragmentIdent::Accumulator => {
let length = 8;
let step = get_output_accumulator_index_step(value, inner);
let index = match layout {
Some(FragmentLayout::ColMajor) => "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)",
Some(FragmentLayout::RowMajor) => "(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane",
_ => panic!("cannot load data to an accumulator without knowing the layout of the data"),
};
(index, length, step)
}
other => panic!("unknown matrix identifier {other}"),
}
}
other => panic!("{other} is not a WMMMA fragment!"),
};
write!(
f,
"for (uint i = 0; i < uint(16); ++i) {{
{frag}[i] = {value_ident}[{index}];
"// load
for (uint i = 0; i < uint({length}); ++i) {{
{frag}[i * {step}] = {value_ident}[{index}];
}}
"
)
Expand Down Expand Up @@ -139,13 +203,15 @@ for (uint i = 0; i < uint(8); ++i) {{
} else {
panic!("{frag_a} is not a WMMA fragment!")
};
let cd_format = if let Variable::WmmaFragment { frag: inner_c, .. } = frag_c {
let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } =
frag_c
{
if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d {
if inner_c.elem == inner_d.elem {
match inner_c.elem {
Elem::F32 => "f32",
Elem::F16 => "f16",
Elem::BF16 => "bf16",
Elem::F32 => ("f32", ""),
Elem::F16 => ("f16", ", false"),
Elem::BF16 => ("bf16", ", false"),
other => {
panic!("{other} format not supported for {frag_c} and {frag_d}")
}
Expand All @@ -161,7 +227,7 @@ for (uint i = 0; i < uint(8); ++i) {{
};
writeln!(
f,
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});"
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c}{opsel});"
)
}
WmmaInstruction::Store {
Expand All @@ -187,7 +253,7 @@ for (uint i = 0; i < uint(8); ++i) {{
match inner.elem {
Elem::F16 | Elem::BF16 => "elemIdx * 2",
Elem::F32 => "elemIdx",
other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."),
other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."),
}
},
other => panic!("{frag} is not a WMMA fragment (it is a {other})!")
Expand All @@ -200,18 +266,29 @@ for (uint i = 0; i < uint(8); ++i) {{
};
write!(
f,
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
"// store
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
{output_ident}[{output_idx}] = {frag}[{frag_idx}];
}}
"
)
}
WmmaInstruction::Cast { input, output } => {
let step = match output {
Variable::WmmaFragment { frag: inner, .. } => match inner.ident {
FragmentIdent::Accumulator => {
get_output_accumulator_index_step(input, inner)
}
_ => 1,
},
_ => 1,
};
write!(
f,
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx] = {input}[elemIdx];
"// cast
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx * {step}] = {input}[elemIdx];
}}
"
)
Expand Down Expand Up @@ -251,3 +328,33 @@ for (uint i = 0; i < uint(8); ++i) {{
result
}
}

fn get_output_accumulator_index_step(
input: &Variable<HipDialect<WmmaIntrinsicCompiler>>,
output: &Fragment<HipDialect<WmmaIntrinsicCompiler>>,
) -> u32 {
// Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
// - a vector of 8 float
// - a vector of 16 half
// Depending on the precision used for the input, the whole 32 bits per register will be used or
// just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
// that we only assign values to even indexes of the accumulator (0, 2, 4, ...)

assert_eq!(
output.ident,
FragmentIdent::<HipDialect<WmmaIntrinsicCompiler>>::Accumulator
);

match input.elem() {
Elem::F16 | Elem::BF16 | Elem::F32 => {
match output.elem {
// loading into accumulator of 16 half precision
Elem::F16 | Elem::BF16 => 2,
// loading into accumulator of 8 full precision
Elem::F32 => 1,
other => panic!("unsupported format {other} for {output}"),
}
}
other => panic!("unsupported format {other} for {input}"),
}
}
14 changes: 11 additions & 3 deletions crates/cubecl-cpp/src/shared/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cubecl_core::{
use half::{bf16, f16};
use std::fmt::Display;

use super::{Dialect, Fragment, COUNTER_TMP_VAR};
use super::{Dialect, Fragment, FragmentIdent, COUNTER_TMP_VAR};

#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum Elem<D: Dialect> {
Expand Down Expand Up @@ -290,9 +290,17 @@ impl<D: Dialect> Display for Variable<D> {
Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"),
Variable::WmmaFragment {
id: index,
frag: _,
frag,
depth,
} => write!(f, "frag_{index}_{depth}"),
} => {
let name = match frag.ident {
FragmentIdent::A => "a",
FragmentIdent::B => "b",
FragmentIdent::Accumulator => "acc",
FragmentIdent::_Dialect(_) => "",
};
write!(f, "frag_{name}_{index}_{depth}")
}
Variable::GridDimGlobal => f.write_str("gridDimGlobal"),
Self::Tmp { id, .. } => write!(f, "_tmp_{id}"),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ where

if A::check_availability::<R>(&client).is_err() {
// Can't execute the test.
println!("Skipped - not supported!");
return;
}

Expand Down

0 comments on commit a4e2b77

Please sign in to comment.