Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 64 wavefront size in HIP compiler #282

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ pub trait CompilerRepresentation: Display {
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: CompilerRepresentation;
type CompilationOptions: Send + Default + core::fmt::Debug;

/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
fn compile(
kernel: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
Expand Down
30 changes: 23 additions & 7 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ pub trait CubeTask<C: Compiler>: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C>;
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C>;
fn name(&self) -> &'static str {
core::any::type_name::<Self>()
}
Expand All @@ -186,10 +190,14 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
}

impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir, mode);
let lower_level_ir = C::compile(gpu_ir, compilation_options, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();

CompiledKernel {
Expand All @@ -212,8 +220,12 @@ impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
}

impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand All @@ -225,8 +237,12 @@ impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
}

impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cpp/src/hip/wmma/intrinsic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ for (uint i = 0; i < uint(8); ++i) {{
frag_b,
frag_c,
frag_d,
warp_size,
} => {
let ab_format = if let Variable::WmmaFragment { frag: inner_a, .. } = frag_a {
if let Variable::WmmaFragment { frag: inner_b, .. } = frag_b {
Expand Down Expand Up @@ -158,10 +159,9 @@ for (uint i = 0; i < uint(8); ++i) {{
} else {
panic!("{frag_c} is not a WMMA fragment!")
};
// TODO: support wavefront size of 64
writeln!(
f,
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w32({frag_a}, {frag_b}, {frag_c});"
"{frag_d} = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}({frag_a}, {frag_b}, {frag_c});"
)
}
WmmaInstruction::Store {
Expand Down
16 changes: 16 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ pub trait Dialect:
fn warp_any(out: &IndexedVariable<Self>) -> String;
}

#[derive(Clone, Debug)]
pub struct CompilationOptions {
pub warp_size: u32,
}

impl Default for CompilationOptions {
fn default() -> Self {
Self { warp_size: 32 }
}
}

#[allow(clippy::too_many_arguments)]
#[derive(Clone, Debug, Default)]
pub struct CppCompiler<D: Dialect> {
Expand All @@ -53,16 +64,20 @@ pub struct CppCompiler<D: Dialect> {
items: HashSet<Item<D>>,
strategy: ExecutionMode,
settings: VariableSettings,
compilation_options: CompilationOptions,
}

impl<D: Dialect> Compiler for CppCompiler<D> {
type Representation = ComputeKernel<D>;
type CompilationOptions = CompilationOptions;

fn compile(
kernel: cubecl_core::ir::KernelDefinition,
compilation_options: &Self::CompilationOptions,
strategy: ExecutionMode,
) -> Self::Representation {
let compiler = Self {
compilation_options: compilation_options.clone(),
strategy,
..Self::default()
};
Expand Down Expand Up @@ -295,6 +310,7 @@ impl<D: Dialect> CppCompiler<D> {
frag_b: self.compile_variable(mat_b),
frag_c: self.compile_variable(mat_c),
frag_d: out,
warp_size: self.compilation_options.warp_size,
}),
gpu::CoopMma::Store {
mat,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cpp/src/shared/mma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ pub fn register_wmma_features(
supported_combinations: SupportedWmmaCombinations,
properties: &mut DeviceProperties<Feature>,
) {
// TODO: move this commented line to register explicitly at the runtime level
// properties.register_feature(Feature::CmmaWarpSize(self.warp_size()));
for (i, o, c, tdims) in supported_combinations {
for (m, n, k) in tdims {
properties.register_feature(Feature::Cmma {
Expand Down Expand Up @@ -115,6 +113,7 @@ pub enum WmmaInstruction<D: Dialect> {
frag_b: Variable<D>,
frag_c: Variable<D>,
frag_d: Variable<D>,
warp_size: u32,
},
/// Store the fragment in an output variable following the stride and the layout.
Store {
Expand Down Expand Up @@ -256,6 +255,7 @@ pub mod wmma_api_base {
frag_b,
frag_c,
frag_d,
..
} => writeln!(
f,
"{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
Expand Down
12 changes: 8 additions & 4 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use cubecl_cpp::cuda::arch::CudaArchitecture;
use cubecl_cpp::{formatter::format_cpp, CudaCompiler};
use cubecl_cpp::{
cuda::arch::CudaArchitecture, formatter::format_cpp, shared::CompilationOptions, CudaCompiler,
};

use super::fence::{Fence, SyncStream};
use super::storage::CudaStorage;
Expand Down Expand Up @@ -39,6 +40,7 @@ pub(crate) struct CudaContext {
module_names: HashMap<KernelId, CompiledKernel>,
timestamps: KernelTimestamps,
pub(crate) arch: CudaArchitecture,
compilation_options: CompilationOptions,
}

#[derive(Debug)]
Expand Down Expand Up @@ -305,6 +307,7 @@ impl ComputeServer for CudaServer {
impl CudaContext {
pub fn new(
memory_management: MemoryManagement<CudaStorage>,
compilation_options: CompilationOptions,
stream: cudarc::driver::sys::CUstream,
context: *mut CUctx_st,
arch: CudaArchitecture,
Expand All @@ -316,6 +319,7 @@ impl CudaContext {
stream,
arch,
timestamps: KernelTimestamps::Disabled,
compilation_options,
}
}

Expand All @@ -340,7 +344,7 @@ impl CudaContext {
logger: &mut DebugLogger,
mode: ExecutionMode,
) {
let mut kernel_compiled = kernel.compile(mode);
let mut kernel_compiled = kernel.compile(&self.compilation_options, mode);

if logger.is_activated() {
kernel_compiled.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
Expand Down Expand Up @@ -372,7 +376,7 @@ impl CudaContext {
message += format!("\n {line}").as_str();
}
}
let source = kernel.compile(mode).source;
let source = kernel.compile(&self.compilation_options, mode).source;
panic!("{message}\n[Source] \n{source}");
};
cudarc::nvrtc::result::get_ptx(program).unwrap()
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ fn create_client(device: &CudaDevice, options: RuntimeOptions) -> ComputeClient<
options.memory_config,
);

// register device props
let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, hardware_props);
register_supported_types(&mut device_props);
device_props.register_feature(Feature::Type(Elem::Float(FloatKind::TF32)));
let supported_wmma_combinations = CudaWmmaCompiler::supported_wmma_combinations(&arch);
register_wmma_features(supported_wmma_combinations, &mut device_props);

let cuda_ctx = CudaContext::new(memory_management, stream, ctx, arch);
let comp_opts = Default::default();
let cuda_ctx = CudaContext::new(memory_management, comp_opts, stream, ctx, arch);
let server = CudaServer::new(cuda_ctx);
ComputeClient::new(MutexComputeChannel::new(server), device_props)
}
Expand Down
6 changes: 5 additions & 1 deletion crates/cubecl-hip/src/compute/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use cubecl_cpp::formatter::format_cpp;
use cubecl_cpp::shared::CompilationOptions;

use crate::runtime::HipCompiler;

Expand Down Expand Up @@ -37,6 +38,7 @@ pub(crate) struct HipContext {
memory_management: MemoryManagement<HipStorage>,
module_names: HashMap<KernelId, HipCompiledKernel>,
timestamps: KernelTimestamps,
compilation_options: CompilationOptions,
}

#[derive(Debug)]
Expand Down Expand Up @@ -276,6 +278,7 @@ impl ComputeServer for HipServer {
impl HipContext {
pub fn new(
memory_management: MemoryManagement<HipStorage>,
compilation_options: CompilationOptions,
stream: cubecl_hip_sys::hipStream_t,
context: cubecl_hip_sys::hipCtx_t,
) -> Self {
Expand All @@ -285,6 +288,7 @@ impl HipContext {
stream,
context,
timestamps: KernelTimestamps::Disabled,
compilation_options,
}
}

Expand Down Expand Up @@ -313,7 +317,7 @@ impl HipContext {
let func_name = CString::new("kernel".to_string()).unwrap();
// CubeCL compilation
// jitc = just-in-time compiled
let mut jitc_kernel = cube_kernel.compile(mode);
let mut jitc_kernel = cube_kernel.compile(&self.compilation_options, mode);

if logger.is_activated() {
jitc_kernel.debug_info = Some(DebugInformation::new("cpp", kernel_id.clone()));
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ pub mod runtime;
pub use device::*;
#[cfg(target_os = "linux")]
pub use runtime::HipRuntime;
#[cfg(target_os = "linux")]
#[cfg(feature = "rocwmma")]
pub(crate) type HipWmmaCompiler = cubecl_cpp::hip::wmma::RocWmmaCompiler;
#[cfg(target_os = "linux")]
#[cfg(not(feature = "rocwmma"))]
pub(crate) type HipWmmaCompiler = cubecl_cpp::hip::wmma::WmmaIntrinsicCompiler;

Expand Down
9 changes: 6 additions & 3 deletions crates/cubecl-hip/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{ffi::CStr, mem::MaybeUninit, str::FromStr};
use cubecl_cpp::{
hip::HipDialect,
register_supported_types,
shared::{register_wmma_features, Architecture, CppCompiler, WmmaCompiler},
shared::{register_wmma_features, Architecture, CompilationOptions, CppCompiler, WmmaCompiler},
};

use cubecl_core::{Feature, MemoryConfiguration, Runtime};
Expand Down Expand Up @@ -110,13 +110,16 @@ fn create_client<M: WmmaCompiler<HipDialect<M>>>(
mem_properties.clone(),
options.memory_config,
);
let hip_ctx = HipContext::new(memory_management, stream, ctx);
let server = HipServer::new(hip_ctx);
let mut device_props = DeviceProperties::new(&[Feature::Plane], mem_properties, topology);
register_supported_types(&mut device_props);
let supported_wmma_combinations = M::supported_wmma_combinations(&arch);
register_wmma_features(supported_wmma_combinations, &mut device_props);

let comp_opts = CompilationOptions {
warp_size: arch.warp_size(),
};
let hip_ctx = HipContext::new(memory_management, comp_opts, stream, ctx);
let server = HipServer::new(hip_ctx);
ComputeClient::new(MutexComputeChannel::new(server), device_props)
}

Expand Down
14 changes: 13 additions & 1 deletion crates/cubecl-spirv/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use crate::{
SpirvKernel,
};

#[derive(Clone, Debug, Default)]
pub struct CompilationOptions {}

pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
pub target: Target,
builder: Builder,
Expand All @@ -41,6 +44,7 @@ pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
pub state: LookupTables,
pub ext_meta_pos: Vec<u32>,
pub metadata: Metadata,
compilation_options: CompilationOptions,
}

unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
Expand All @@ -64,6 +68,7 @@ impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
visited: self.visited.clone(),
metadata: self.metadata.clone(),
ext_meta_pos: self.ext_meta_pos.clone(),
compilation_options: self.compilation_options.clone(),
}
}
}
Expand All @@ -85,6 +90,7 @@ impl<T: SpirvTarget> Default for SpirvCompiler<T> {
visited: Default::default(),
metadata: Default::default(),
ext_meta_pos: Default::default(),
compilation_options: Default::default(),
}
}
}
Expand All @@ -105,8 +111,13 @@ impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {

impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
type Representation = SpirvKernel;
type CompilationOptions = CompilationOptions;

fn compile(value: KernelDefinition, mode: ExecutionMode) -> Self::Representation {
fn compile(
value: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation {
let bindings = value
.inputs
.clone()
Expand All @@ -128,6 +139,7 @@ impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
let (module, optimizer) = Self {
mode,
metadata: Metadata::new(num_meta as u32, num_ext),
compilation_options: compilation_options.clone(),
ext_meta_pos,
..Default::default()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
mode
};
log::debug!("Compiling {}", kernel.name());
let compiled = kernel.compile(mode);
let compiled = kernel.compile(&server.compilation_options, mode);
#[cfg(feature = "spirv-dump")]
dump_spirv(&compiled, kernel.name(), kernel.id());
compiled
Expand Down
Loading