Skip to content

LuisaGroup/luisa-compute-rs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

98b0ce3 · Apr 25, 2024
Oct 27, 2023
Apr 24, 2024
Nov 23, 2023
Apr 24, 2024
Apr 25, 2024
Mar 26, 2024
Dec 29, 2022
Feb 10, 2023
Sep 17, 2023
Oct 18, 2022
Jan 25, 2024
Mar 20, 2023
Sep 23, 2023

Repository files navigation

luisa-compute-rs

Rust frontend to LuisaCompute and more! Unified API and embedded DSL for high performance computing on stream architectures.

To see the use of luisa-compute-rs in a high performance offline rendering system, checkout our research renderer

Table of Contents

Example

Try cargo run --release --example path_tracer -- [cpu|cuda|dx|metal]!

Vecadd

use luisa::prelude::*;
use luisa_compute as luisa;

#[tracked]
fn main() {
    use luisa::*;
    init_logger();
    let args: Vec<String> = std::env::args().collect();
    assert!(
        args.len() <= 2,
        "Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
        args[0]
    );

    let ctx = Context::new(current_exe());
    let device = ctx.create_device(if args.len() == 2 {
        args[1].as_str()
    } else {
        "cpu"
    });
    let x = device.create_buffer::<f32>(1024);
    let y = device.create_buffer::<f32>(1024);
    let z = device.create_buffer::<f32>(1024);
    x.view(..).fill_fn(|i| i as f32);
    y.view(..).fill_fn(|i| 1000.0 * i as f32);
    let kernel = device.create_kernel::<fn(Buffer<f32>)>(&device, |buf_z| {
        // z is pass by arg
        let buf_x = &x; // x and y are captured
        let buf_y = &y;
        let tid = dispatch_id().x;
        let x = buf_x.read(tid);
        let y = buf_y.read(tid);
        let vx = 0.0f32.var(); // create a local mutable variable
        *vx += x;
        buf_z.write(tid, vx + y);
    });
    kernel.dispatch([1024, 1, 1], &z);
    let z_data = z.view(..).copy_to_vec();
    println!("{:?}", &z_data[0..16]);
}

Other examples in examples

Example Description
Atomic Atomic buffer operations
Bindless Bindless array access
Custom Aggregate Use #[derive(Aggregate)] for kernel only data types
Custom Op Custom operator for CPU backend
Polymporphism Simple usage of Polymorphic<K, T>
Advanced Polymporphism Use Polymorphic<K, T> to implement recursive polymorphic call
Ray Tracing A simple raytracing kernel with GUI
Path Tracer A path tracer with GUI

Overview

Embedded Domain-Specific Language

We provided an Rust-flavored implementation of LuisaCompute EDSL that tightly integrates with Rust language via traits and proc-macros.

Automatic Differentiation

We implemented a source-to-source reverse mode automatic differentiation that supports complex control flow.

The autodiff works tightly with builtin functions and the type system. Instead of implementing every function using basic arithmetic operations and apply autodiff on it, all builtin functions are differentiated using efficient VJP formulae.

CPU Backend

This crate also provides a CPU backend implementation in Rust that will eventually replace the current LLVM backend in LuisaCompute. This backend emphasizes on debuggability, flexibility and as well as safety.

IR Module

The EDSL and code generation are built atop of an SSA-based IR module. The IR module is in a separate crate and can be used to implement other backends and IR transformation such as autodiff.

Debuggability

The CPU backend is designed to be debuggable. If needed, it will perform runtime checks to detect errors such as out-of-bound access, bindless array type mismatch, etc. It will display error message containing the host stacktrace for pinpointing the error location.

Usage

Building

To try out the examples, clone the repo using git clone --recursive https://github.com/LuisaGroup/luisa-compute-rs.git.

To use it in your project, add the following to your Cargo.toml:

[dependencies]
luisa_compute = { git= "https://github.com/LuisaGroup/luisa-compute-rs.git"}

You need to install CMake and Ninja to build the backends. More details about prerequistes can be found in here.

In your project, the following to your files:

use luisa_compute as luisa;
use luisa::prelude::*;

