From 3f80b21f526862bfeb602a90059be045baddbe40 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 18 Jul 2024 15:01:16 -0400 Subject: [PATCH] WIP --- README.md | 15 +++++++++++++++ examples/gelu/Cargo.toml | 1 + examples/gelu/src/lib.rs | 19 +++++++++++-------- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index e32e9a83..567cfbf8 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,21 @@ This is just the beginning. We plan to include more utilities such as convolutions, random number generation, fast Fourier transforms, and other essential algorithms. We are a small team also building [Burn](https://burn.dev), so don't hesitate to contribute and port algorithms; it can help more than you would imagine! +## How it works + +CubeCL leverages Rust's proc macro system in a unique two-step process: + +1. Parsing: The proc macro parses the GPU kernel code using the syn crate. +2. Expansion: Instead of immediately generating an Intermediate Representation (IR), the macro generates a new Rust function. + +The generated function, semantically similar to the original, is responsible for creating the IR when called. +This approach differs from traditional compilers, which typically generate IR directly after parsing. +Our method enables several key features: + +- **Comptime**: By not transforming the original code, it becomes remarkably easy to integrate compile-time optimizations. +- **Automatic Vectorization**: By simply vectorizing the inputs of a CubeCL function, we can determine the vectorization factor of each intermediate variable during the expansion. +- **Rust Integration**: The generated code remains valid Rust code, allowing it to be bundled without any dependency on the specific runtime. + ## Design CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size. diff --git a/examples/gelu/Cargo.toml b/examples/gelu/Cargo.toml index 0acaf9c6..8875481c 100644 --- a/examples/gelu/Cargo.toml +++ b/examples/gelu/Cargo.toml @@ -13,3 +13,4 @@ cuda = ["cubecl/cuda"] [dependencies] cubecl = { path = "../../crates/cubecl", version = "0.1.0" } +half = { workspace = true } diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 6917c890..3ec81305 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -13,22 +13,25 @@ fn gelu_scalar(x: F) -> F { } pub fn launch(device: &R::Device) { + type Primitive = half::f16; + type CubeType = F16; + let client = R::client(device); - let input = &[-1., 0., 1., 5.]; + let input = &[-1., 0., 1., 5.].map(|f| Primitive::from_f32(f)); - let output_handle = client.empty(input.len() * core::mem::size_of::()); - let input_handle = client.create(f32::as_bytes(input)); + let output_handle = client.empty(input.len() * core::mem::size_of::()); + let input_handle = client.create(Primitive::as_bytes(input)); - gelu_array::launch::( + gelu_array::launch::( client.clone(), CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - ArrayArg::new(&input_handle, input.len()), - ArrayArg::new(&output_handle, input.len()), + CubeDim::new(input.len() as u32 / 4, 1, 1), + ArrayArg::vectorized(4, &input_handle, input.len()), + ArrayArg::vectorized(4, &output_handle, input.len()), ); let bytes = client.read(output_handle.binding()); - let output = f32::from_bytes(&bytes); + let output = Primitive::from_bytes(&bytes); // Should be [-0.1587, 0.0000, 0.8413, 5.0000] println!("Executed gelu with runtime {:?} => {output:?}", R::name());