Skip to content

Commit

Permalink
Improve debug kernel generation (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jul 28, 2024
1 parent 6eaccca commit 2b95a9e
Show file tree
Hide file tree
Showing 18 changed files with 293 additions and 64 deletions.
1 change: 1 addition & 0 deletions crates/cubecl-common/src/stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ impl<T> RwLock<T> {
/// A unique identifier for a running thread.
///
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
#[allow(dead_code)]
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub struct ThreadId(core::num::NonZeroU64);

Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ repository = "https://github.com/tracel-ai/cubecl/tree/main/cubecl-cube"
version.workspace = true

[features]
default = []
std = []
default = ["cubecl-runtime/default"]
std = ["cubecl-runtime/std"]
template = []
export_tests = []

Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ pub struct KernelExpansion {
}

/// Simply indicate the output that can be replaced by the input.
#[derive(new, Clone, Copy, Debug)]
#[derive(new, Default, Clone, Debug, Hash, PartialEq, Eq)]
pub struct InplaceMapping {
/// Input position.
pub pos_input: usize,
/// Output position.
pub pos_output: usize,
}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum VectorizationPartial {
Input {
pos: usize,
Expand All @@ -46,7 +46,7 @@ enum VectorizationPartial {
},
}

#[derive(Default, Clone)]
#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
pub struct KernelSettings {
pub mappings: Vec<InplaceMapping>,
vectorization_global: Option<Vectorization>,
Expand Down
119 changes: 93 additions & 26 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
marker::PhantomData,
};

use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId};
use alloc::sync::Arc;
use cubecl_runtime::server::{Binding, ComputeServer};

Expand All @@ -16,68 +16,135 @@ pub struct CompiledKernel {
pub cube_dim: CubeDim,
/// The number of bytes used by the share memory
pub shared_mem_bytes: usize,
pub lang_tag: Option<&'static str>,
/// Extra debugging information about the compiled kernel.
pub debug_info: Option<DebugInformation>,
}

/// Extra debugging information about the compiled kernel.
#[derive(new)]
pub struct DebugInformation {
/// The language tag of the source..
pub lang_tag: &'static str,
/// The compilation id.
pub id: KernelId,
}

impl Display for CompiledKernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n======== Compiled Kernel ========")?;
f.write_str("\n[START_KERNEL_COMPILATION]")?;

if let Some(name) = self.name {
if name.len() <= 32 {
f.write_fmt(format_args!("\nname: {name}"))?;
} else {
let name = format_type_name(name);
let name = format_str(name, &[('<', '>')], false);
f.write_fmt(format_args!("\nname: {name}"))?;
}
}

f.write_fmt(format_args!(
"
cube_dim: ({}, {}, {})
shared_memory: {} bytes
shared_memory: {} bytes",
self.cube_dim.x, self.cube_dim.y, self.cube_dim.z, self.shared_mem_bytes,
))?;

if let Some(info) = &self.debug_info {
f.write_fmt(format_args!(
"\ninfo: {}",
format_str(
format!("{}", info.id).as_str(),
&[('(', ')'), ('[', ']'), ('{', '}')],
true
)
))?;
}

f.write_fmt(format_args!(
"
source:
```{}
{}
```
=================================
[END_KERNEL_COMPILATION]
",
self.cube_dim.x,
self.cube_dim.y,
self.cube_dim.z,
self.shared_mem_bytes,
self.lang_tag.unwrap_or(""),
self.debug_info
.as_ref()
.map(|info| info.lang_tag)
.unwrap_or(""),
self.source
))
}
}

fn format_type_name(type_name: &str) -> String {
fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
let kernel_id = kernel_id.to_string();
let mut result = String::new();
let mut depth = 0;
let indendation = 4;

for c in type_name.chars() {
let mut prev = ' ';

for c in kernel_id.chars() {
if c == ' ' {
continue;
}

if c == '<' {
depth += 1;
result.push_str("<\n");
result.push_str(&" ".repeat(indendation * depth));
continue;
} else if c == '>' {
depth -= 1;
result.push_str(",\n>");
let mut found_marker = false;

for (start, end) in markers {
let (start, end) = (*start, *end);

if c == start {
depth += 1;
if prev != ' ' && include_space {
result.push(' ');
}
result.push(start);
result.push('\n');
result.push_str(&" ".repeat(indendation * depth));
found_marker = true;
} else if c == end {
depth -= 1;
if prev != start {
if prev == ' ' {
result.pop();
}
result.push_str(",\n");
result.push_str(&" ".repeat(indendation * depth));
result.push(end);
} else {
for _ in 0..(&" ".repeat(indendation * depth).len()) + 1 + indendation {
result.pop();
}
result.push(end);
}
found_marker = true;
}
}

if found_marker {
prev = c;
continue;
}

if c == ',' && depth > 0 {
if prev == ' ' {
result.pop();
}

result.push_str(",\n");
result.push_str(&" ".repeat(indendation * depth));
continue;
}

if c == ':' && include_space {
result.push(c);
result.push(' ');
prev = ' ';
} else {
result.push(c);
prev = c;
}
}

Expand All @@ -88,7 +155,7 @@ fn format_type_name(type_name: &str) -> String {
/// provided id.
pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self) -> CompiledKernel;
}
Expand All @@ -113,11 +180,11 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
source,
cube_dim,
shared_mem_bytes,
lang_tag: None,
debug_info: None,
}
}

fn id(&self) -> String {
fn id(&self) -> KernelId {
self.kernel_definition.id().clone()
}
}
Expand All @@ -127,7 +194,7 @@ impl CubeTask for Arc<dyn CubeTask> {
self.as_ref().compile()
}

fn id(&self) -> String {
fn id(&self) -> KernelId {
self.as_ref().id()
}
}
Expand All @@ -137,7 +204,7 @@ impl CubeTask for Box<dyn CubeTask> {
self.as_ref().compile()
}

fn id(&self) -> String {
fn id(&self) -> KernelId {
self.as_ref().id()
}
}
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ pub trait Int:

macro_rules! impl_int {
($type:ident, $primitive:ty) => {
#[derive(Clone, Copy)]
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Copy, Hash)]
pub struct $type {
pub val: $primitive,
pub vectorization: u8,
Expand Down
13 changes: 12 additions & 1 deletion crates/cubecl-core/src/frontend/element/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,25 @@ use super::{
ScalarArgSettings, Vectorized, __expand_new, __expand_vectorized,
};

#[derive(Clone, Copy, Debug)]
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Copy, Hash)]
/// An unsigned int.
/// Preferred for indexing operations
pub struct UInt {
pub val: u32,
pub vectorization: u8,
}

impl core::fmt::Debug for UInt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.vectorization == 1 {
f.write_fmt(format_args!("{}", self.val))
} else {
f.write_fmt(format_args!("{}-{}", self.val, self.vectorization))
}
}
}

impl CubeType for UInt {
type ExpandType = ExpandElementTyped<Self>;
}
Expand Down
Loading

0 comments on commit 2b95a9e

Please sign in to comment.