Skip to content

Commit

Permalink
Correctly cast fragments and load fragments to registers
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr committed Nov 28, 2024
1 parent fe30e6b commit e709a5d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 28 deletions.
110 changes: 85 additions & 25 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
other => panic!("unsupported type {other} for {fragment}"),
}
},
FragmentIdent::Accumulator => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
Expand All @@ -81,7 +81,8 @@ impl WmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
f,
"// fill
{frag} = {{}};
")
"
)
} else {
write!(
f,
Expand All @@ -99,8 +100,30 @@ for (uint i = 0; i < uint(8); ++i) {{
layout,
..
} => {
// Matrix A must be in column major layout
// Matrices B, C and D must be in row major layout
// Matrix A must be in column major layout (so fragments correspond to a row)
// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
//
// Each lane is a thread so each column get 8 VGPRs used to store fragments
// Here is the layout for C and D matrices and how they map to registers
//
// Lane index 0 1 2 3 ... 13 14 15 ... 17 18 ... 30 31
// --------------------------------------------------------------------------------------------------------------
// VGPR0 | 1,1 | 1,2 | 1,3 | 1,4 | ... | 1,13 | 1,14 | 1,15 | ... | 2,1 | 2,2 | ... | 2,15 | 2,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR1 | 3,1 | 3,2 | 3,3 | 3,4 | ... | 3,13 | 3,14 | 3,15 | ... | 4,1 | 4,2 | ... | 4,15 | 4,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR2 | 5,1 | 5,2 | 5,3 | 5,4 | ... | 5,13 | 5,14 | 5,15 | ... | 6,1 | 6,2 | ... | 6,15 | 6,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR3 | 7,1 | 7,2 | 7,3 | 7,4 | ... | 7,13 | 7,14 | 7,15 | ... | 8,1 | 8,2 | ... | 8,15 | 8,16 |
// --------------------------------------------------------------------------------------------------------------
// VGPR4 | 9,1 | 9,2 | 9,3 | 9,4 | ... | 9,13 | 9,14 | 9,15 | ... | 10,1 | 10,2 | ... | 10,15| 10,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR5 | 11,1 | 11,2 | 11,3 | 11,4 | ... | 11,13| 11,14| 11,15| ... | 12,1 | 12,2 | ... | 12,15| 12,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR6 | 13,1 | 13,2 | 13,3 | 13,4 | ... | 13,13| 13,14| 13,15| ... | 14,1 | 14,2 | ... | 14,15| 14,16|
// --------------------------------------------------------------------------------------------------------------
// VGPR7 | 15,1 | 15,2 | 15,3 | 15,4 | ... | 15,13| 15,14| 15,15| ... | 16,1 | 16,2 | ... | 16,15| 16,16|
// --------------------------------------------------------------------------------------------------------------
let item = value.item();
let mut value_ident = format!("{value}");
if item.vectorization > 1 {
Expand All @@ -110,37 +133,39 @@ for (uint i = 0; i < uint(8); ++i) {{
)?;
value_ident = format!("{value}_half");
}
let (index, length) = match frag {
// TODO: support iu8 and iu4
let (index, length, step) = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
match inner.ident {
FragmentIdent::A | FragmentIdent::B => {
let length = 16;
let step = 1;
// fragment a and b are always in half precision and they don't require special attention
// to how they are stored in memory as matrix A and B are also in half precision
let index = if (inner.ident == FragmentIdent::A
&& inner.layout.unwrap() == FragmentLayout::ColMajor)
|| (inner.ident == FragmentIdent::B
&& inner.layout.unwrap() == FragmentLayout::RowMajor)
{
// correct layout
"i * uint(16) + wmmaLane"
} else {
// transpose
"i + wmmaLane * uint(16)"
};
(index, 16)
}
(index.to_string(), length, step)
},
FragmentIdent::Accumulator => {
let length = 8;
// For the acc we check layout of the source
// the acc must be in row format which mean that each lane (thread) is in col format
// moreover even rows are in 1~16 first thread of the wavefront and odd rows are in 17~32
match layout {
let step = get_output_accumulator_index_step(value, inner);
let index = match layout {
Some(FragmentLayout::ColMajor) => {
("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)", length)
format!("(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * uint(16)")
},
Some(FragmentLayout::RowMajor) => {
("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane", length)
format!("(i * uint(2) + threadIdx.x / uint(16)) * uint(16) + wmmaLane")
},
_ => panic!("cannot load data to an accumulator without knowing the layout of the data "),
}
_ => panic!("cannot load data to an accumulator without knowing the layout of the data"),
};
(index, length, step)
}
other => panic!("unknown matrix identifier {other}"),
}
Expand All @@ -151,7 +176,7 @@ for (uint i = 0; i < uint(8); ++i) {{
f,
"// load
for (uint i = 0; i < uint({length}); ++i) {{
{frag}[i] = {value_ident}[{index}];
{frag}[i * {step}] = {value_ident}[{index}];
}}
"
)
Expand Down Expand Up @@ -182,11 +207,11 @@ for (uint i = 0; i < uint({length}); ++i) {{
} else {
panic!("{frag_a} is not a WMMA fragment!")
};
let (cd_format, opsel) = if let Variable::WmmaFragment { frag: mut inner_c, .. } = frag_c {
if let Variable::WmmaFragment { frag: mut inner_d, .. } = frag_d {
let (cd_format, opsel) = if let Variable::WmmaFragment { frag: inner_c, .. } =
frag_c
{
if let Variable::WmmaFragment { frag: inner_d, .. } = frag_d {
if inner_c.elem == inner_d.elem {
inner_c.amd_intrinsic_computed = true;
inner_d.amd_intrinsic_computed = true;
match inner_c.elem {
Elem::F32 => ("f32", ""),
Elem::F16 => ("f16", ", false"),
Expand Down Expand Up @@ -230,8 +255,7 @@ for (uint i = 0; i < uint({length}); ++i) {{
let frag_idx = match frag {
Variable::WmmaFragment { frag: inner, .. } => {
match inner.elem {
Elem::F16 | Elem::BF16 if inner.amd_intrinsic_computed => "elemIdx * 2",
Elem::F16 | Elem::BF16 if !inner.amd_intrinsic_computed => "elemIdx",
Elem::F16 | Elem::BF16 => "elemIdx * 2",
Elem::F32 => "elemIdx",
other => panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported."),
}
Expand All @@ -255,11 +279,20 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
)
}
WmmaInstruction::Cast { input, output } => {
let step = match output {
Variable::WmmaFragment { frag: inner, .. } => {
match inner.ident {
FragmentIdent::Accumulator => get_output_accumulator_index_step(input, inner),
_ => 1,
}
}
_ => 1,
};
write!(
f,
"// cast
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
{output}[elemIdx] = {input}[elemIdx];
{output}[elemIdx * {step}] = {input}[elemIdx];
}}
"
)
Expand Down Expand Up @@ -299,3 +332,30 @@ for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
result
}
}

fn get_output_accumulator_index_step(
input: &Variable<HipDialect<WmmaIntrinsicCompiler>>,
output: &Fragment<HipDialect<WmmaIntrinsicCompiler>>) -> u32 {
// Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
// - a vector of 8 floats
// - a vector of 16 halfs
// Depending on the precision used for the input, the whole 32 bits per register will be used or
// just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
// that we only assign values to even indexes of the accumulator (0, 2, 4, ...)

assert_eq!(output.ident, FragmentIdent::<HipDialect<WmmaIntrinsicCompiler>>::Accumulator);

let step = match input.elem() {
Elem::F16 | Elem::BF16 | Elem::F32 => {
match output.elem {
// loading into accumulator of 16 half precision
Elem::F16 | Elem::BF16 => 2,
// loading into accumulator of 8 full precision
Elem::F32 => 1,
other => panic!("unsupported format {other} for {output}"),
}
},
other => panic!("unsupported format {other} for {input}"),
};
step
}
1 change: 0 additions & 1 deletion crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,6 @@ impl<D: Dialect> CppCompiler<D> {
k: matrix.k,
elem: self.compile_elem(matrix.elem),
layout: self.compile_matrix_layout(matrix.layout),
amd_intrinsic_computed: false,
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/shared/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl<D: Dialect> Display for Variable<D> {
FragmentIdent::_Dialect(_) => "",
};
write!(f, "frag_{name}_{index}_{depth}")
},
}
Variable::GridDimGlobal => f.write_str("gridDimGlobal"),
Self::Tmp { id, .. } => write!(f, "_tmp_{id}"),
}
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-cpp/src/shared/mma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ pub struct Fragment<D: Dialect> {
pub k: u8,
pub elem: Elem<D>,
pub layout: Option<FragmentLayout<D>>,
pub amd_intrinsic_computed: bool,
}

/// Warp Matrix-Multiply and Accumulate Instruction.
Expand Down

0 comments on commit e709a5d

Please sign in to comment.