track! and #[tracked] Macro

To start writing using DSL, let's first introduce the track! macro. track!(expr) rewrites expr and redirect operators/control flows to DSL's internal traits. It resolves the fundamental issue that Rust is unable to overload operator=.

Every operation involving a DSL object must be enclosed within track!, except Var<T>::store() and Var<T>::load()

For, example:

let a = 1.0f32.expr();
let b = 1.0f32.expr();
let c = a + b; // Compile error

let c = track!(a + b); // c is now 2.0

// Or even better,
track!({
  let a = 1.0f32.expr();
  let b = 1.0f32.expr();
  let c = a + b;
});

We highly encourage you to enclose the entire kernel inside track!*

Inside track! normal Rust syntax is still supported (with a few exceptions). Operations involving non-DSL values are still performed using native Rust operators. For example:

// This is still valid
track!({
  let mut v = vec![1.0f32, 2.0,f32];
  v.push(3.0);
  let a = RefCell::new(1.0f32);
  *a.borrow_mut() += 1.0;
  match v.len() {
    1 => println!("1"),
    _ => println!("not 1"),
  }
  if v.is_empty() {
    println!("empty");
  }
  let a = 1.0f32.expr();
  let b = a + 1.0;
});

We also offer a #[tracked] macro that applies to a function. It transform the body of the function using track!.

#[tracked]
fn add(a:Expr<f32>, b:Expr<f32>) -> Expr<f32> {
 a + b
}

However, not every kernel can be constructed using `track!` code only. We still need the ability to use native control flow directly in kernel. 

