Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for cast instruction in hip wmma intrinsic compiler #317

Merged
merged 12 commits into from
Nov 29, 2024
82 changes: 76 additions & 6 deletions crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cubecl::{
ir::{Elem, FloatKind},
prelude::*,
};
use half::f16;
use half::{bf16, f16};

#[cube(launch)]
/// Executes Out = Lhs @ Rhs.T
Expand Down Expand Up @@ -88,7 +88,7 @@ pub fn kernel_simple_tf32(lhs: &Array<tf32>, rhs: &Array<tf32>, out: &mut Array<
}

#[cube(launch)]
pub fn cast_matrix(input: &Array<f32>, out: &mut Array<f16>) {
pub fn cast_matrix_f16(input: &Array<f32>, out: &mut Array<f16>) {
let acc = unsafe {
cmma::Matrix::<f32>::uninitialized(
cmma::MatrixIdent::Accumulator,
Expand All @@ -110,6 +110,29 @@ pub fn cast_matrix(input: &Array<f32>, out: &mut Array<f16>) {
);
}

#[cube(launch)]
pub fn cast_matrix_bf16(input: &Array<f32>, out: &mut Array<bf16>) {
let acc = unsafe {
cmma::Matrix::<f32>::uninitialized(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
)
};
cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor);

let output = cmma::cast::<f32, bf16>(&acc);

cmma::store(
&mut out.to_slice_mut(),
&output,
16,
cmma::MatrixLayout::RowMajor,
);
}

pub fn test_simple_1<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand Down Expand Up @@ -174,7 +197,7 @@ pub fn test_simple_1<R: Runtime>(
assert_eq!(expected, actual);
}

pub fn test_cmma_cast_acc<R: Runtime>(
pub fn test_cmma_cast_f16<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
Expand All @@ -195,7 +218,7 @@ pub fn test_cmma_cast_acc<R: Runtime>(
let out = client.empty(core::mem::size_of::<f16>() * 256);

unsafe {
cast_matrix::launch::<R>(
cast_matrix_f16::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
Expand All @@ -211,6 +234,43 @@ pub fn test_cmma_cast_acc<R: Runtime>(
assert_eq!(actual, expected);
}

pub fn test_cmma_cast_bf16<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
if !client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::BF16),
b: Elem::Float(FloatKind::BF16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
}) {
// We can't execute the test, skip.
return;
}

let input: Vec<f32> = (0..256).map(|i| i as f32).collect();
let input = client.create(f32::as_bytes(&input));
let out = client.empty(core::mem::size_of::<f16>() * 256);

unsafe {
cast_matrix_bf16::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
ArrayArg::from_raw_parts::<f32>(&input, 256, 1),
ArrayArg::from_raw_parts::<f16>(&out, 256, 1),
)
};

let actual = client.read_one(out.binding());
let actual = bf16::from_bytes(&actual);
let expected: Vec<bf16> = (0..256).map(|i| bf16::from_f32(i as f32)).collect();

assert_eq!(actual, expected);
}

pub fn test_simple_tf32<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand Down Expand Up @@ -305,10 +365,20 @@ macro_rules! testgen_cmma {
}

#[test]
fn test_cmma_cast_acc() {
fn test_cmma_cast_f16() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(32, 1, 1);
cubecl_core::runtime_tests::cmma::test_cmma_cast_f16::<TestRuntime>(
client,
cube_dimensions,
);
}

#[test]
fn test_cmma_cast_bf16() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(32, 1, 1);
cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::<TestRuntime>(
cubecl_core::runtime_tests::cmma::test_cmma_cast_bf16::<TestRuntime>(
client,
cube_dimensions,
);
Expand Down
169 changes: 138 additions & 31 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
}

fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("typedef _Float16 half16 __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef float float8 __attribute__((ext_vector_type(8)));\n")
f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
}

fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -50,8 +53,17 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match fragment.ident {
FragmentIdent::A | FragmentIdent::B => write!(f, "half16"),
FragmentIdent::Accumulator => write!(f, "float8"),
FragmentIdent::A | FragmentIdent::B => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::Accumulator => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
Elem::F32 => write!(f, "float8_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::_Dialect(_) => Ok(()),
}
}
Expand All @@ -65,21 +77,53 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
let fill_with_zeros =
matches!(value, Variable::ConstantScalar(number, _) if number.is_zero());
if fill_with_zeros {
writeln!(f, "{frag} = {{}};")
write!(
f,
"// fill
{frag} = {{}};
"
)
} else {
write!(
f,
"
"// fill
for (uint i = 0; i < uint(8); ++i) {{
{frag}[i] = {value};
}}
"
)
}
}
WmmaInstruction::Load { frag, value, .. } => {
// Matrix A must be in column major layout
// Matrix B must be in row major layout
WmmaInstruction::Load {
frag,
value,
layout,
..
} => {
// Matrix A must be in column major layout (so fragments correspond to a row)
// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
//
// Each lane is a thread so each column get 8 VGPRs used to store fragments
// Here is the layout for C and D matrices and how they map to registers
//
// Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31
// --------------------------------------------------------------------------------------------------------------
// VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16|
// --------------------------------------------------------------------------------------------------------------
let item = value.item();
let mut value_ident = format!("{value}");
if item.vectorization > 1 {
Expand All @@ -89,26 +133,46 @@ for (uint i = 0; i < uint(8); ++i) {{
)?;
value_ident = format!("{value}_half");
}
let index = match frag {
// TODO: support iu8 and iu4
let (index, length, step) = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
// correct layout
"i * uint(16) + wmmaLane"
} else {
// transpose
"i + wmmaLane * uint(16)"
match inner.ident {
FragmentIdent::A | FragmentIdent::B => {
let length = 16;
let step = 1;
// fragment a and b are always in half precision and they don't require special attention
// to how they are stored in memory as matrix A and B are also in half precision
let index = if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
"i * uint(16) + wmmaLane"
} else {
"i + wmmaLane * uint(16)"
};
(index, length, step)
}
FragmentIdent::Accumulator => {
let length = 8;
let step = get_output_accumulator_index_step(value, inner);
let index = match layout {
Some(FragmentLayout::ColMajor) => "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)",
Some(FragmentLayout::RowMajor) => "(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane",
_ => panic!("cannot load data to an accumulator without knowing the layout of the data"),
};
(index, length, step)
}
other => panic!("unknown matrix identifier {other}"),
}
}
other => panic!("{other} is not a WMMMA fragment!"),
};
write!(
f,
"for (uint i = 0; i < uint(16); ++i) {{
{frag}[i] = {value_ident}[{index}];
"// load
for (uint i = 0; i < uint({length}); ++i) {{
{frag}[i * {step}] = {value_ident}[{index}];
}}
"
)
Expand Down Expand Up @@ -139,13 +203,15 @@ for (uint i = 0; i < uint(8); ++i) {{
} else {
panic!("{frag_a} is not a WMMA fragment!")
};
let cd_format = if let Variable::WmmaFragment { frag: inner_c, .. } = frag_c {
let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } =
frag_c
{
if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d {
if inner_c.elem == inner_d.elem {
match inner_c.elem {
Elem::F32 => "f32",
Elem::F16 => "f16",
Elem::BF16 => "bf16",
Elem::F32 => ("f32", ""),
Elem::F16 => ("f16", ", false"),
Elem::BF16 => ("bf16", ", false"),
other => {
panic!("{other} format not supported for {frag_c} and {frag_d}")
}
Expand All @@ -161,7 +227,7 @@ for (uint i = 0; i < uint(8); ++i) {{
};
writeln!(
f,
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});"
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c}{opsel});"
)
}
WmmaInstruction::Store {
Expand All @@ -187,7 +253,7 @@ for (uint i = 0; i < uint(8); ++i) {{
match inner.elem {
Elem::F16 | Elem::BF16 => "elemIdx * 2",
Elem::F32 => "elemIdx",
other => panic!("C fragment format can be {other}. Only f16, bf16 and f32 are supported."),
other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."),
}
},
other => panic!("{frag} is not a WMMA fragment (it is a {other})!")
Expand All @@ -200,18 +266,29 @@ for (uint i = 0; i < uint(8); ++i) {{
};
write!(
f,
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
"// store
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
{output_ident}[{output_idx}] = {frag}[{frag_idx}];
}}
"
)
}
WmmaInstruction::Cast { input, output } => {
let step = match output {
Variable::WmmaFragment { frag: inner, .. } => match inner.ident {
FragmentIdent::Accumulator => {
get_output_accumulator_index_step(input, inner)
}
_ => 1,
},
_ => 1,
};
write!(
f,
"for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx] = {input}[elemIdx];
"// cast
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx * {step}] = {input}[elemIdx];
}}
"
)
Expand Down Expand Up @@ -251,3 +328,33 @@ for (uint i = 0; i < uint(8); ++i) {{
result
}
}

fn get_output_accumulator_index_step(
input: &Variable<HipDialect<WmmaIntrinsicCompiler>>,
output: &Fragment<HipDialect<WmmaIntrinsicCompiler>>,
) -> u32 {
// Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
// - a vector of 8 float
// - a vector of 16 half
// Depending on the precision used for the input, the whole 32 bits per register will be used or
// just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
// that we only assign values to even indexes of the accumulator (0, 2, 4, ...)

assert_eq!(
output.ident,
FragmentIdent::<HipDialect<WmmaIntrinsicCompiler>>::Accumulator
);

match input.elem() {
Elem::F16 | Elem::BF16 | Elem::F32 => {
match output.elem {
// loading into accumulator of 16 half precision
Elem::F16 | Elem::BF16 => 2,
// loading into accumulator of 8 full precision
Elem::F32 => 1,
other => panic!("unsupported format {other} for {output}"),
}
}
other => panic!("unsupported format {other} for {input}"),
}
}
Loading
Loading