Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 18, 2024
1 parent ee7693d commit 3f80b21
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ This is just the beginning.
We plan to include more utilities such as convolutions, random number generation, fast Fourier transforms, and other essential algorithms.
We are a small team also building [Burn](https://burn.dev), so don't hesitate to contribute and port algorithms; it can help more than you would imagine!

## How it works

CubeCL leverages Rust's proc macro system in a unique two-step process:

1. Parsing: The proc macro parses the GPU kernel code using the syn crate.
2. Expansion: Instead of immediately generating an Intermediate Representation (IR), the macro generates a new Rust function.

The generated function, semantically similar to the original, is responsible for creating the IR when called.
This approach differs from traditional compilers, which typically generate IR directly after parsing.
Our method enables several key features:

- **Comptime**: By not transforming the original code, it becomes remarkably easy to integrate compile-time optimizations.
- **Automatic Vectorization**: By simply vectorizing the inputs of a CubeCL function, we can determine the vectorization factor of each intermediate variable during the expansion.
- **Rust Integration**: The generated code remains valid Rust code, allowing it to be bundled without any dependency on the specific runtime.

## Design

CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size.
Expand Down
1 change: 1 addition & 0 deletions examples/gelu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ cuda = ["cubecl/cuda"]

[dependencies]
cubecl = { path = "../../crates/cubecl", version = "0.1.0" }
half = { workspace = true }
19 changes: 11 additions & 8 deletions examples/gelu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,25 @@ fn gelu_scalar<F: Float>(x: F) -> F {
}

pub fn launch<R: Runtime>(device: &R::Device) {
type Primitive = half::f16;
type CubeType = F16;

let client = R::client(device);
let input = &[-1., 0., 1., 5.];
let input = &[-1., 0., 1., 5.].map(|f| Primitive::from_f32(f));

let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
let input_handle = client.create(f32::as_bytes(input));
let output_handle = client.empty(input.len() * core::mem::size_of::<Primitive>());
let input_handle = client.create(Primitive::as_bytes(input));

gelu_array::launch::<F32, R>(
gelu_array::launch::<CubeType, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::new(&input_handle, input.len()),
ArrayArg::new(&output_handle, input.len()),
CubeDim::new(input.len() as u32 / 4, 1, 1),
ArrayArg::vectorized(4, &input_handle, input.len()),
ArrayArg::vectorized(4, &output_handle, input.len()),
);

let bytes = client.read(output_handle.binding());
let output = f32::from_bytes(&bytes);
let output = Primitive::from_bytes(&bytes);

// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("Executed gelu with runtime {:?} => {output:?}", R::name());
Expand Down

0 comments on commit 3f80b21

Please sign in to comment.