For example, we can use native `for` loops to unroll a DSL loop. We first starts with a native version using DSL loops.
```rust
#[tracked]
fn pow_naive(x:Expr<f32>, i:u32) -> Expr<f32> {
 let p = 1.0f32.var();
 for _ in 0..i {
   p *= x;
 }
 **p // converts Var<f32> to Expr<f32>, only required when passing a Var<T> to fn(Expr<T>) and return from fn(...)->Expr<T>
}

To unroll the loop, we basically just what the DSL to produce p*=x for i times. We can use the escape!(expr) macro so that it leaves expr as is, preserving the native loop.

#[tracked]
fn pow_unrolled(x:Expr<f32>, i:u32)->Expr<f32> {
  let p = 1.0f32.var();
  escape!({
    for _ in 0..i {
      track!({
        p *= x;
      });
  });
  **p 
}

Of course this can be tedius if you just want to unroll a loop. Thus we provide a for_unrolled function that unrolls a loop for you.

#[tracked]
fn pow_unrolled(x:Expr<f32>, i:u32)->Expr<f32> {
  let p = 1.0f32.var();
  for_unrolled(0..i, |_|{
      p *= x;
  });
  /*
  Equivalently, you can use
  (0..i).for_each(|_|{
      p *= x;
  });
  */
  **p 
}

Variables and Expressions

We support the following primitive types on backend bool, i32, u32, i64, u64, f32. Additional primitive types such as u8, i8, i16, u16, and f64 are supported on some backends. For each type, there are two EDSL proxy objects Expr<T> and Var<T>. Expr<T> is an immutable object that represents a value. Var<T> is also an immutable object that represents a variable (mutable value).

Warning: Every DSL object in host code must be immutable due to Rust unable to overload operator =. Attempting to circumvent this limitation using Cell and RefCell would likely result in uncompilable kernels/wrong results. For example:

// **no good**
let v = Cell::new(0.0f32.expr());
track!(if cond {
  v.set(v.get() + 1.0);
));

// **good**
let v = 0.0f32.var();
track!(if cond {
  *v += 1.0;
));

All operations except load/store should be performed on Expr<T>. Var<T> can only be used to load/store values.

As in the C++ EDSL, we additionally supports the vector of length 2-4 for all primitives and float square matrices with dimension 2-4 such as

luisa_compute::lang::types::vector::alias::{
  Bool2 
  Bool3 
  Bool4 
  Float2 
  Float3 
  Float4
  Int2
  Int3 
  Int4 
  Uint2
  Uint3 
  Uint4
};

luisa_compute::lang::types::vector::{Mat2, Mat3, Mat4};

Array types [T;N] are also supported. Call arr.read(i) and arr.write(i, value) on ArrayVar<T, N> for element access. ArrayExpr<T,N> can be stored to and loaded from ArrayVar<T, N>. The limitation is however the array length must be determined during host compile time. If runtime length is required, use VLArrayVar<T>. VLArrayVar<T>::zero(length: usize) would create a zero initialized array. Similarly you can use read and write methods as well. To query the length of a VLArrayVar<T> in host, use VLArrayVar<T>::static_len()->usize. To query the length in kernel, use VLArrayVar<T>::len()->Expr<u32>

Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as PartialOrd cannot return a DSL type. Instead, use cmpxx methods such as cmpgt, cmpeq, etc. To cast a primitive/vector into another type, use v.as_::<Type>(), v.as_Type() and v.as_PrimitiveType(). For example:

let iv = Int2::expr(1, 1, 1);
let fv = iv.as_::<Float2>(); //fv is Expr<Float2>
let also_fv = iv.as_float2();
let also_fv = iv.cast_f32(); 

To perform a bitwise cast, use the bitcast function. let fv:Expr<f32> = bitcast::<u32, f32>(0u32);

Builtin Functions

We have extentded primitive types with methods similar to their host counterpart: v.sin(), luisa::max(a, b), etc. Most methods accepts both a Expr<T> or a literal such as 0.0. However, the select function is slightly different as it does not accept literals. You need to use select(cond, f_var, 1.0f32.expr()).

Control Flow

Note, you cannot modify outer scope variables inside a control flow block by declaring the variable as mut. To modify outer scope variables, use Var<T> instead and store the value back to the outer scope.

if, while, break, continue, return and loop are supported via tracked! macro. It is also possible to construct these control flows without track!.

The switch_ statement has to be constructe manually. For example,

let (x,y) = switch::<(Expr<i32>, Expr<f32>)>(value)
    .case(1, || { ... })
    .case(2, || { ... })
    .default(|| { ... })
    .finish();

Warning: due to backend generates C-like source code instead of LLVM IR/PTX/DXIL directly, it is not possible to use break inside switch cases.

Custom Data Types

To add custom data types to the EDSL, simply derive from Value macro. Note that #[repr(C)] is required for the struct to be compatible with C ABI. #[derive(Value)] would generate two proxies types: XXExpr and XXVar. Implement your methods on these proxies instead of Expr<T> and Var<T> directly.

#[derive(Copy, Clone, Default, Debug, Value)]
#[repr(C)]
#[value_new(pub)]
pub struct MyVec2 {
    pub x: f32,
    pub y: f32,
}

impl MyVec2Expr {
  // pass arguments using `AsExpr` so that they accept both Var and Expr
  #[tracked]
  pub fn dot(&self, other: impl AsExpr<Value=MyVec2>) {
    self.x * other.x + self.y * other.y
  }
}
impl MyVec2Var {
  #[tracked]
  pub fn set_to_one(&self) {
    // you can access the current `Var<Self>` using `self_`
    self.self_ = MyVec2::new_expr(1.0, 1.0);
  }
}

track!({
  let v = MyVec2::var_zeroed();
  let sum = v.x +*v.y; 
  *v.x += 1.0;
  let v = MyVec2::from_comps_expr(MyVec2Comps{ x: 1.0f32.expr(), y: 1.0f32.expr()});
  let v = MyVec2::new_expr(1.0f32, 2.0f32); // only if #[value_new] is present
});

// You can also control the order of arguments in `#[value_new]`
#[derive(Copy, Clone, Default, Debug, Value)]
#[repr(C)]
#[value_new(pub y, x)]
pub struct Foo {
    pub x: f32,
    pub y: i32,
}
let v = MyVec2::new_expr(1.0fi32, 2.0f32);
// v.x == 2.0
// v.y == 1

Polymorphism

We prvoide a powerful Polymorphic<DevirtualizationKey, dyn Trait> construct as in the C++ DSL. See examples for more detail

