Skip to content

Commit

Permalink
Add support for 64 wavefront size in HIP compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr committed Nov 20, 2024
1 parent e1eb00a commit 53e2c03
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 25 deletions.
7 changes: 6 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ pub trait CompilerRepresentation: Display {
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: CompilerRepresentation;
type CompilationOptions: Send + Default + core::fmt::Debug;

/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
fn compile(
kernel: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
Expand Down
30 changes: 23 additions & 7 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ pub trait CubeTask<C: Compiler>: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C>;
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C>;
fn name(&self) -> &'static str {
core::any::type_name::<Self>()
}
Expand All @@ -186,10 +190,14 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
}

impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir, mode);
let lower_level_ir = C::compile(gpu_ir, compilation_options, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();

CompiledKernel {
Expand All @@ -212,8 +220,12 @@ impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
}

impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand All @@ -225,8 +237,12 @@ impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
}

impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ for (uint i = 0; i < uint(8); ++i) {{
frag_b,
frag_c,
frag_d,
warp_size,
} => {
let ab_format = if let Variable::WmmaFragment { frag: inner_a, .. } = frag_a {
if let Variable::WmmaFragment { frag: inner_b, .. } = frag_b {
Expand Down Expand Up @@ -158,10 +159,9 @@ for (uint i = 0; i < uint(8); ++i) {{
} else {
panic!("{frag_c} is not a WMMA fragment!")
};
// TODO: support wavefront size of 64
writeln!(
f,
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w32({frag_a}, {frag_b}, {frag_c});"
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});"
)
}
WmmaInstruction::Store {
Expand Down
10 changes: 10 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ pub trait Dialect:
fn warp_any(out: &IndexedVariable<Self>) -> String;
}

#[derive(Clone, Debug, Default)]
pub struct CompilationOptions {
pub warp_size: u32,
}

#[allow(clippy::too_many_arguments)]
#[derive(Clone, Debug, Default)]
pub struct CppCompiler<D: Dialect> {
Expand All @@ -53,16 +58,20 @@ pub struct CppCompiler<D: Dialect> {
items: HashSet<Item<D>>,
strategy: ExecutionMode,
settings: VariableSettings,
compilation_options: CompilationOptions,
}

impl<D: Dialect> Compiler for CppCompiler<D> {
type Representation = ComputeKernel<D>;
type CompilationOptions = CompilationOptions;

fn compile(
kernel: cubecl_core::ir::KernelDefinition,
compilation_options: &Self::CompilationOptions,
strategy: ExecutionMode,
) -> Self::Representation {
let compiler = Self {
compilation_options: compilation_options.clone(),
strategy,
..Self::default()
};
Expand Down Expand Up @@ -295,6 +304,7 @@ impl<D: Dialect> CppCompiler<D> {
frag_b: self.compile_variable(mat_b),
frag_c: self.compile_variable(mat_c),
frag_d: out,
warp_size: self.compilation_options.warp_size,
}),
gpu::CoopMma::Store {
mat,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cpp/src/shared/mma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ pub fn register_wmma_features(
supported_combinations: SupportedWmmaCombinations,
properties: &mut DeviceProperties<Feature>,
) {
// TODO: move this commented line to register explicitly at the runtime level
// properties.register_feature(Feature::CmmaWarpSize(self.warp_size()));
for (i, o, c, tdims) in supported_combinations {
for (m, n, k) in tdims {
properties.register_feature(Feature::Cmma {
Expand Down Expand Up @@ -115,6 +113,7 @@ pub enum WmmaInstruction<D: Dialect> {
frag_b: Variable<D>,
frag_c: Variable<D>,
frag_d: Variable<D>,
warp_size: u32,
},
/// Store the fragment in an output variable following the stride and the layout.
Store {
Expand Down Expand Up @@ -256,6 +255,7 @@ pub mod wmma_api_base {
frag_b,
frag_c,
frag_d,
..
} => writeln!(
f,
"{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
Expand Down
8 changes: 6 additions & 2 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use cubecl_cpp::shared::CompilationOptions;
use cubecl_cpp::{formatter::format_cpp, CudaCompiler};

use super::fence::{Fence, SyncStream};
Expand Down Expand Up @@ -37,6 +38,7 @@ pub(crate) struct CudaContext {
memory_management: MemoryManagement<CudaStorage>,
module_names: HashMap<KernelId, CompiledKernel>,
timestamps: KernelTimestamps,
compilation_options: CompilationOptions,
pub(crate) arch: u32,
}

Expand Down Expand Up @@ -301,6 +303,7 @@ impl ComputeServer for CudaServer {
impl CudaContext {
pub fn new(
memory_management: MemoryManagement<CudaStorage>,
compilation_options: CompilationOptions,
stream: cudarc::driver::sys::CUstream,
context: *mut CUctx_st,
arch: u32,
Expand All @@ -312,6 +315,7 @@ impl CudaContext {
stream,
arch,
timestamps: KernelTimestamps::Disabled,
compilation_options,
}
}

Expand All @@ -336,7 +340,7 @@ impl CudaContext {
logger: &mut DebugLogger,
mode: ExecutionMode,
) {
let mut kernel_compiled = kernel.compile(mode);
let mut kernel_compiled = kernel.compile(&self.compilation_options, mode);

if logger.is_activated() {
kernel_compiled.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
Expand Down Expand Up @@ -368,7 +372,7 @@ impl CudaContext {
message += format!("\n {line}").as_str();
}
}
let source = kernel.compile(mode).source;
let source = kernel.compile(&self.compilation_options, mode).source;
panic!("{message}\n[Source] \n{source}");
};
cudarc::nvrtc::result::get_ptx(program).unwrap()
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ fn create_client(device: &CudaDevice, options: RuntimeOptions) -> ComputeClient<
options.memory_config,
);

let cuda_ctx = CudaContext::new(memory_management, stream, ctx, arch);
let comp_opts = Default::default();
let cuda_ctx = CudaContext::new(memory_management, comp_opts, stream, ctx, arch);
let mut server = CudaServer::new(cuda_ctx);
let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, hardware_props);
register_supported_types(&mut device_props);
Expand Down
6 changes: 5 additions & 1 deletion crates/cubecl-hip/src/compute/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use cubecl_cpp::formatter::format_cpp;
use cubecl_cpp::shared::CompilationOptions;

use crate::runtime::HipCompiler;

Expand Down Expand Up @@ -37,6 +38,7 @@ pub(crate) struct HipContext {
memory_management: MemoryManagement<HipStorage>,
module_names: HashMap<KernelId, HipCompiledKernel>,
timestamps: KernelTimestamps,
compilation_options: CompilationOptions,
}

#[derive(Debug)]
Expand Down Expand Up @@ -270,6 +272,7 @@ impl ComputeServer for HipServer {
impl HipContext {
pub fn new(
memory_management: MemoryManagement<HipStorage>,
compilation_options: CompilationOptions,
stream: cubecl_hip_sys::hipStream_t,
context: cubecl_hip_sys::hipCtx_t,
) -> Self {
Expand All @@ -279,6 +282,7 @@ impl HipContext {
stream,
context,
timestamps: KernelTimestamps::Disabled,
compilation_options,
}
}

Expand Down Expand Up @@ -307,7 +311,7 @@ impl HipContext {
let func_name = CString::new("kernel".to_string()).unwrap();
// CubeCL compilation
// jitc = just-in-time compiled
let mut jitc_kernel = cube_kernel.compile(mode);
let mut jitc_kernel = cube_kernel.compile(&self.compilation_options, mode);

if logger.is_activated() {
jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
Expand Down
9 changes: 6 additions & 3 deletions crates/cubecl-hip/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
use cubecl_cpp::{
hip::HipDialect,
register_supported_types,
shared::{register_wmma_features, Architecture, CppCompiler, WmmaCompiler},
shared::{register_wmma_features, Architecture, CompilationOptions, CppCompiler, WmmaCompiler},
};

use cubecl_core::{Feature, MemoryConfiguration, Runtime};
Expand Down Expand Up @@ -110,13 +110,16 @@ fn create_client<M: WmmaCompiler<HipDialect<M>>>(
mem_properties.clone(),
options.memory_config,
);
let hip_ctx = HipContext::new(memory_management, stream, ctx);
let server = HipServer::new(hip_ctx);
let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, topology);
register_supported_types(&mut device_props);
let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
register_wmma_features(supported_wmma_combinations, &mut device_props);

let comp_opts = CompilationOptions {
warp_size: arch.warp_size(),
};
let hip_ctx = HipContext::new(memory_management, comp_opts, stream, ctx);
let server = HipServer::new(hip_ctx);
ComputeClient::new(MutexComputeChannel::new(server), device_props)
}

Expand Down
14 changes: 13 additions & 1 deletion crates/cubecl-spirv/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use crate::{
SpirvKernel,
};

#[derive(Clone, Debug, Default)]
pub struct CompilationOptions {}

pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
pub target: Target,
builder: Builder,
Expand All @@ -41,6 +44,7 @@ pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
pub state: LookupTables,
pub ext_meta_pos: Vec<u32>,
pub metadata: Metadata,
compilation_options: CompilationOptions,
}

unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
Expand All @@ -64,6 +68,7 @@ impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
visited: self.visited.clone(),
metadata: self.metadata.clone(),
ext_meta_pos: self.ext_meta_pos.clone(),
compilation_options: self.compilation_options.clone(),
}
}
}
Expand All @@ -85,6 +90,7 @@ impl<T: SpirvTarget> Default for SpirvCompiler<T> {
visited: Default::default(),
metadata: Default::default(),
ext_meta_pos: Default::default(),
compilation_options: Default::default(),
}
}
}
Expand All @@ -105,8 +111,13 @@ impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {

impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
type Representation = SpirvKernel;
type CompilationOptions = CompilationOptions;

fn compile(value: KernelDefinition, mode: ExecutionMode) -> Self::Representation {
fn compile(
value: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation {
let bindings = value
.inputs
.clone()
Expand All @@ -128,6 +139,7 @@ impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
let (module, optimizer) = Self {
mode,
metadata: Metadata::new(num_meta as u32, num_ext),
compilation_options: compilation_options.clone(),
ext_meta_pos,
..Default::default()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
mode
};
log::debug!("Compiling {}", kernel.name());
let compiled = kernel.compile(mode);
let compiled = kernel.compile(&server.compilation_options, mode);
#[cfg(feature = "spirv-dump")]
dump_spirv(&compiled, kernel.name(), kernel.id());
compiled
Expand Down
Loading

0 comments on commit 53e2c03

Please sign in to comment.