trait Area {
    fn area(&self) -> Expr<f32>;
}
#[derive(Value, Clone, Copy)]
#[repr(C)]
pub struct Circle {
    radius: f32,
}
impl Area for CircleExpr {
    fn area(&self) -> Expr<f32> {
        PI * self.radius * self.radius
    }
}
impl_polymorphic!(Area, Circle);

let circles = device.create_buffer(..);
let mut poly_area: Polymorphic<(), dyn Area> = Polymorphic::new();
poly_area.register((), &circles);
let area = poly_area.dispatch(tag, index, |obj|{
    obj.area()
});

Autodiff

Autodiff code should be enclosed in the autodiff call. The requires_grad call is used to mark the variables that need to be differentiated. Any type including user defined ones can receive gradients. The backward call triggers the backward pass. Subsequent calls to gradient will return the gradient of the variable passed in. User can also supply custom gradients with backward_with_grad.

Note: Only one backward call is allowed in a single autodiff block. The autodiff block does not return any value. To store any side effects, use of local variables or buffers is required.

autodiff(||{
    let v: Expr<Vec3> = buf_v.read(..);
    let m: Expr<Mat3> = buf_m.read(..);
    requires_grad(v);
    requires_grad(m);
    let z = v.dot(m * v) * 0.5;
    backward(z);
    let dv = gradient(dv);
    let dm = gradient(dm);
    buf_dv.write(.., dv);
    buf_dm.write(.., dm);
});

Custom Operators

LuisaCompute supports injecting arbitrary user code to implement a custom operator. This is handled differently on different backends. On CPU backends, user can directly pass a closure to the kernel. The closure needs to have a Fn(&mut T) signature where it modifies the argument inplace. The EDSL frontend would then wrap the closure into a T->T function object.

#[derive(Clone, Copy, Value, Debug)]
#[repr(C)]
pub struct MyAddArgs {
    pub x: f32,
    pub y: f32,
    pub result: f32,
}
let my_add = CpuFn::new(|args: &mut MyAddArgs| {
    args.result = args.x + args.y;
});

let args = MyAddArgsExpr::new(x, y, 0.0.expr());
let result = my_add.call(args);

Callable

Note: Almost all usage of callables are largely replaced with outline(...) in the Outlining section. We keep this section for reference.

Users can define device-only functions using Callables. Callables have similar type signature to kernels: Callable<fn(Args)->Ret>. The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using Device::create_callable. To invoke a Callable, use Callable::call(args...). Callables accepts arguments such as resources (BufferVar<T>, .etc), expressions and references (pass a Var<T> to the callable). For example:

let add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
    a + b
}));
let z = add.call(x, y);
let pass_by_ref =  Callable::<fn(Var<f32>)>::new(&device, track!(|a| {
   a += 1.0;
}));
let a = 1.0f32.var();
pass_by_ref.call(a);
cpu_dbg!(*a); // prints 2.0

Note: You create callables within another callable and even capture variables from the outer callable. The only limitation is that you cannot return a callable from another callable.

let add = track!(Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(
    &device,
    |a, b| {
        // callables can be defined within callables
        let partial_add = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device, |y| a + y);
        partial_add.call(b)
    }
));

Static callables

A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use create_static_callable(fn). For example,

lazy_static! {
    static ref ADD:Callable<fn(Expr<f32>, Expr<f32>)->Expr<f32>> = Callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>::new_static(|a, b| {
    track!(a + b)
});
}
ADD.call(x, y);

Dynamic callables (templates)

Use DynCallable. These are callables that defer recording until being called. As a result, it requires you to pass a 'static closure, avoiding the capture issue. To create a DynCallable, use Device::create_dyn_callable(Box::new(fn)). The syntax is the same as create_callable. Furthermore, DynCallable supports DynExpr and DynVar, which provides some capablitiy of implementing template/overloading inside EDSL.

let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
    &device,
    Box::new(|a: DynExpr, b: DynExpr| -> DynExpr {
        if let Some(a) = a.downcast::<f32>() {
            let b = b.downcast::<f32>().unwrap();
            return DynExpr::new(track!(a + b));
        } else if let Some(a) = a.downcast::<i32>() {
            let b = b.downcast::<i32>().unwrap();
            return DynExpr::new(track!(a + b));
        } else {
            unreachable!()
        }
    }),
);

Outlining

The outline(|| { .. }) extracts a code snippet and deduplicate it into a callable. The callable is then called from the original location.

let add = |a:Expr<f32>, b:Expr<f32>|{
  let sum = 0.0f32.var();
  // automatically generates a callable and call it
  outline(|| {
      let a = 1.0f32.expr();
      let b = 2.0f32.expr();
      *sum = a + b;
  });
  **sum
};
let z = add(x, y);
// the following code is deduplicated
let w = add(z, y);

Since outline works via capturing, it is possible to pass arbitrary types or ever types that are not statically known.

Kernel

A kernel can be written in a closure or a function. The closure/function should have a Fn(/*args*/)->() signature, where the args are taking the Var type of resources, such as BufferVar<T>, Tex2D<T>, etc.

let kernel = device.create_kernel::<fn(Arg0, Arg1, ...)>(&|/*args*/| {
    /*body*/
});
kernel.dispatch([/*dispatch size*/], &arg0, &arg1, ...);

There are two ways to pass arguments to a kernel: by arguments or by capture.

let captured:Buffer<f32> = device.create_buffer(...);
let kernel = device.create_kernel::<fn(Buffer<f32>>(arg| {
    let v = arg.read(..);
    let u = captured.read(..);
}));

User can pass a maximum of 16 arguments to kernel and unlimited number of captured variables. If more than 16 arguments are needed, user can pack them into a struct and pass the struct as a single argument.

#[derive(BindGroup)]
pub struct BufferPair {
    a:Buffer<f32>,
    b:Buffer<f32>
}
let kernel = device.create_kernel::<fn(BufferPair)>(&|| {
    // ...
});
let a = device.create_buffer(...);
let b = device.create_buffer(...);
let pair = BufferPair{a,b};
kernel.dispatch([...], &packed);
let BufferPair{a, b} = packed; // unpack if you need to use them later

Debugging

We provide logging through the log crate. Users can either setup their own logger or use the init_logger() and init_logger_verbose() for handy initialization. For debug builds, oob checks are automatically inserted so that an assertion failure would occur if oob access is detected. On CPU/CUDA backend, it will be accompanied by an informative message such as assertion failed: i.cmplt(self.len()) at xx.rs:yy:zz. Setting the environment variable LUISA_BACKTRACE=1 would display a stacktrace containing the DSL code that records the kernel. For other backends, assertion with message is still WIP.

For release builds however, these checks are disabled by default for performance reasons. To enable them, set environment variable LUISA_DEBUG=1 prior to launching the application.

Advanced Usage

Note that the IR module has a public interface. If needed, user can implement their own DSL syntax sugar. Every EDSL object implements either Aggregate or FromNode trait, which allows any EDSL type to be destructured into its underlying IR nodes and reconstructed from them.

TODO

Safety

API

Host-side safety: The API aims to be 100% safe on host side. However, the safety of async operations are gauranteed via staticly know sync points (such as Stream::submit_and_sync). If fully dynamic async operations are needed, user need to manually lift the liftime and use unsafe code accordingly.

Device-side safety: Due to the async nature of device-side operations. It is both very difficult to propose a safe host API that captures device resource lifetime. While device-side safety isn't guaranteed at compile time, on cpu backend runtime checks will catch any illegal memory access during execution. However, for other backends such check is either too expensive or impractical and memory errors would result in undefined behavior instead.

Backend

Safety checks such as OOB is generally not available for many GPU backends. As it is difficult to produce meaningful debug message in event of a crash. However, the CPU backend provided in the crate contains full safety checks and is recommended for debugging.

Citation

When using luisa-compute-rs in an academic project, we encourage you to cite

@misc{LuisaComputeRust
    author = {Xiaochun Tong, et al},
    year = {2023},
    note = {https://github.com/LuisaGroup/luisa-compute-rs},
    title = {Rust frontend to LuisaCompute}
}