diff --git a/README.md b/README.md index 593f375f..5787d587 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # luisa-compute-rs Rust frontend to LuisaCompute and more! Unified API and embedded DSL for high performance computing on stream architectures. +*Warning:* while the project is already usable, it is not stable and **breaking changes** can happend at any time without notification. + To see the use of `luisa-compute-rs` in a high performance offline rendering system, checkout [our research renderer](https://github.com/shiinamiyuki/akari_render) ## Table of Contents - [luisa-compute-rs](#luisa-compute-rs) @@ -15,10 +17,11 @@ To see the use of `luisa-compute-rs` in a high performance offline rendering sys - [Debuggability](#debuggability) - [Usage](#usage) - [Building](#building) + - [`track!` and #[tracked] Macro](#track-and-tracked-macro) - [Variables and Expressions](#variables-and-expressions) - [Builtin Functions](#builtin-functions) - [Control Flow](#control-flow) - - [`track!` Mcro](#track-mcro) + - [Custom Data Types](#custom-data-types) - [Polymorphism](#polymorphism) - [Autodiff](#autodiff) @@ -61,16 +64,16 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&|buf_z| { + let kernel = Kernel::)>::new(&device, |buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); - let tid = dispatch_id().x(); + let tid = dispatch_id().x; let x = buf_x.read(tid); let y = buf_y.read(tid); let vx = 0.0f32.var(); // create a local mutable variable - *vx.get_mut() += x; - buf_z.write(tid, vx.load() + y); + *vx += x; + buf_z.write(tid, vx + y); }); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); @@ -126,95 +129,127 @@ In your project, the following to your files: use luisa_compute as luisa; use luisa::prelude::*; ``` -### Variables and Expressions -There are six basic types in EDSL. `bool`, `i32`, `u32`, `i64`, `u64`, `f32`. (`f64` support might be added to CPU backend). -For each type, there are two EDSL proxy objects `Expr` and `Var`. `Expr` is an immutable object that represents a value. `Var` is a mutable object that represents a variable. To load values from `Var`, use `*var` and to obtain a mutable reference for assignment, use `v.get_mut()`. E.g. `*v.get_mut() = f(*u)`. -*Note*: Every DSL object in host code **must** be immutable due to Rust unable to overload `operator =`. For example: +### `track!` and `#[tracked]` Macro +To start writing using DSL, let's first introduce the `track!` macro. `track!( expr )` rewrites `expr` and redirect operators/control flows to DSL's internal traits. It resolves the fundamental issue that Rust is unable to overload `operator=`. + +**Every operation involving a DSL object must be enclosed within `track!`**, except `Var::store()` and `Var::load()` + +For, example: ```rust -// **no good** -let mut v = 0.0f32.expr(); -if_!(cond, { - v += 1.0; -}); +let a = 1.0f32.expr(); +let b = 1.0f32.expr(); +let c = a + b; // Compile error -// also **not good** -let v = Cell::new(0.0f32.expr()); -if_!(cond, { - v.set(v.get() + 1.0); -}); +let c = track!(a + b); // c is now 2.0 -// **good** -let v = 0.0f32.var(); -if_!(cond, { - *v.get_mut() += 1.0; +// Or even better, +track!({ + let a = 1.0f32.expr(); + let b = 1.0f32.expr(); + let c = a + b; }); ``` -*Note*: You should not store the referene obtained by `v.get_mut()` for repeated use, as the assigned value is only updated when `v.get_mut()` is dropped. For example,: +We also offer a `#[tracked]` macro that applies to a function. It transform the body of the function using `track!`. + ```rust +#[tracked] +fn add(a:Expr, b:Expr)->Expr { + a + b +} + +However, not every kernel can be constructed using `track!` code only. We still need the ability to use native control flow directly in kernel. + +For example, we can use native `for` loops to unroll a DSL loop. We first starts with a native version using DSL loops. ```rust -let v = 0.0f32.var(); -let bad = v.get_mut(); -*bad = 1.0; -let u = *v; -drop(bad); -cpu_dbg!(u); // prints 0.0 -cpu_dbg!(*v); // prints now 1.0 +#[tracked] +fn pow_naive(x:Expr, i:u32)->Expr { + let p = 1.0f32.var(); + for _ in 0..i { + p *= x; + } + **p // converts Var to Expr, only required when passing a Var to fn(Expr) and return from fn(...)->Expr +} +``` +To unroll the loop, we basically just what the DSL to produce `p*=x` for `i` times. We can use the `escape!(expr)` macro so that it leaves `expr` as is, preserving the native loop. +```rust +#[tracked] +fn pow_unrolled(x:Expr, i:u32)->Expr { + let p = 1.0f32.var(); + escape!({ + for _ in 0..i { + track!({ + p *= x; + }); + }); + **p +} ``` + + +### Variables and Expressions +We support the following primitive types on backend `bool`, `i32`, `u32`, `i64`, `u64`, `f32`. Additional primitive types such as `u8`, `i8`, `i16`, `u16`, and `f64` are supported on some backends. +For each type, there are two EDSL proxy objects `Expr` and `Var`. `Expr` is an immutable object that represents a value. `Var` is also an **immutable** object that represents a variable (mutable value). + +**Warning**: Every DSL object in host code **must** be immutable due to Rust unable to overload `operator =`. Attempting to circumvent this limitation using `Cell` and `RefCell` would likely result in uncompilable kernels/wrong results. +For example: +```rust +// **no good** +let v = Cell::new(0.0f32.expr()); +track!(if cond { + v.set(v.get() + 1.0); +)); + +// **good** +let v = 0.0f32.var(); +track!(if cond { + *v += 1.0; +)); + All operations except load/store should be performed on `Expr`. `Var` can only be used to load/store values. -As in the C++ EDSL, we additionally supports the following vector/matrix types. Their proxy types are `XXXExpr` and `XXXVar`: +As in the C++ EDSL, we additionally supports the vector of length 2-4 for all primitives and float square matrices with dimension 2-4 such as ```rust -Bool2 // bool2 in C++ -Bool3 // bool3 in C++ -Bool4 // bool4 in C++ -Float2 // float2 in C++ -Float3 // float3 in C++ -Float4 // float4 in C++ -Int2 // int2 in C++ -Int3 // int3 in C++ -Int4 // int4 in C++ -Uint2 // uint2 in C++ -Uint3 // uint3 in C++ -Uint4 // uint4 in C++ -Mat2 // float2x2 in C++ -Mat3 // float3x3 in C++ -Mat4 // float4x4 in C++ +luisa_compute::lang::types::vector::alias::{ + Bool2 + Bool3 + Bool4 + Float2 + Float3 + Float4 + Int2 + Int3 + Int4 + Uint2 + Uint3 + Uint4 +}; + +luisa_compute::lang::types::vector::{Mat2, Mat3, Mat4}; ``` -Array types `[T;N]` are also supported and their proxy types are `ArrayExpr` and `ArrayVar`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar` for element access. `ArrayExpr` can be stored to and loaded from `ArrayVar`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar`. `VLArrayVar::zero(length: usize)` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar` in host, use `VLArrayVar::static_len()->usize`. To query the length in kernel, use `VLArrayVar::len()->Expr` -Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.type()`. For example: +Array types `[T;N]` are also supported. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar` for element access. `ArrayExpr` can be stored to and loaded from `ArrayVar`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar`. `VLArrayVar::zero(length: usize)` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar` in host, use `VLArrayVar::static_len()->usize`. To query the length in kernel, use `VLArrayVar::len()->Expr` + +Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.as_::()`, `v.as_Type()` and `v.as_PrimitiveType()`. For example: ```rust let iv = Int2::expr(1, 1, 1); -let fv = iv.float(); //fv is Expr -let bv = fv.bool(); // bv is Expr +let fv = iv.as_::(); //fv is Expr +let also_fv = iv.as_float2(); +let also_fv = iv.cast_f32(); ``` To perform a bitwise cast, use the `bitcast` function. `let fv:Expr = bitcast::(0u32);` ### Builtin Functions -We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr` or a literal like `0.0`. However, the `select` function is slightly different as it does not accept literals. You need to use `select(cond, f_var, 1.0f32.expr())`. +We have extentded primitive types with methods similar to their host counterpart: `v.sin(), luisa::max(a, b)`, etc. Most methods accepts both a `Expr` or a literal such as `0.0`. However, the `select` function is slightly different as it does not accept literals. You need to use `select(cond, f_var, 1.0f32.expr())`. ### Control Flow -*Note*, you cannot modify outer scope variables inside a control flow block by declaring the variable as `mut`. To modify outer scope variables, use `Var` instead and call *var.get_mut() = value` to store the value back to the outer scope. - -If, while, break, continue are supported. Note that `if` and `switch` works similar to native Rust `if` and `match` in that values can be returned at the end of the block. +*Note*, you cannot modify outer scope variables inside a control flow block by declaring the variable as `mut`. To modify outer scope variables, use `Var` instead and store the value back to the outer scope. +`if`, `while`, `break`, `continue`, `return` and `loop` are supported via `tracked!` macro. It is also possible to construct these control flows without `track!`. +The `switch_` statement has to be constructe manually inside a `escape!` block. For example, ```rust -if_!(cond, { /* then */}); -if_!(cond, { /* then */}, { /* else */}); -if_!(cond, { value_a }, { value_b }) -while_!(cond, { /* body */}); -for_range(start..end, |i| { /* body */}); -/* For loops in C-style are mapped to generic loops -for(init; cond; update) { body } is mapped to: -init; -generic_loop(cond, body, update) -*/ -generic_loop(|| -> Expr{ /*cond*/ }, || { /* body */}, || { /* update after each iteration */}) -break_(); -continue_(); let (x,y) = switch::<(Expr, Expr)>(value) .case(1, || { ... }) .case(2, || { ... }) @@ -222,73 +257,57 @@ let (x,y) = switch::<(Expr, Expr)>(value) .finish(); ``` -### `track!` Mcro - -We also offer a `track!` macro that automatically rewrites control flow primitves and comparison operators. For example (from [`examples/mpm.rs`](luisa_compute/examples/mpm.rs)): - -```rust -track!(|| { - // ... - let vx = select( - coord.x() < BOUND && (vx < 0.0f32) - || coord.x() + BOUND > N_GRID as u32 && (vx > 0.0f32), - 0.0f32.into(), - vx, - ); - let vy = select( - coord.y() < BOUND && (vy < 0.0f32) - || coord.y() + BOUND > N_GRID as u32 && (vy > 0.0f32), - 0.0f32.into(), - vy, - ); - // ... -}) -``` -is equivalent to: -```rust -|| { - // ... - let vx = select( - (coord.x().cmplt(BOUND) & vx.cmplt(0.0f32)) - | (coord.x() + BOUND).cmpgt(N_GRID as u32) & vx.cmpgt(0.0f32), - 0.0f32.into(), - vx, - ); - let vy = select( - (coord.y().cmplt(BOUND) & vy.cmplt(0.0f32)) - | (coord.y() + BOUND).cmpgt(N_GRID as u32) & vy.cmpgt(0.0f32), - 0.0f32.into(), - vy, - ); - // ... -} -``` -Similarily, -```rust -track!(if cond { foo } else if bar { baz } else { qux }) -``` -will be converted to -```rust -if_!(cond, { foo }, { if_!(bar, { baz }, { qux }) }) -``` - -Note that this macro will rewrite `while`, `for _ in x..y`, and `loop` expressions to versions using functions, which will then break the `break` and `continue` expressions. In order to avoid this, it's possible to use the `escape!` macro within a `track!` context to disable rewriting for an expression. +**Warning**: due to backend generates C-like source code instead of LLVM IR/PTX/DXIL directly, it is not possible to use `break` inside switch cases. ### Custom Data Types -To add custom data types to the EDSL, simply derive from `Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`: +To add custom data types to the EDSL, simply derive from `Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. +`#[derive(Value)]` would generate two proxies types: `XXExpr` and `XXVar`. Implement your methods on these proxies instead of `Expr` and `Var` directly. ```rust #[derive(Copy, Clone, Default, Debug, Value)] #[repr(C)] +#[value_new(pub)] pub struct MyVec2 { pub x: f32, pub y: f32, } -let v = MyVec2.var(); -let sum = *v.x() + *v.y(); -*v.x().get_mut() += 1.0; +impl MyVec2Expr { + // pass arguments using `AsExpr` so that they accept both Var and Expr + #[tracked] + pub fn dot(&self, other: impl AsExpr) { + self.x * other.x + self.y * other.y + } +} +impl MyVec2Var { + #[tracked] + pub fn set_to_one(&self) { + // you can access the current `Var` using `self_` + self.self_ = MyVec2::new_expr(1.0, 1.0); + } +} + +track!({ + let v = MyVec2::var_zeroed(); + let sum = v.x +*v.y; + *v.x += 1.0; + let v = MyVec2::from_comps_expr(MyVec2Comps{x:1.0f32.expr(), y:1.0f32.expr()}); + let v = MyVec2::new_expr(1.0f32, 2.0f32); // only if #[value_new] is present +}); + +// You can also control the order of arguments in `#[value_new]` +#[derive(Copy, Clone, Default, Debug, Value)] +#[repr(C)] +#[value_new(pub y, x)] +pub struct Foo { + pub x: f32, + pub y: i32, +} +let v = MyVec2::new_expr(1.0fi32, 2.0f32); +// v.x == 2.0 +// v.y == 1 ``` + ### Polymorphism We prvoide a powerful `Polymorphic` construct as in the C++ DSL. See examples for more detail ```rust @@ -302,7 +321,7 @@ pub struct Circle { } impl Area for CircleExpr { fn area(&self) -> Float32 { - PI * self.radius() * self.radius() + PI * self.radius * self.radius } } impl_polymorphic!(Area, Circle); @@ -359,34 +378,34 @@ let result = my_add.call(args); Users can define device-only functions using Callables. Callables have similar type signature to kernels: `CallableRet>`. The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar`, .etc), expressions and references (pass a `Var` to the callable). For example: ```rust -let add = device.create_callable::, Expr)-> Expr>(&|a, b| { +let add = Callable::, Expr)-> Expr>::new(&device, track!(|a, b| { a + b -}); +})); let z = add.call(x, y); -let pass_by_ref = device.create_callable::)>(&|a| { - *a.get_mut() += 1.0; -}); +let pass_by_ref = Callable::)>::new(&device, track!(|a| { + a += 1.0; +})); let a = 1.0f32.var(); pass_by_ref.call(a); cpu_dbg!(*a); // prints 2.0 ``` ***Note***: You cannot record a callable when recording another kernel or callables. This is because a callable can capture outer variables such as buffers. However, capturing local variables define in another callable is undefined behavior. To avoid this, we disallow recording a callable when recording another callable or kernel. ```rust -let add = device.create_callable::, Expr)-> Expr>(&|a, b| { +let add = Callable::, Expr)-> Expr>::new(&device, track!(|a, b| { // runtime error! - let another_add = device.create_callable::, Expr)-> Expr>(&|a, b| { + let another_add = Callable::, Expr)-> Expr>::new(&device, track!(|a, b| { a + b - }); + })); a + b -}); +})); ``` ***However, we acknowledge that recording a callable inside another callable/kernel is a useful feature***. Thus we provide two ways to workaround this limitation: 1. Use static callables. A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example, ```rust lazy_static! { - static ref ADD:Callable, Expr)->Expr> = create_static_callable::, Expr)->Expr>(|a, b| { - a + b + static ref ADD:Callable, Expr)->Expr> = Callable::, Expr)->Expr>::new_static(|a, b| { + track!(a + b) }); } ADD.call(x, y); @@ -395,13 +414,13 @@ ADD.call(x, y); 2. Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL. ```rust -let add = device.create_callable::, Expr)->Expr>(&|a, b| { +let add = Callable::, Expr)-> Expr>::new(&device, track!(|a, b| { // no error! - let another_add = device.create_dyn_callable::, Expr)->Expr>(Box::new(|a, b| { + let another_add = DynCallable::, Expr)-> Expr>::new(&device, track!(Box::new(|a, b| { a + b - })); + }))); a + b -}); +})); ``` ### Kernel @@ -462,7 +481,7 @@ Safety checks such as OOB is generally not available for GPU backends. As it is When using luisa-compute-rs in an academic project, we encourage you to cite ```bibtex @misc{LuisaComputeRust - author = {Xiaochun Tong}, + author = {Xiaochun Tong, et al}, year = {2023}, note = {https://github.com/LuisaGroup/luisa-compute-rs}, title = {Rust frontend to LuisaCompute} diff --git a/luisa_compute/examples/atomic.rs b/luisa_compute/examples/atomic.rs index f6651cef..82f01919 100644 --- a/luisa_compute/examples/atomic.rs +++ b/luisa_compute/examples/atomic.rs @@ -10,12 +10,15 @@ fn main() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_sum = sum.var(); - let tid = dispatch_id().x; - buf_sum.atomic_fetch_add(0, buf_x.read(tid)); - })); + let shader = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_sum = sum.var(); + let tid = dispatch_id().x; + buf_sum.atomic_fetch_add(0, buf_x.read(tid)); + }), + ); shader.dispatch([x.len() as u32, 1, 1]); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 15842cdc..57932e06 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -29,7 +29,7 @@ fn main() { let dy_gt = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = device.create_kernel::(track!(&|| { + let shader = Kernel::::new(&device, track!(|| { let tid = dispatch_id().x; let buf_x = x.var(); let buf_y = y.var(); diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index fdaa0361..5ab81fa7 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -24,7 +24,7 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(track!(&|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); diff --git a/luisa_compute/examples/bindgroup.rs b/luisa_compute/examples/bindgroup.rs index a2965fb8..77df3ed8 100644 --- a/luisa_compute/examples/bindgroup.rs +++ b/luisa_compute/examples/bindgroup.rs @@ -22,6 +22,6 @@ fn main() { y, exclude: 42.0, }; - let shader = device.create_kernel::)>(&|_args| {}); + let shader = Kernel::)>::new(&device, |_args| {}); shader.dispatch([1024, 1, 1], &my_args); } diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index 0ee206dd..48e37a7f 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -62,7 +62,7 @@ fn main() { bindless.emplace_buffer_async(1, &y); bindless.emplace_tex2d_async(0, &img, Sampler::default()); bindless.update(); - let kernel = device.create_kernel::)>(&track!(|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { let bindless = bindless.var(); let tid = dispatch_id().x; let buf_x = bindless.buffer::(0_u32.expr()); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 23255297..fc9c7457 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -18,13 +18,13 @@ fn main() { "cpu" }); let add = - device.create_callable::, Expr) -> Expr>(&|a, b| track!(a + b)); + Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x; diff --git a/luisa_compute/examples/callable_advanced.rs b/luisa_compute/examples/callable_advanced.rs index 5de40244..f8630ab3 100644 --- a/luisa_compute/examples/callable_advanced.rs +++ b/luisa_compute/examples/callable_advanced.rs @@ -18,8 +18,9 @@ fn main() { } else { "cpu" }); - let add = device.create_dyn_callable:: DynExpr>(Box::new( - |a: DynExpr, b: DynExpr| -> DynExpr { + let add = DynCallable:: DynExpr>::new( + &device, + Box::new(|a: DynExpr, b: DynExpr| -> DynExpr { if let Some(a) = a.downcast::() { let b = b.downcast::().unwrap(); return DynExpr::new(track!(a + b)); @@ -29,28 +30,31 @@ fn main() { } else { unreachable!() } - }, - )); + }), + ); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - buf_z.write(tid, add.call(x.into(), y.into()).get::()); - w.var().write( - tid, - add.call(x.as_::().into(), y.as_::().into()) - .get::(), - ); - })); + buf_z.write(tid, add.call(x.into(), y.into()).get::()); + w.var().write( + tid, + add.call(x.as_::().into(), y.as_::().into()) + .get::(), + ); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index f039264d..65bcbcf9 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -28,21 +28,24 @@ fn main() { println!("Hello from thread 0!"); } }); - let shader = device.create_kernel::)>(&track!(|buf_z: BufferVar| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let args = MyAddArgs::new_expr(x, y, 0.0f32.expr()); - let result = my_add.call(args); - let _ = my_print.call(tid); - if tid == 0 { - cpu_dbg!(args); - } - buf_z.write(tid, result.result); - })); + let shader = Kernel::)>::new( + &device, + track!(|buf_z: BufferVar| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let args = MyAddArgs::new_expr(x, y, 0.0f32.expr()); + let result = my_add.call(args); + let _ = my_print.call(tid); + if tid == 0 { + cpu_dbg!(args); + } + buf_z.write(tid, result.result); + }), + ); shader.dispatch([1024, 1, 1], &z); let mut z_data = vec![0.0; 1024]; z.view(..).copy_to(&mut z_data); diff --git a/luisa_compute/examples/fluid.rs b/luisa_compute/examples/fluid.rs index cf5d1eed..c17ac49d 100644 --- a/luisa_compute/examples/fluid.rs +++ b/luisa_compute/examples/fluid.rs @@ -118,24 +118,25 @@ fn main() { } ); - let advect = device - .create_kernel_async::, Buffer, Buffer, Buffer)>( - track!(&|u0, u1, rho0, rho1| { - let coord = dispatch_id().xy(); - let u = u0.read(index(coord)); - - // trace backward - let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32()); - p = p - u * dt; - - // advect - u1.write(index(coord), sample_vel(u0, p.x, p.y)); - rho1.write(index(coord), sample_float(rho0, p.x, p.y)); - }), - ); + let advect = Kernel::, Buffer, Buffer, Buffer)>::new_async( + &device, + track!(|u0, u1, rho0, rho1| { + let coord = dispatch_id().xy(); + let u = u0.read(index(coord)); + + // trace backward + let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32()); + p = p - u * dt; + + // advect + u1.write(index(coord), sample_vel(u0, p.x, p.y)); + rho1.write(index(coord), sample_float(rho0, p.x, p.y)); + }), + ); - let divergence = - device.create_kernel_async::, Buffer)>(track!(&|u, div| { + let divergence = Kernel::, Buffer)>::new_async( + &device, + track!(|u, div| { let coord = dispatch_id().xy(); if coord.x < (N_GRID as u32 - 1) && coord.y < (N_GRID as u32 - 1) { let dx = (u.read(index(Uint2::expr(coord.x + 1, coord.y))).x @@ -146,10 +147,12 @@ fn main() { * 0.5; div.write(index(coord), dx + dy); } - })); + }), + ); - let pressure_solve = device.create_kernel_async::, Buffer, Buffer)>( - track!(&|p0, p1, div| { + let pressure_solve = Kernel::, Buffer, Buffer)>::new_async( + &device, + track!(|p0, p1, div| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -166,8 +169,9 @@ fn main() { }), ); - let pressure_apply = - device.create_kernel_async::, Buffer)>(track!(&|p, u| { + let pressure_apply = Kernel::, Buffer)>::new_async( + &device, + track!(|p, u| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -184,10 +188,12 @@ fn main() { u.write(ij, u.read(ij) - f_p); } - })); + }), + ); - let integrate = - device.create_kernel_async::, Buffer)>(track!(&|u, rho| { + let integrate = Kernel::, Buffer)>::new_async( + &device, + track!(|u, rho| { let coord = dispatch_id().xy(); let ij = index(coord); @@ -199,10 +205,12 @@ fn main() { // fade rho.write(ij, rho.read(ij) * (1.0f32 - 0.1f32 * dt)); - })); + }), + ); - let init = device.create_kernel_async::, Buffer, Float2)>(track!( - &|rho, u, dir| { + let init = Kernel::, Buffer, Float2)>::new_async( + &device, + track!(|rho, u, dir| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -214,10 +222,10 @@ fn main() { rho.write(ij, 1.0f32); u.write(ij, dir); } - } - )); + }), + ); - let init_grid = device.create_kernel_async::(&|| { + let init_grid = Kernel::::new_async(&device, || { let idx = index(dispatch_id().xy()); u0.var().write(idx, Float2::expr(0.0f32, 0.0f32)); u1.var().write(idx, Float2::expr(0.0f32, 0.0f32)); @@ -230,21 +238,24 @@ fn main() { div.var().write(idx, 0.0f32); }); - let clear_pressure = device.create_kernel_async::(&|| { + let clear_pressure = Kernel::::new_async(&device, || { let idx = index(dispatch_id().xy()); p0.var().write(idx, 0.0f32); p1.var().write(idx, 0.0f32); }); - let draw_rho = device.create_kernel_async::(&track!(|| { - let coord = dispatch_id().xy(); - let ij = index(coord); - let value = rho0.var().read(ij); - display.var().write( - Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y), - Float4::expr(value, 0.0f32, 0.0f32, 1.0f32), - ); - })); + let draw_rho = Kernel::::new_async( + &device, + track!(|| { + let coord = dispatch_id().xy(); + let ij = index(coord); + let value = rho0.var().read(ij); + display.var().write( + Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y), + Float4::expr(value, 0.0f32, 0.0f32, 1.0f32), + ); + }), + ); event_loop.run(move |event, _, control_flow| { control_flow.set_poll(); diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 8044ccf4..b731b8e0 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -90,137 +90,154 @@ fn main() { p.x + p.y * N_GRID as u32 }); - let clear_grid = device.create_kernel_async::(track!(&|| { - let idx = index(dispatch_id().xy()); - grid_v.var().write(idx * 2, 0.0f32); - grid_v.var().write(idx * 2 + 1, 0.0f32); - grid_m.var().write(idx, 0.0f32); - })); - - let point_to_grid = device.create_kernel_async::(track!(&|| { - let p = dispatch_id().x; - let xp = x.var().read(p) / DX; - let base = (xp - 0.5f32).cast_i32(); - let fx = xp - base.cast_f32(); - - let w = [ - 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), - 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), - 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), - ]; - let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX); - let affine = - Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); - let vp = v.var().read(p); - escape!(for ii in 0..9 { - let (i, j) = (ii % 3, ii / 3); - track!({ - let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX; - let weight = w[i].x * w[j].y; - let vadd = weight * (P_MASS * vp + affine * dpos); - let idx = index((base + offset).cast_u32()); - grid_v.var().atomic_fetch_add(idx * 2, vadd.x); - grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); - grid_m.var().atomic_fetch_add(idx, weight * P_MASS); - }); - }); - let _ = (); // WHAT? - })); - - let simulate_grid = device.create_kernel_async::(&track!(|| { - let coord = dispatch_id().xy(); - let i = index(coord); - let v = Var::::zeroed(); - v.store(Float2::expr( - grid_v.var().read(i * 2u32), - grid_v.var().read(i * 2u32 + 1u32), - )); - let m = grid_m.var().read(i); - - v.store(select(m > 0.0f32, v.load() / m, v.load())); - let vx = v.load().x; - let vy = v.load().y - DT * GRAVITY; - let vx = select( - coord.x < BOUND && (vx < 0.0f32) || coord.x + BOUND > N_GRID as u32 && (vx > 0.0f32), - 0.0f32.expr(), - vx, - ); - let vy = select( - coord.y < BOUND && (vy < 0.0f32) || coord.y + BOUND > N_GRID as u32 && (vy > 0.0f32), - 0.0f32.expr(), - vy, - ); - grid_v.var().write(i * 2, vx); - grid_v.var().write(i * 2 + 1, vy); - })); + let clear_grid = Kernel::::new( + &device, + track!(|| { + let idx = index(dispatch_id().xy()); + grid_v.var().write(idx * 2, 0.0f32); + grid_v.var().write(idx * 2 + 1, 0.0f32); + grid_m.var().write(idx, 0.0f32); + }), + ); - let grid_to_point = device.create_kernel_async::(track!(&|| { - let p = dispatch_id().x; - let xp = x.var().read(p) / DX; - let base = (xp - 0.5f32).cast_i32(); - let fx = xp - base.cast_f32(); + let point_to_grid = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + let xp = x.var().read(p) / DX; + let base = (xp - 0.5f32).cast_i32(); + let fx = xp - base.cast_f32(); - let w = [ - 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), - 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), - 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), - ]; - let new_v = Var::::zeroed(); - let new_C = Var::::zeroed(); - new_v.store(Float2::expr(0.0f32, 0.0f32)); - new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); - escape!({ - for ii in 0..9 { + let w = [ + 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), + 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), + 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), + ]; + let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX); + let affine = + Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); + let vp = v.var().read(p); + escape!(for ii in 0..9 { let (i, j) = (ii % 3, ii / 3); track!({ let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX.expr(); + let dpos = (offset.cast_f32() - fx) * DX; let weight = w[i].x * w[j].y; + let vadd = weight * (P_MASS * vp + affine * dpos); let idx = index((base + offset).cast_u32()); - let g_v = Float2::expr( - grid_v.var().read(idx * 2u32), - grid_v.var().read(idx * 2u32 + 1u32), - ); - new_v.store(new_v.load() + weight * g_v); - new_C.store( - new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX), - ); + grid_v.var().atomic_fetch_add(idx * 2, vadd.x); + grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); + grid_m.var().atomic_fetch_add(idx, weight * P_MASS); }); - } - }); + }); + let _ = (); // WHAT? + }), + ); - v.var().write(p, new_v); - x.var().write(p, x.var().read(p) + new_v.load() * DT); - J.var() - .write(p, J.var().read(p) * (1.0f32 + DT * trace(new_C.load()))); - C.var().write(p, new_C); - })); + let simulate_grid = Kernel::::new( + &device, + track!(|| { + let coord = dispatch_id().xy(); + let i = index(coord); + let v = Var::::zeroed(); + v.store(Float2::expr( + grid_v.var().read(i * 2u32), + grid_v.var().read(i * 2u32 + 1u32), + )); + let m = grid_m.var().read(i); - let clear_display = device.create_kernel_async::(&|| { + v.store(select(m > 0.0f32, v.load() / m, v.load())); + let vx = v.load().x; + let vy = v.load().y - DT * GRAVITY; + let vx = select( + coord.x < BOUND && (vx < 0.0f32) + || coord.x + BOUND > N_GRID as u32 && (vx > 0.0f32), + 0.0f32.expr(), + vx, + ); + let vy = select( + coord.y < BOUND && (vy < 0.0f32) + || coord.y + BOUND > N_GRID as u32 && (vy > 0.0f32), + 0.0f32.expr(), + vy, + ); + grid_v.var().write(i * 2, vx); + grid_v.var().write(i * 2 + 1, vy); + }), + ); + + let grid_to_point = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + let xp = x.var().read(p) / DX; + let base = (xp - 0.5f32).cast_i32(); + let fx = xp - base.cast_f32(); + + let w = [ + 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), + 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), + 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), + ]; + let new_v = Var::::zeroed(); + let new_C = Var::::zeroed(); + new_v.store(Float2::expr(0.0f32, 0.0f32)); + new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); + escape!({ + for ii in 0..9 { + let (i, j) = (ii % 3, ii / 3); + track!({ + let offset = Int2::expr(i as i32, j as i32); + let dpos = (offset.cast_f32() - fx) * DX.expr(); + let weight = w[i].x * w[j].y; + let idx = index((base + offset).cast_u32()); + let g_v = Float2::expr( + grid_v.var().read(idx * 2u32), + grid_v.var().read(idx * 2u32 + 1u32), + ); + new_v.store(new_v.load() + weight * g_v); + new_C.store( + new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX), + ); + }); + } + }); + + v.var().write(p, new_v); + x.var().write(p, x.var().read(p) + new_v.load() * DT); + J.var() + .write(p, J.var().read(p) * (1.0f32 + DT * trace(new_C.load()))); + C.var().write(p, new_C); + }), + ); + + let clear_display = Kernel::::new(&device, || { display.var().write( dispatch_id().xy(), Float4::expr(0.1f32, 0.2f32, 0.3f32, 1.0f32), ); }); - let draw_particles = device.create_kernel_async::(&track!(|| { - let p = dispatch_id().x; - for i in -1..=1 { - for j in -1..=1 { - let pos = (x.var().read(p) * RESOLUTION as f32).cast_i32() + Int2::expr(i, j); - if pos.x >= (0i32) - && pos.x < (RESOLUTION as i32) - && pos.y >= (0i32) - && pos.y < (RESOLUTION as i32) - { - display.var().write( - Uint2::expr(pos.x.cast_u32(), RESOLUTION - 1u32 - pos.y.cast_u32()), - Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32), - ); + let draw_particles = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + for i in -1..=1 { + for j in -1..=1 { + let pos = (x.var().read(p) * RESOLUTION as f32).cast_i32() + Int2::expr(i, j); + if pos.x >= (0i32) + && pos.x < (RESOLUTION as i32) + && pos.y >= (0i32) + && pos.y < (RESOLUTION as i32) + { + display.var().write( + Uint2::expr(pos.x.cast_u32(), RESOLUTION - 1u32 - pos.y.cast_u32()), + Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32), + ); + } } } - } - })); + }), + ); event_loop.run(move |event, _, control_flow| { control_flow.set_poll(); match event { diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index 7278494b..712f7b40 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -8,7 +8,9 @@ use winit::event_loop::EventLoop; use luisa::lang::types::vector::{alias::*, *}; use luisa::prelude::*; -use luisa::rtx::{offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps}; +use luisa::rtx::{ + offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, +}; use luisa_compute as luisa; #[derive(Value, Clone, Copy)] @@ -244,11 +246,12 @@ fn main() { }); // use create_kernel_async to compile multiple kernels in parallel - let path_tracer = device.create_kernel_async::, Tex2d, Accel, Uint2)>( - track!(&|image: Tex2dVar, - seed_image: Tex2dVar, - accel: AccelVar, - resolution: Expr| { + let path_tracer = Kernel::, Tex2d, Accel, Uint2)>::new_async( + &device, + track!(|image: Tex2dVar, + seed_image: Tex2dVar, + accel: AccelVar, + resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); let cbox_materials = ([ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor @@ -263,7 +266,7 @@ fn main() { .expr(); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::) -> Expr>(|state: Var| { + let lcg = Callable::) -> Expr>::new_static(|state: Var| { const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state = LCG_A * state + LCG_C; @@ -278,7 +281,7 @@ fn main() { orig: o.into(), tmin: tmin, dir: d.into(), - tmax: tmax + tmax: tmax, }) }; @@ -455,8 +458,9 @@ fn main() { ); }), ); - let display = - device.create_kernel_async::, Tex2d)>(track!(&|acc, display| { + let display = Kernel::, Tex2d)>::new_async( + &device, + track!(|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); @@ -468,7 +472,8 @@ fn main() { let srgb = radiance.lt(0.0031308).select(radiance * 12.92, r); display.write(coord, Float4::expr(srgb.x, srgb.y, srgb.z, 1.0f32)); - })); + }), + ); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 92cd10de..e7a325e6 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -256,11 +256,12 @@ fn main() { }); // use create_kernel_async to compile multiple kernels in parallel - let path_tracer = device.create_kernel_async::, Tex2d, Accel, Uint2)>( - track!(&|image: Tex2dVar, - seed_image: Tex2dVar, - accel: AccelVar, - resolution: Expr| { + let path_tracer = Kernel::, Tex2d, Accel, Uint2)>::new_async( + &device, + track!(|image: Tex2dVar, + seed_image: Tex2dVar, + accel: AccelVar, + resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); let cbox_materials = [ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor @@ -275,7 +276,7 @@ fn main() { .expr(); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::) -> Expr>(|state: Var| { + let lcg = Callable::) -> Expr>::new_static(|state: Var| { const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state = LCG_A * state + LCG_C; @@ -505,8 +506,9 @@ fn main() { ); }), ); - let display = - device.create_kernel_async::, Tex2d)>(track!(&|acc, display| { + let display = Kernel::, Tex2d)>::new_async( + &device, + track!(|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); @@ -518,7 +520,8 @@ fn main() { let srgb = radiance.lt(0.0031308).select(radiance * 12.92, r); display.write(coord, Float4::expr(srgb.x, srgb.y, srgb.z, 1.0f32)); - })); + }), + ); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/polymorphism.rs b/luisa_compute/examples/polymorphism.rs index 9a9c4126..8539c4e4 100644 --- a/luisa_compute/examples/polymorphism.rs +++ b/luisa_compute/examples/polymorphism.rs @@ -50,15 +50,18 @@ fn main() { poly_area.register((), &circles); poly_area.register((), &squares); let areas = device.create_buffer::(4); - let shader = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let tag = tid / 2; - let index = tid % 2; - let area = poly_area - .get(TagIndex::new_expr(tag, index)) - .dispatch(|_tag, _key, obj| obj.area()); - areas.var().write(tid, area); - })); + let shader = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let tag = tid / 2; + let index = tid % 2; + let area = poly_area + .get(TagIndex::new_expr(tag, index)) + .dispatch(|_tag, _key, obj| obj.area()); + areas.var().write(tid, area); + }), + ); shader.dispatch([4, 1, 1]); let areas = areas.view(..).copy_to_vec(); println!("{:?}", areas); diff --git a/luisa_compute/examples/polymorphism_advanced.rs b/luisa_compute/examples/polymorphism_advanced.rs index 316600a6..a62a03a4 100644 --- a/luisa_compute/examples/polymorphism_advanced.rs +++ b/luisa_compute/examples/polymorphism_advanced.rs @@ -132,19 +132,22 @@ fn main() { ); let poly_shader = builder.build(); let result = device.create_buffer::(100); - let kernel = device.create_kernel::(&track!(|| { - let i = dispatch_id().x; - let x = i.as_f32() / 100.0 * PI; - let ctx = ShaderEvalContext { - poly_shader: &poly_shader, - key: &shader_final_key, - }; - let tag_index = TagIndex::new_expr(shader_final.tag, shader_final.index); - let v = poly_shader - .get(tag_index) - .dispatch(|_, _, shader| shader.evaluate(x, &ctx)); - result.var().write(i, v); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let i = dispatch_id().x; + let x = i.as_f32() / 100.0 * PI; + let ctx = ShaderEvalContext { + poly_shader: &poly_shader, + key: &shader_final_key, + }; + let tag_index = TagIndex::new_expr(shader_final.tag, shader_final.index); + let v = poly_shader + .get(tag_index) + .dispatch(|_, _, shader| shader.evaluate(x, &ctx)); + result.var().write(i, v); + }), + ); kernel.dispatch([100, 1, 1]); let result = result.copy_to_vec(); for i in 0..100 { diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index af173ae7..9312a2b1 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -21,7 +21,7 @@ fn main() { "cpu" }); let printer = Printer::new(&device, 65536); - let kernel = device.create_kernel::(track!(&|| { + let kernel = Kernel::::new(&device, track!(|| { let id = dispatch_id().xy(); if id.x == id.y { lc_info!(printer, "id = {:?}", id); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 1b6d63fc..b04007d1 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -119,93 +119,98 @@ fn main() { let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); let debug_hit_t = device.create_buffer::(4); - let rt_kernel = device.create_kernel::(&track!(|| { - let accel = accel.var(); - let px = dispatch_id().xy(); - let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32); - let xy = 2.0 * xy - 1.0; - let o = Float3::expr(0.0, 0.0, -2.0); - let d = Float3::expr(xy.x, xy.y, 0.0) - o; - let d = d.normalize(); + let rt_kernel = Kernel::::new( + &device, + track!(|| { + let accel = accel.var(); + let px = dispatch_id().xy(); + let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32); + let xy = 2.0 * xy - 1.0; + let o = Float3::expr(0.0, 0.0, -2.0); + let d = Float3::expr(xy.x, xy.y, 0.0) - o; + let d = d.normalize(); - let ray = Ray::new_expr( - Expr::<[f32; 3]>::from(o + translate.expr()), - 1e-3f32, - Expr::<[f32; 3]>::from(d), - 1e9f32, - ); - let hit = accel.query_all( - ray, - 255, - RayQuery { - on_triangle_hit: |candidate: TriangleCandidate| { - let bary = candidate.bary; - let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); - let t = candidate.committed_ray_t; - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(0, t); - debug_hit_t.write(1, candidate.ray().tmax); - }; - // if (uvw.xy().length() < 0.8) - // & (uvw.yz().length() < 0.8) - // & (uvw.xz().length() < 0.8) - if uvw.xy().length() < 0.8 && uvw.yz().length() < 0.8 && uvw.xz().length() < 0.8 - { - candidate.commit(); - } - }, - on_procedural_hit: |candidate: ProceduralCandidate| { - let ray = candidate.ray(); - let prim = candidate.prim; - let sphere = spheres.var().read(prim); - let o: Expr = ray.orig.into(); - let d: Expr = ray.dir.into(); - let t = Var::::zeroed(); + let ray = Ray::new_expr( + Expr::<[f32; 3]>::from(o + translate.expr()), + 1e-3f32, + Expr::<[f32; 3]>::from(d), + 1e9f32, + ); + let hit = accel.query_all( + ray, + 255, + RayQuery { + on_triangle_hit: |candidate: TriangleCandidate| { + let bary = candidate.bary; + let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); + let t = candidate.committed_ray_t; + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(0, t); + debug_hit_t.write(1, candidate.ray().tmax); + }; + // if (uvw.xy().length() < 0.8) + // & (uvw.yz().length() < 0.8) + // & (uvw.xz().length() < 0.8) + if uvw.xy().length() < 0.8 + && uvw.yz().length() < 0.8 + && uvw.xz().length() < 0.8 + { + candidate.commit(); + } + }, + on_procedural_hit: |candidate: ProceduralCandidate| { + let ray = candidate.ray(); + let prim = candidate.prim; + let sphere = spheres.var().read(prim); + let o: Expr = ray.orig.into(); + let d: Expr = ray.dir.into(); + let t = Var::::zeroed(); - for _ in 0..100 { - let dist = (o + d * t - (sphere.center + translate.expr())).length() - - sphere.radius; - if dist < 0.001 { - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(2, t); - debug_hit_t.write(3, candidate.ray().tmax); - } - if t < ray.tmax { - candidate.commit(t); + for _ in 0..100 { + let dist = (o + d * t - (sphere.center + translate.expr())).length() + - sphere.radius; + if dist < 0.001 { + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(2, t); + debug_hit_t.write(3, candidate.ray().tmax); + } + if t < ray.tmax { + candidate.commit(t); + } + break; } - break; + *t += dist; } - *t += dist; - } + }, }, - }, - ); - let img = img.view(0).var(); - let color = if hit.triangle_hit() { - let bary = hit.bary; - let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); - uvw - } else { - if hit.procedural_hit() { - let prim = hit.prim_id; - let sphere = spheres.var().read(prim); - let normal = (Expr::::from(ray.orig) - + Expr::::from(ray.dir) * hit.committed_ray_t - - sphere.center) - .normalize(); - let light_dir = Float3::expr(1.0, 0.6, -0.2).normalize(); - let light = Float3::expr(1.0, 1.0, 1.0); - let ambient = Float3::expr(0.1, 0.1, 0.1); - let diffuse = luisa::max(light * normal.dot(light_dir), 0.0); - let color = ambient + diffuse; - color + ); + let img = img.view(0).var(); + let color = if hit.triangle_hit() { + let bary = hit.bary; + let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); + uvw } else { - Float3::expr(0.0, 0.0, 0.0) - } - }; + if hit.procedural_hit() { + let prim = hit.prim_id; + let sphere = spheres.var().read(prim); + let normal = (Expr::::from(ray.orig) + + Expr::::from(ray.dir) * hit.committed_ray_t + - sphere.center) + .normalize(); + let light_dir = Float3::expr(1.0, 0.6, -0.2).normalize(); + let light = Float3::expr(1.0, 1.0, 1.0); + let ambient = Float3::expr(0.1, 0.1, 0.1); + let diffuse = luisa::max(light * normal.dot(light_dir), 0.0); + let color = ambient + diffuse; + color + } else { + Float3::expr(0.0, 0.0, 0.0) + } + }; - img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); - })); + img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); + }), + ); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() .with_title("Luisa Compute Rust - Ray Query") diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 0adb7cb3..672e7246 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -36,7 +36,7 @@ fn main() { let img_w = 800; let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); - let rt_kernel = device.create_kernel::(&track!(|| { + let rt_kernel = Kernel::::new(&device,track!(|| { let accel = accel.var(); let px = dispatch_id().xy(); let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index 448023d5..a90882f2 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -2,8 +2,9 @@ use std::env::current_exe; use luisa::lang::types::vector::alias::*; use luisa::prelude::*; -use std::cell::RefCell; +use luisa::runtime::{Kernel, KernelDef}; use luisa_compute as luisa; +use std::cell::RefCell; fn main() { luisa::init_logger(); let args: Vec = std::env::args().collect(); @@ -24,18 +25,21 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(track!(&|buf_z| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let vx = 2.0_f32.var(); // create a local mutable variable - *vx += x; // store to vx - *vx = vx; - buf_z.write(tid, vx + y); - })); + + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let vx = 2.0_f32.var(); // create a local mutable variable + *vx += x; // store to vx + buf_z.write(tid, vx + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index ca15ac88..354761e8 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -24,6 +24,8 @@ use ir::{ Instruction, IrBuilder, ModulePools, Pooled, Type, TypeOf, UserNodeData, }; +use self::index::IntoIndex; + pub mod control_flow; pub mod debug; pub mod diff; @@ -538,3 +540,41 @@ pub(crate) fn need_runtime_check() -> bool { } || debug::__env_need_backtrace() } +fn try_eval_const_index(index: NodeRef) -> Option { + let inst = &index.get().instruction; + match inst.as_ref() { + Instruction::Const(c) => match c { + Const::Int8(i) => Some(*i as usize), + Const::Int16(i) => Some(*i as usize), + Const::Int32(i) => Some(*i as usize), + Const::Int64(i) => Some(*i as usize), + Const::Uint8(i) => Some(*i as usize), + Const::Uint16(i) => Some(*i as usize), + Const::Uint32(i) => Some(*i as usize), + Const::Uint64(i) => Some(*i as usize), + _ => None, + }, + Instruction::Call(f, args) => match f { + Func::Cast => try_eval_const_index(args[0]), + Func::Add => { + let a = try_eval_const_index(args[0]); + let b = try_eval_const_index(args[1]); + match (a, b) { + (Some(a), Some(b)) => Some(a + b), + _ => None, + } + } + _ => None, + }, + _ => None, + } +} +pub(crate) fn check_index_lt_usize(index: impl IntoIndex, size: usize) { + let index = index.to_u64(); + let i: Option = try_eval_const_index(index.node()); + if let Some(i) = i { + assert!(i < size, "Index out of bound, index: {}, size: {}", i, size); + } else { + lc_assert!(index.lt(size as u64)); + } +} diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index 30a25663..a25d6a71 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -33,7 +33,7 @@ impl Index for ArrayExpr { // TODO: Add need_runtime_check()? if need_runtime_check() { - lc_assert!(i.lt((N as u64).expr())); + check_index_lt_usize(i, N); } Expr::::from_node(__current_scope(|b| { @@ -49,7 +49,7 @@ impl Index for ArrayAtomicRef { // TODO: Add need_runtime_check()? if need_runtime_check() { - lc_assert!(i.lt((N as u64).expr())); + check_index_lt_usize(i, N); } AtomicRef::::from_node(__current_scope(|b| { @@ -92,7 +92,7 @@ impl IndexRead for Expr<[T; N]> { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(N as u64)); + check_index_lt_usize(i, N); } Expr::::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.node(), i.node()], T::type_()) @@ -104,7 +104,7 @@ impl IndexRead for Var<[T; N]> { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(N as u64)); + check_index_lt_usize(i, N); } Expr::::from_node(__current_scope(|b| { let gep = b.call(Func::GetElementPtr, &[self.node(), i.node()], T::type_()); @@ -117,7 +117,7 @@ impl IndexWrite for Var<[T; N]> { let i = i.to_u64(); let value = value.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(N as u64)); + check_index_lt_usize(i, N); } __current_scope(|b| { let gep = b.call(Func::GetElementPtr, &[self.node(), i.node()], T::type_()); @@ -259,7 +259,7 @@ impl VLArrayExpr { }); Self::from_node(node) } - pub fn static_len(&self) -> usize { + pub fn len(&self) -> usize { match self.node.type_().as_ref() { Type::Array(ArrayType { element: _, length }) => *length, _ => unreachable!(), @@ -268,14 +268,14 @@ impl VLArrayExpr { pub fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + check_index_lt_usize(i, self.len()); } Expr::::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) })) } - pub fn len(&self) -> Expr { + pub fn len_expr(&self) -> Expr { match self.node.type_().as_ref() { Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(), _ => unreachable!(), diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs index 624bc387..61e387a4 100644 --- a/luisa_compute/src/lang/types/vector/impls.rs +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -76,7 +76,7 @@ macro_rules! impl_sized { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(($N as u64).expr())); + check_index_lt_usize(i, $N); } Expr::::from_node(__current_scope(|s| { @@ -94,7 +94,7 @@ macro_rules! impl_sized { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(($N as u64).expr())); + check_index_lt_usize(i, $N); } Var::::from_node(__current_scope(|s| { @@ -333,7 +333,7 @@ macro_rules! impl_mat_proxy { fn index(&self, i: X) -> &Self::Output { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(($N as u64).expr())); + check_index_lt_usize(i, $N); } Expr::<$V>::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.0.node, i.node()], <$V>::type_()) diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 72ab439f..2476f2d8 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -40,7 +40,8 @@ pub mod prelude { pub use crate::resource::{IoTexel, StorageTexel, *}; pub use crate::runtime::api::StreamTag; pub use crate::runtime::{ - create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + Callable, Command, Device, DynCallable, Kernel, KernelBuildOptions, KernelDef, Scope, + Stream, }; pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, while_, Context}; @@ -155,6 +156,7 @@ impl Context { } } +#[derive(Clone)] pub struct ResourceTracker { resources: Vec>, } diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index eab2eb67..56c4d97e 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -149,31 +149,8 @@ impl Printer { count_per_arg: args.count_per_arg, }); } - pub fn reset(&self) -> PrinterReset { - PrinterReset { inner: self } - } - pub fn print(&self) -> PrinterPrint { - PrinterPrint { inner: self } - } -} -pub struct PrinterPrint<'a> { - inner: &'a Printer, -} -pub struct PrinterReset<'a> { - inner: &'a Printer, -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - fn shl(self, rhs: PrinterPrint<'a>) -> Self::Output { - self.print(rhs.inner) - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - fn shl(self, rhs: PrinterReset<'a>) -> Self::Output { - self.reset_printer(rhs.inner) - } } + impl<'a> Scope<'a> { pub fn reset_printer(&self, printer: &Printer) -> &Self { printer diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 303e58b7..4bce93d1 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -391,7 +391,7 @@ pub enum HitType { pub fn offset_ray_origin(p: Expr, n: Expr) -> Expr { lazy_static! { static ref F: Callable, Expr) -> Expr> = - create_static_callable::, Expr) -> Expr>(|p, n| { + Callable::, Expr) -> Expr>::new_static(|p, n| { const ORIGIN: f32 = 1.0f32 / 32.0f32; const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32; const INT_SCALE: f32 = 256.0f32; @@ -473,7 +473,7 @@ impl Deref for TriangleCandidate { } } impl ProceduralCandidate { - pub fn commit(&self, t: impl AsExpr) { + pub fn commit(&self, t: impl AsExpr) { let t = t.as_expr(); __current_scope(|b| { b.call( diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 218afa93..6c1fb73f 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -202,7 +202,7 @@ impl Device { }), _marker: PhantomData {}, len: count, - _is_byte_buffer:false, + _is_byte_buffer: false, }; buffer } @@ -400,93 +400,59 @@ impl Device { modifications: RwLock::new(HashMap::new()), } } - pub fn create_callable<'a, S: CallableSignature<'a>>(&self, f: S::Fn) -> S::Callable { - let mut builder = KernelBuilder::new(Some(self.clone()), false); - let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); - S::wrap_raw_callable(raw_callable) - } - pub fn create_dyn_callable<'a, S: CallableSignature<'a>>(&self, f: S::DynFn) -> S::DynCallable { - S::create_dyn_callable(self.clone(), false, f) + + /// Compile a [`KernelDef`] into a [`Kernel`]. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel(&self, k: &KernelDef) -> Kernel { + self.compile_kernel_with_options(k, KernelBuildOptions::default()) } - pub fn create_dyn_callable_once<'a, S: CallableSignature<'a>>( - &self, - f: S::DynFn, - ) -> S::DynCallable { - S::create_dyn_callable(self.clone(), true, f) - } - pub fn create_kernel<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = - KernelBuildFn::build_kernel(&f, &mut builder, KernelBuildOptions::default()); - S::wrap_raw_kernel(raw_kernel) - } - pub fn create_kernel_async<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel( - &f, - &mut builder, + + /// Compile a [`KernelDef`] into a [`Kernel`] asynchronously. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel_async(&self, k: &KernelDef) -> Kernel { + self.compile_kernel_with_options( + k, KernelBuildOptions { async_compile: true, ..Default::default() }, - ); - S::wrap_raw_kernel(raw_kernel) + ) } - pub fn create_kernel_with_options<'a, S: KernelSignature<'a>>( + + /// Compile a [`KernelDef`] into a [`Kernel`] with options. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel_with_options( &self, - f: S::Fn, + k: &KernelDef, options: KernelBuildOptions, - ) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel(&f, &mut builder, options); - S::wrap_raw_kernel(raw_kernel) - } -} - -pub fn create_static_callable<'a, S: CallableSignature<'a>>(f: S::StaticFn) -> S::Callable { - let r_backup = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - std::mem::replace(&mut *r, Recorder::new()) - }); - let mut builder = KernelBuilder::new(None, false); - let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); - let callable = S::wrap_raw_callable(raw_callable); - RECORDER.with(|r| { - *r.borrow_mut() = r_backup; - }); - callable -} -#[macro_export] -macro_rules! fn_n_args { - (0)=>{ dyn Fn()}; - (1)=>{ dyn Fn(_)}; - (2)=>{ dyn Fn(_,_)}; - (3)=>{ dyn Fn(_,_,_)}; - (4)=>{ dyn Fn(_,_,_,_)}; - (5)=>{ dyn Fn(_,_,_,_,_)}; - (6)=>{ dyn Fn(_,_,_,_,_,_)}; - (7)=>{ dyn Fn(_,_,_,_,_,_,_)}; - (8)=>{ dyn Fn(_,_,_,_,_,_,_,_)}; - (9)=>{ dyn Fn(_,_,_,_,_,_,_,_,_)}; - (10)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_)}; - (11)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_)}; - (12)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_)}; - (13)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_)}; - (14)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_,_)}; - (15)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)}; -} -#[macro_export] -macro_rules! wrap_fn { - ($arg_count:tt, $f:expr) => { - &$f as &fn_n_args!($arg_count) - }; -} -#[macro_export] -macro_rules! create_kernel { - ($device:expr, $arg_count:tt, $f:expr) => {{ - let kernel: fn_n_args!($arg_count) = Box::new($f); - $device.create_kernel(kernel) - }}; + ) -> Kernel { + let name = options.name.unwrap_or("".to_string()); + let name = Arc::new(CString::new(name).unwrap()); + let shader_options = api::ShaderOption { + enable_cache: options.enable_cache, + enable_fast_math: options.enable_fast_math, + enable_debug_info: options.enable_debug_info, + compile_only: false, + name: name.as_ptr(), + }; + let module = k.inner.module.clone(); + let artifact = if options.async_compile { + ShaderArtifact::Async(AsyncShaderArtifact::new( + self.clone(), + module.clone(), + shader_options, + name, + )) + } else { + ShaderArtifact::Sync(self.inner.create_shader(&module, &shader_options)) + }; + Kernel { + inner: RawKernel { + device: self.clone(), + artifact, + module, + resource_tracker: k.inner.resource_tracker.clone(), + }, + _marker: PhantomData {}, + } + } } pub(crate) enum StreamHandle { Default { @@ -793,31 +759,6 @@ impl<'a> Scope<'a> { self } } -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: Command<'a>) -> Self::Output { - self.submit(std::iter::once(rhs)); - self - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: EventSignal<'a>) -> Self::Output { - self.signal(rhs.event, rhs.ticket) - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: EventWait<'a>) -> Self::Output { - self.wait(rhs.event, rhs.ticket) - } -} impl<'a> Drop for Scope<'a> { fn drop(&mut self) { if !self.synchronized.get() { @@ -939,7 +880,7 @@ pub struct RawKernel { pub(crate) artifact: ShaderArtifact, #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, - pub(crate) module: CArc, + pub(crate) module: CArc, } pub struct CallableArgEncoder { @@ -1111,17 +1052,17 @@ macro_rules! impl_kernel_arg_for_tuple { fn encode(&self, _: &mut KernelArgEncoder) { } } }; - ($first:ident $($rest:ident) *) => { - impl<$first:KernelArg, $($rest: KernelArg),*> KernelArg for ($first, $($rest,)*) { - type Parameter = ($first::Parameter, $($rest::Parameter),*); + ($first:ident $($Ts:ident) *) => { + impl<$first:KernelArg, $($Ts: KernelArg),*> KernelArg for ($first, $($Ts,)*) { + type Parameter = ($first::Parameter, $($Ts::Parameter),*); #[allow(non_snake_case)] fn encode(&self, encoder: &mut KernelArgEncoder) { - let ($first, $($rest,)*) = self; + let ($first, $($Ts,)*) = self; $first.encode(encoder); - $($rest.encode(encoder);)* + $($Ts.encode(encoder);)* } } - impl_kernel_arg_for_tuple!($($rest)*); + impl_kernel_arg_for_tuple!($($Ts)*); }; } @@ -1142,6 +1083,7 @@ impl RawKernel { } } } + pub fn dispatch_async(&self, args: KernelArgEncoder, dispatch_size: [u32; 3]) -> Command { let mut rt = ResourceTracker::new(); rt.add(Arc::new(args.uniform_data)); @@ -1165,23 +1107,23 @@ impl RawKernel { } } -pub struct Callable> { +pub struct Callable { #[allow(dead_code)] pub(crate) inner: RawCallable, pub(crate) _marker: PhantomData, } -pub(crate) struct DynCallableInner> { +pub(crate) struct DynCallableInner { builder: Box, &mut KernelBuilder) -> Callable>, callables: Vec>, } -pub struct DynCallable> { +pub struct DynCallable { #[allow(dead_code)] pub(crate) inner: RefCell>, pub(crate) device: Device, pub(crate) init_once: bool, } -impl> DynCallable { - pub(crate) fn new( +impl DynCallable { + pub(crate) fn _new( device: Device, init_once: bool, builder: Box, &mut KernelBuilder) -> Callable>, @@ -1253,14 +1195,50 @@ pub struct RawCallable { #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, } +pub struct RawKernelDef { + #[allow(dead_code)] + pub(crate) device: Option, + pub(crate) module: CArc, + #[allow(dead_code)] + pub(crate) resource_tracker: ResourceTracker, +} -pub struct Kernel> { +/// A kernel definition +/// See [`Kernel`] for more information +pub struct KernelDef { + pub(crate) inner: RawKernelDef, + pub(crate) _marker: PhantomData, +} + +/// An executable kernel +/// Kernel creation can be done in multiple ways: +/// - Seperate recording and compilation: +/// ```no_run +//// // Recording: +/// use luisa_compute::prelude::*; +/// let ctx = Context::new(std::env::current_exe().unwrap()); +/// let device = ctx.create_device("cpu"); +/// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); +/// // Compilation: +/// let kernel = device.compile_kernel(&kernel); +/// ``` +/// - Recording and compilation in one step: +/// ```no_run +/// use luisa_compute::prelude::*; +/// let ctx = Context::new(std::env::current_exe().unwrap()); +/// let device = ctx.create_device("cpu"); +/// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); +/// ``` +/// - Asynchronous compilation use [`Kernel::::new_async`] +/// - Custom build options using [`Kernel::::new_with_options`] +/// +pub struct Kernel { pub(crate) inner: RawKernel, pub(crate) _marker: PhantomData, } -unsafe impl> Send for Kernel {} -unsafe impl> Sync for Kernel {} -impl> Kernel { +unsafe impl Send for Kernel {} +unsafe impl Sync for Kernel {} +impl Kernel { pub fn cache_dir(&self) -> Option { let handle = self.inner.unwrap(); let device = &self.inner.device; @@ -1302,80 +1280,101 @@ impl AsKernelArg> for Tex3d {} impl AsKernelArg for BindlessArray {} impl AsKernelArg for Accel {} + macro_rules! impl_call_for_callable { - ($first:ident $($rest:ident)*) => { - impl CallableR> { + ( $($Ts:ident)*) => { + impl CallableR> { #[allow(non_snake_case)] - pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { + #[allow(unused_mut)] + pub fn call(&self, $($Ts:$Ts),*) -> R { let mut encoder = CallableArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* CallableRet::_from_return( crate::lang::__invoke_callable(&self.inner.module, &encoder.args)) } } - impl DynCallableR> { + impl DynCallableR> { #[allow(non_snake_case)] - pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { + #[allow(unused_mut)] + pub fn call(&self, $($Ts:$Ts),*) -> R { let mut encoder = CallableArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* - self.call_impl(std::rc::Rc::new(($first, $($rest,)*)), &encoder.args) + $($Ts.encode(&mut encoder);)* + self.call_impl(std::rc::Rc::new(($($Ts,)*)), &encoder.args) } } - impl_call_for_callable!($($rest)*); }; - ()=>{ - impl CallableR> { - pub fn call(&self)->R { - CallableRet::_from_return( - crate::lang::__invoke_callable(&self.inner.module, &[])) - } - } - impl DynCallableR> { - pub fn call(&self)-> R{ - self.call_impl(std::rc::Rc::new(()), &[]) - } - } - } } +impl_call_for_callable!(); +impl_call_for_callable!(T0); +impl_call_for_callable!(T0 T1 ); +impl_call_for_callable!(T0 T1 T2 ); +impl_call_for_callable!(T0 T1 T2 T3 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + macro_rules! impl_dispatch_for_kernel { - ($first:ident $($rest:ident)*) => { - impl <$first:KernelArg+'static, $($rest: KernelArg+'static),*> Kernel { + ($($Ts:ident)*) => { + impl <$($Ts: KernelArg+'static),*> Kernel { #[allow(non_snake_case)] - pub fn dispatch(&self, dispatch_size: [u32; 3], $first:&impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),*) { + #[allow(unused_mut)] + pub fn dispatch(&self, dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),*) { let mut encoder = KernelArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* self.inner.dispatch(encoder, dispatch_size) } #[allow(non_snake_case)] + #[allow(unused_mut)] pub fn dispatch_async<'a>( &'a self, - dispatch_size: [u32; 3], $first: &impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),* + dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),* ) -> Command<'a> { let mut encoder = KernelArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* self.inner.dispatch_async(encoder, dispatch_size) } + /// Blocks until the kernel is compiled + pub fn ensure_ready(&self) { + self.inner.unwrap(); + } } - impl_dispatch_for_kernel!($($rest)*); }; - ()=>{ - impl Kernel { - pub fn dispatch(&self, dispatch_size: [u32; 3]) { - self.inner.dispatch(KernelArgEncoder::new(), dispatch_size) - } - pub fn dispatch_async<'a>( - &'a self, - dispatch_size: [u32; 3], - ) -> Command<'a> { - self.inner.dispatch_async(KernelArgEncoder::new(), dispatch_size) - } - } -} } + +impl_dispatch_for_kernel!(); +impl_dispatch_for_kernel!(T0); +impl_dispatch_for_kernel!(T0 T1 ); +impl_dispatch_for_kernel!(T0 T1 T2 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +#[macro_export] +macro_rules! remove_last { + ($x:ident) => { + + }; + ($first:ident $($xs:ident)*) => { + $first remove_last!($($xs)*) + }; +} diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index ea118b9c..bd738927 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -348,14 +348,13 @@ impl KernelBuilder { } }) } - fn build_kernel( + fn build_kernel( &mut self, - options: KernelBuildOptions, body: impl FnOnce(&mut Self), - ) -> crate::runtime::RawKernel { + ) -> crate::runtime::KernelDef { body(self); let (rt, cpu_custom_ops, captures) = self.collect_module_info(); - RECORDER.with(|r| -> crate::runtime::RawKernel { + RECORDER.with(|r| -> crate::runtime::KernelDef { let mut r = r.borrow_mut(); assert!(r.lock); r.lock = false; @@ -380,45 +379,28 @@ impl KernelBuilder { block_size: r.block_size.unwrap_or([64, 1, 1]), pools: r.pools.clone().unwrap(), }; - - let module = CArc::new(module); - let name = options.name.unwrap_or("".to_string()); - let name = Arc::new(CString::new(name).unwrap()); - let shader_options = api::ShaderOption { - enable_cache: options.enable_cache, - enable_fast_math: options.enable_fast_math, - enable_debug_info: options.enable_debug_info, - compile_only: false, - name: name.as_ptr(), - }; - let artifact = if options.async_compile { - ShaderArtifact::Async(AsyncShaderArtifact::new( - self.device.clone().unwrap(), - module.clone(), - shader_options, - name, - )) - } else { - ShaderArtifact::Sync( - self.device - .as_ref() - .unwrap() - .inner - .create_shader(&module, &shader_options), - ) - }; - // r.reset(); - RawKernel { - artifact, - device: self.device.clone().unwrap(), - resource_tracker: rt, - module, + + KernelDef { + inner: RawKernelDef { + device: self.device.clone(), + resource_tracker: rt, + module: CArc::new(module), + }, + _marker: PhantomData, } }) } } +/// Build options for kernel compilation +/// * `enable_debug_info`: enable debug info, default true on debug build +/// * `enable_optimization`: enable optimization, default true +/// * `async_compile`: compile the kernel asynchronously +/// * `enable_cache`: enable cache for the compiled kernel +/// * `enable_fast_math`: enable fast math in the compiled kernel +/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of the generated PTX kernel +/// #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct KernelBuildOptions { pub enable_debug_info: bool, @@ -445,21 +427,12 @@ impl Default for KernelBuildOptions { } } } - -pub trait KernelBuildFn { - fn build_kernel( - &self, - builder: &mut KernelBuilder, - options: KernelBuildOptions, - ) -> crate::runtime::RawKernel; -} - -pub trait CallableBuildFn { +pub trait CallableBuildFn { fn build_callable(&self, args: Option>, builder: &mut KernelBuilder) -> RawCallable; } -pub trait StaticCallableBuildFn: CallableBuildFn {} +pub trait StaticCallableBuildFn: CallableBuildFn {} // @FIXME: this looks redundant pub unsafe trait CallableRet { @@ -486,204 +459,206 @@ unsafe impl CallableRet for Expr { } } -pub trait CallableSignature<'a> { - type Callable; - type DynCallable; - type Fn: CallableBuildFn; - type StaticFn: StaticCallableBuildFn; - type DynFn: CallableBuildFn + 'static; +pub trait CallableSignature { type Ret: CallableRet; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable; - fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; } -pub trait KernelSignature<'a> { - type Fn: KernelBuildFn; - type Kernel; - - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; -} -macro_rules! impl_callable_signature { - ()=>{ - impl<'a, R: CallableRet +'static> CallableSignature<'a> for fn()->R { - type Fn = &'a dyn Fn() ->R; - type DynFn = BoxR>; - type StaticFn = fn() -> R; - type Callable = CallableR>; - type DynCallable = DynCallableR>; +pub trait KernelSignature {} +macro_rules! impl_callable { + ($($Ts:ident)*) => { + impl CallableSignature for fn($($Ts,)*)->R { type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:PhantomData, + } + impl CallableR> { + pub fn newR>(device: &Device, f:F)->Self where F:CallableBuildFnR> { + let mut builder = KernelBuilder::new(Some(device.clone()), false); + let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); + Self{ + inner: raw_callable, + _marker: PhantomData, } } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { - let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) - })) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a> for fn($first, $($rest,)*)->R { - type Fn = &'a dyn Fn($first, $($rest),*)->R; - type DynFn = BoxR>; - type Callable = CallableR>; - type StaticFn = fn($first, $($rest,)*)->R; - type DynCallable = DynCallableR>; - type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:PhantomData, + pub fn new_static(f:fn($($Ts,)*)->R)->Self where fn($($Ts,)*)->R :CallableBuildFnR> { + let r_backup = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + std::mem::replace(&mut *r, Recorder::new()) + }); + let mut builder = KernelBuilder::new(None, false); + let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); + RECORDER.with(|r| { + *r.borrow_mut() = r_backup; + }); + Self{ + inner: raw_callable, + _marker: PhantomData, } } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { + } + impl DynCallableR> { + pub fn new(device: &Device, f:BoxR>)->Self where BoxR> : CallableBuildFnR> { + DynCallable::_new(device.clone(), false, Box::new(move |arg, builder| { let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) + Callable { + inner: raw_callable, + _marker: PhantomData, + } })) } } - impl_callable_signature!($($rest)*); - }; -} -impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_signature { - ()=>{ - impl<'a> KernelSignature<'a> for fn() { - type Fn = &'a dyn Fn(); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:PhantomData, - } - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for fn($first, $($rest,)*) { - type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:PhantomData, - } - } - } - impl_kernel_signature!($($rest)*); }; } -impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_callable_build_for_fn { - ()=>{ - impl CallableBuildFn for &dyn Fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() +impl_callable!(); +impl_callable!(T0); +impl_callable!(T0 T1 ); +impl_callable!(T0 T1 T2 ); +impl_callable!(T0 T1 T2 T3 ); +impl_callable!(T0 T1 T2 T3 T4 ); +impl_callable!(T0 T1 T2 T3 T4 T5 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +macro_rules! impl_kernel { + ($($Ts:ident)*) => { + impl<$($Ts: KernelArg +'static),*> KernelSignature for fn($($Ts,)*) {} + impl<$($Ts: KernelArg +'static),*> KernelDef { + #[allow(non_snake_case)] + #[allow(unused_variables)] + pub fn new_maybe_device(device: Option<&Device>, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let mut builder = KernelBuilder::new(device.cloned(), true); + builder.build_kernel(move |builder| { + $(let $Ts = <$Ts::Parameter as KernelParameter>::def_param(builder);)* + (f)($($Ts,)*) }) } - } - impl CallableBuildFn for fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) + pub fn new(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + Self::new_maybe_device(Some(device), f) } - } - impl CallableBuildFn for BoxR> { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) + pub fn new_static(f:fn($($Ts::Parameter,)*))->Self { + Self::new_maybe_device(None, f) } } - impl StaticCallableBuildFn for fn()->R {} - }; - ($first:ident $($rest:ident)*) => { - impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) + impl<$($Ts: KernelArg +'static),*> Kernel { + /// Compile a kernel with given recording function `f`. + pub fn new(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel(&def) } - } - impl CallableBuildFn for BoxR> { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) + /// Compile a kernel asynchronously with given recording function `f`. + /// This function returns immediately after `f` returns + + pub fn new_async(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel_async(&def) + } + + // Compile a kernel with given recording function `f` and build options [`KernelBuildOptions`] + pub fn new_with_options(device: &Device, options: KernelBuildOptions, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel_with_options(&def, options) } } - impl CallableBuildFn for fn($first, $($rest,)*)->R { + }; +} + +impl_kernel!(); +impl_kernel!(T0); +impl_kernel!(T0 T1 ); +impl_kernel!(T0 T1 T2 ); +impl_kernel!(T0 T1 T2 T3 ); +impl_kernel!(T0 T1 T2 T3 T4 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +macro_rules! impl_callable_build_for_fn { + ($($Ts:ident)*) => { + impl CallableBuildFnR> for T + where T: Fn($($Ts,)*)->R + 'static { #[allow(non_snake_case)] + #[allow(unused_variables)] fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |builder| { if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) + let ($($Ts,)*) = args.downcast_ref::<($($Ts,)*)>().cloned().unwrap(); + $(let $Ts = $Ts::def_param(Some(Rc::new($Ts)), builder);)* + self($($Ts,)*) } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) + $(let $Ts = $Ts::def_param(None, builder);)* + self($($Ts,)*) } }) } } - impl StaticCallableBuildFn for fn($first, $($rest,)*)->R {} - impl_callable_build_for_fn!($($rest)*); + // impl CallableBuildFn for BoxR> { + // #[allow(non_snake_case)] + // fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + // builder.build_callable( |builder| { + // if let Some(args) = args { + // let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + // let $first = $first::def_param(Some(Rc::new($first)), builder); + // $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + // self($first, $($rest,)*) + // } else { + // let $first = $first::def_param(None, builder); + // $(let $rest = $rest::def_param(None, builder);)* + // self($first, $($rest,)*) + // } + // }) + // } + // } + // impl CallableBuildFn for fn($first, $($rest,)*)->R { + // #[allow(non_snake_case)] + // fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + // builder.build_callable( |builder| { + // if let Some(args) = args { + // let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + // let $first = $first::def_param(Some(Rc::new($first)), builder); + // $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + // self($first, $($rest,)*) + // } else { + // let $first = $first::def_param(None, builder); + // $(let $rest = $rest::def_param(None, builder);)* + // self($first, $($rest,)*) + // } + // }) + // } + // } + impl StaticCallableBuildFnR> for fn($($Ts,)*)->R + where fn($($Ts,)*)->R : CallableBuildFnR> {} }; } + +impl_callable_build_for_fn!(); +impl_callable_build_for_fn!(T0); +impl_callable_build_for_fn!(T0 T1 ); +impl_callable_build_for_fn!(T0 T1 T2 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_build_for_fn { - ()=>{ - impl KernelBuildFn for &dyn Fn() { - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |_| { - self() - }) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { - #[allow(non_snake_case)] - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |builder| { - let $first = $first::def_param(builder); - $(let $rest = $rest::def_param(builder);)* - self($first, $($rest,)*) - }) - } - } - impl_kernel_build_for_fn!($($rest)*); - }; -} -impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 1888326e..98b43734 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -74,7 +74,7 @@ fn autodiff_helper]) -> Expr>( // inputs[i].view(..).copy_from(&tmp); // } println!("init time: {:?}", tic.elapsed()); - let kernel = device.create_kernel_async::(&|| { + let kernel = Kernel::::new_async(&device, || { let input_vars = inputs.iter().map(|input| input.var()).collect::>(); let grad_fd_vars = grad_fd.iter().map(|grad| grad.var()).collect::>(); let grad_ad_vars = grad_ad.iter().map(|grad| grad.var()).collect::>(); @@ -707,23 +707,26 @@ fn autodiff_select() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = select(x > y, x * 4.0, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = select(x > y, x * 4.0, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -751,24 +754,27 @@ fn autodiff_detach() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let k = detach(x * y); - let z = (x + y) * k; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let k = detach(x * y); + let z = (x + y) * k; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -801,25 +807,28 @@ fn autodiff_select_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x > y; - let a = (x - y).sqrt(); - let z = select(cond, a, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x > y; + let a = (x - y).sqrt(); + let z = select(cond, a, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -842,30 +851,33 @@ fn autodiff_if_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x > y; - let z = if cond { - let a = (x - y).sqrt(); - a - } else { - y * 0.5 - }; - // cpu_dbg!(f32, z); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x > y; + let z = if cond { + let a = (x - y).sqrt(); + a + } else { + y * 0.5 + }; + // cpu_dbg!(f32, z); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -893,25 +905,28 @@ fn autodiff_if_phi() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - if true.expr() { - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if x > y { x * 4.0 } else { y * 0.5 }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - } - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + if true.expr() { + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if x > y { x * 4.0 } else { y * 0.5 }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + } + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -939,31 +954,34 @@ fn autodiff_if_phi2() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if x > y { - if x > 3.0 { - x * 4.0 + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if x > y { + if x > 3.0 { + x * 4.0 + } else { + x * 2.0 + } } else { - x * 2.0 - } - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -995,37 +1013,40 @@ fn autodiff_if_phi3() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let const_two = 2.0_f32.var(); - let const_three = 3.0_f32.var(); - let const_four = f32::var_zeroed(); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let const_two = 2.0_f32.var(); + let const_three = 3.0_f32.var(); + let const_four = f32::var_zeroed(); - autodiff(|| { - requires_grad(x); - requires_grad(y); - const_four.store(4.0); - let c = (x > const_three).as_::(); - let z = if x > y { - switch::>(c) - .case(0, || x * const_two) - .default(|| x * const_four) - .finish() - * const_two - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + autodiff(|| { + requires_grad(x); + requires_grad(y); + const_four.store(4.0); + let c = (x > const_three).as_::(); + let z = if x > y { + switch::>(c) + .case(0, || x * const_two) + .default(|| x * const_four) + .finish() + * const_two + } else { + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1057,38 +1078,41 @@ fn autodiff_if_phi4() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - let consts = Float3::var_zeroed(); - autodiff(|| { - requires_grad(x); - requires_grad(y); - *consts = Float3::expr(2.0, 3.0, 4.0); - let const_two = consts.x; - let const_three = consts.y; - let const_four = consts.z; - let c = (x > const_three).as_::(); - let z = if x > y { - switch::>(c) - .case(0, || x * const_two) - .default(|| x * const_four) - .finish() - * const_two - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let consts = Float3::var_zeroed(); + autodiff(|| { + requires_grad(x); + requires_grad(y); + *consts = Float3::expr(2.0, 3.0, 4.0); + let const_two = consts.x; + let const_three = consts.y; + let const_four = consts.z; + let c = (x > const_three).as_::(); + let z = if x > y { + switch::>(c) + .case(0, || x * const_two) + .default(|| x * const_four) + .finish() + * const_two + } else { + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1122,29 +1146,32 @@ fn autodiff_switch() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_t = t.var(); - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let t = buf_t.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = switch::>(t) - .case(0, || x * 4.0) - .case(1, || x * 2.0) - .case(2, || y * 0.5) - .finish(); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = switch::>(t) + .case(0, || x * 4.0) + .case(1, || x * 2.0) + .case(2, || y * 0.5) + .finish(); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1182,7 +1209,7 @@ fn autodiff_callable() { x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); let callable = - device.create_callable::, Var, Expr)>(track!(&|vx, vy, t| { + Callable::, Var, Expr)>::new(&device, track!(|vx, vy, t| { let x = **vx; let y = **vy; autodiff(|| { @@ -1198,22 +1225,25 @@ fn autodiff_callable() { *vy = gradient(y); }); })); - let kernel = device.create_kernel::(&track!(|| { - let buf_t = t.var(); - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let t = buf_t.read(tid); - let dx = x.var(); - let dy = y.var(); - callable.call(dx, dy, t); - buf_dx.write(tid, dx); - buf_dy.write(tid, dy); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + let dx = x.var(); + let dy = y.var(); + callable.call(dx, dy, t); + buf_dx.write(tid, dx); + buf_dy.write(tid, dy); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index d01802de..a5d1f5a8 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1,5 +1,6 @@ use luisa::lang::ops::AddMaybeExpr; use luisa::lang::types::array::VLArrayVar; +use luisa::lang::types::dynamic::*; use luisa::lang::types::vector::alias::*; use luisa::lang::types::{core::*, ExprProxy}; use luisa::prelude::*; @@ -16,46 +17,32 @@ fn event() { let a: Buffer = device.create_buffer_from_slice(&[0]); let b: Buffer = device.create_buffer_from_slice(&[0]); // compute (1 + 3) * (4 + 5) - let add = device.create_kernel::, i32)>(&|buf: BufferVar, v: Expr| { + let add = Kernel::, i32)>::new(&device, |buf: BufferVar, v: Expr| { track!(buf.write(0, buf.read(0) + v)); }); - let mul = device.create_kernel::, Buffer)>( - &|a: BufferVar, b: BufferVar| { - track!(a.write(0, a.read(0) * b.read(0))); - }, - ); + let mul = Kernel::, Buffer)>::new(&device, |a, b| { + track!(a.write(0, a.read(0) * b.read(0))); + }); let stream_a = device.create_stream(StreamTag::Compute); let stream_b = device.create_stream(StreamTag::Compute); { let scope_a = stream_a.scope(); let scope_b = stream_b.scope(); let event = device.create_event(); - let _ = &scope_a - << add.dispatch_async([1, 1, 1], &a, &1) - << add.dispatch_async([1, 1, 1], &b, &4) - << event.signal(1); - let _ = &scope_b - << event.wait(1) - << add.dispatch_async([1, 1, 1], &a, &3) - << add.dispatch_async([1, 1, 1], &b, &5) - << event.signal(2); - let _ = - &scope_a << event.wait(2) << mul.dispatch_async([1, 1, 1], &a, &b) << event.signal(3); + scope_a + .submit([add.dispatch_async([1, 1, 1], &a, &1)]) + .submit([add.dispatch_async([1, 1, 1], &b, &4)]) + .signal(&event, 1); + scope_b + .wait(&event, 1) + .submit([add.dispatch_async([1, 1, 1], &a, &3)]) + .submit([add.dispatch_async([1, 1, 1], &b, &5)]) + .signal(&event, 2); + scope_a + .wait(&event, 2) + .submit([mul.dispatch_async([1, 1, 1], &a, &b)]) + .signal(&event, 3); event.synchronize(3); - // scope_a - // .submit([add.dispatch_async([1, 1, 1], &a, &1)]) - // .submit([add.dispatch_async([1, 1, 1], &b, &4)]) - // .signal(&event, 1); - // scope_b - // .wait(&event, 1) - // .submit([add.dispatch_async([1, 1, 1], &a, &3)]) - // .submit([add.dispatch_async([1, 1, 1], &b, &5)]) - // .signal(&event, 2); - // scope_a - // .wait(&event, 2) - // .submit([mul.dispatch_async([1, 1, 1], &a, &b)]) - // .signal(&event, 3); - // event.synchronize(3); } let v = a.copy_to_vec(); assert_eq!(v[0], (1 + 3) * (4 + 5)); @@ -64,57 +51,68 @@ fn event() { #[should_panic] fn callable_return_mismatch() { let device = get_device(); - let _abs = device.create_callable::) -> Expr>(&track!(|x| { - if x > 0.0 { - return true.expr(); - } - -x - })); + let _abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return true.expr(); + } + -x + }), + ); } #[test] #[should_panic] fn callable_return_mismatch2() { let device = get_device(); - let _abs = device.create_callable::) -> Expr>(&track!(|x| { - if x > 0.0 { - return; - } - -x - })); + let _abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return; + } + -x + }), + ); } #[test] #[should_panic] fn callable_return_void_mismatch() { let device = get_device(); - let _abs = device.create_callable::)>(&track!(|x| { - if x > 0.0 { - return true.expr(); - } - *x = -x; - })); + let _abs = Callable::)>::new( + &device, + track!(|x| { + if x > 0.0 { + return true.expr(); + } + *x = -x; + }), + ); } #[test] fn callable_early_return() { let device = get_device(); - let abs = device.create_callable::) -> Expr>(track!(&|x| { - if x > 0.0 { - return x; - } - -x - })); + let abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return x; + } + -x + }), + ); let x = device.create_buffer::(1024); let mut rng = StdRng::seed_from_u64(0); x.fill_fn(|_| rng.gen()); let y = device.create_buffer::(1024); - device - .create_kernel::(&|| { - let i = dispatch_id().x; - let x = x.var().read(i); - let y = y.var(); - y.write(i, abs.call(x)); - }) - .dispatch([x.len() as u32, 1, 1]); + Kernel::::new(&device, || { + let i = dispatch_id().x; + let x = x.var().read(i); + let y = y.var(); + y.write(i, abs.call(x)); + }) + .dispatch([x.len() as u32, 1, 1]); let x = x.copy_to_vec(); let y = y.copy_to_vec(); for i in 0..x.len() { @@ -124,31 +122,34 @@ fn callable_early_return() { #[test] fn callable() { let device = get_device(); - let write = device.create_callable::, Expr, Var)>( - &|buf: BufferVar, i: Expr, v: Var| { + let write = Callable::, Expr, Var)>::new( + &device, + |buf: BufferVar, i: Expr, v: Var| { buf.write(i, v.load()); track!(*v+=1;) }, ); - let add = - device.create_callable::, Expr) -> Expr>(&|a, b| track!(a + b)); + let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as u32); y.view(..).fill_fn(|i| 1000 * i as u32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_w = w.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let z = add.call(x, y).var(); - write.call(buf_z, tid, z); - buf_w.write(tid, z); - })); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_w = w.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let z = add.call(x, y).var(); + write.call(buf_z, tid, z); + buf_w.write(tid, z); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); let w_data = w.view(..).copy_to_vec(); @@ -164,7 +165,12 @@ fn vec_cast() { let i: Buffer = device.create_buffer(1024); f.view(..) .fill_fn(|i| Float2::new(i as f32 + 0.5, i as f32 + 1.5)); - let kernel = device.create_kernel_with_options::( + let kernel = Kernel::::new_with_options( + &device, + KernelBuildOptions { + name: Some("vec_cast".to_string()), + ..KernelBuildOptions::default() + }, &|| { let f = f.var(); let i = i.var(); @@ -172,10 +178,6 @@ fn vec_cast() { let v = f.read(tid); i.write(tid, v.as_int2()); }, - KernelBuildOptions { - name: Some("vec_cast".to_string()), - ..KernelBuildOptions::default() - }, ); kernel.dispatch([1024, 1, 1]); let mut i_data = vec![Int2::new(0, 0); 1024]; @@ -197,19 +199,22 @@ fn bool_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -238,19 +243,22 @@ fn bvec_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -276,16 +284,19 @@ fn test_var_replace() { let device = get_device(); let xs: Buffer = device.create_buffer(1024); let ys: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = xs.var().read(tid).var(); - *x = Int4::expr(1, 2, 3, 4); - let y = **x; - *x.y = 10; - *x.z = 20; - xs.write(tid, x); - ys.write(tid, y); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = xs.var().read(tid).var(); + *x = Int4::expr(1, 2, 3, 4); + let y = **x; + *x.y = 10; + *x.z = 20; + xs.write(tid, x); + ys.write(tid, y); + }), + ); kernel.dispatch([1024, 1, 1]); let xs = xs.view(..).copy_to_vec(); let ys = ys.view(..).copy_to_vec(); @@ -317,26 +328,29 @@ fn vec_bit_minmax() { x.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); z.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let z = z.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - let min = min.var(); - let max = max.var(); - let clamp = clamp.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - min.write(tid, luisa::min(x, y)); - max.write(tid, luisa::max(x, y)); - clamp.write(tid, z.clamp(luisa::min(x, y), luisa::max(x, y))); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let z = z.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + let min = min.var(); + let max = max.var(); + let clamp = clamp.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + min.write(tid, luisa::min(x, y)); + max.write(tid, luisa::max(x, y)); + clamp.write(tid, z.clamp(luisa::min(x, y), luisa::max(x, y))); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -376,7 +390,7 @@ fn vec_permute() { let v3: Buffer = device.create_buffer(1024); v2.view(..) .fill_fn(|i| Int2::new(i as i32 + 0, i as i32 + 1)); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let v2 = v2.var(); let v3 = v3.var(); let tid = dispatch_id().x; @@ -399,18 +413,21 @@ fn if_phi() { let x: Buffer = device.create_buffer(1024); let even: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::(&track!(|| { - let x = x.var(); - let even = even.var(); - let tid = dispatch_id().x; - let v = x.read(tid); - let result = if v % 2 == 0 { - true.expr() - } else { - false.expr() - }; - even.write(tid, result); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let x = x.var(); + let even = even.var(); + let tid = dispatch_id().x; + let v = x.read(tid); + let result = if v % 2 == 0 { + true.expr() + } else { + false.expr() + }; + even.write(tid, result); + }), + ); kernel.dispatch([1024, 1, 1]); let mut i_data = vec![false; 1024]; even.view(..).copy_to(&mut i_data); @@ -426,7 +443,7 @@ fn switch_phi() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -473,7 +490,7 @@ fn switch_unreachable() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32 % 3); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -515,17 +532,20 @@ fn switch_unreachable() { fn array_read_write() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -539,15 +559,18 @@ fn array_read_write() { fn array_read_write3() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - for_range(0..4u32, |i| { - arr.write(i, tid.as_i32() + i.as_i32()); - }); - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + for_range(0..4u32, |i| { + arr.write(i, tid.as_i32() + i.as_i32()); + }); + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -561,17 +584,20 @@ fn array_read_write3() { fn array_read_write4() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - for_range(0..6u32, |_| { - for_range(0..4u32, |i| { - arr.write(i, arr.read(i) + tid.as_i32() + i.as_i32()); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + for_range(0..6u32, |_| { + for_range(0..4u32, |i| { + arr.write(i, arr.read(i) + tid.as_i32() + i.as_i32()); + }); }); - }); - buf_x.write(tid, arr); - })); + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -591,19 +617,22 @@ fn array_read_write2() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - buf_y.write(tid, arr.read(0)); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + buf_y.write(tid, arr.read(0)); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); let y_data = y.view(..).copy_to_vec(); @@ -620,25 +649,28 @@ fn array_read_write_vla() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let vl = VLArrayVar::::zero(4); - let i = i32::var_zeroed(); - while i < 4 { - vl.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), vl.read(i.as_u32())); - *i += 1; - } - buf_x.write(tid, arr); - buf_y.write(tid, arr.read(0)); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let vl = VLArrayVar::::zero(4); + let i = i32::var_zeroed(); + while i < 4 { + vl.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), vl.read(i.as_u32())); + *i += 1; + } + buf_x.write(tid, arr); + buf_y.write(tid, arr.read(0)); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); let y_data = y.view(..).copy_to_vec(); @@ -654,17 +686,20 @@ fn array_read_write_vla() { fn array_read_write_async_compile() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -681,19 +716,22 @@ fn capture_same_buffer_multiple_view() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let buf_x_lo = x.view(0..64).var(); - let buf_x_hi = x.view(64..).var(); - let x = if tid < 64 { - buf_x_lo.read(tid) - } else { - buf_x_hi.read(tid - 64) - }; - let buf_sum = sum.var(); + let shader = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let buf_x_lo = x.view(0..64).var(); + let buf_x_hi = x.view(64..).var(); + let x = if tid < 64 { + buf_x_lo.read(tid) + } else { + buf_x_hi.read(tid - 64) + }; + let buf_sum = sum.var(); - buf_sum.atomic_fetch_add(0, x); - })); + buf_sum.atomic_fetch_add(0, x); + }), + ); shader.dispatch([x.len() as u32, 1, 1]); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); @@ -709,19 +747,22 @@ fn uniform() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|v: Expr| { - let tid = dispatch_id().x; - let buf_x_lo = x.view(0..64).var(); - let buf_x_hi = x.view(64..).var(); - let x = if tid < 64 { - buf_x_lo.read(tid) - } else { - buf_x_hi.read(tid - 64) - }; - let buf_sum = sum.var(); - let x = x * v.reduce_prod(); - buf_sum.atomic_fetch_add(0, x); - })); + let shader = Kernel::::new( + &device, + track!(|v: Expr| { + let tid = dispatch_id().x; + let buf_x_lo = x.view(0..64).var(); + let buf_x_hi = x.view(64..).var(); + let x = if tid < 64 { + buf_x_lo.read(tid) + } else { + buf_x_hi.read(tid - 64) + }; + let buf_sum = sum.var(); + let x = x * v.reduce_prod(); + buf_sum.atomic_fetch_add(0, x); + }), + ); shader.dispatch([x.len() as u32, 1, 1], &Float3::new(1.0, 2.0, 3.0)); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); @@ -758,8 +799,9 @@ fn byte_buffer() { let i1 = push!(Big, big); let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); - device - .create_kernel::(&track!(|| unsafe { + Kernel::::new( + &device, + track!(|| unsafe { let buf = buf.var(); let i0 = i0 as u64; let i1 = i1 as u64; @@ -779,8 +821,9 @@ fn byte_buffer() { buf.write_as::(i1, v1.load()); buf.write_as::(i2, v2.load()); buf.write_as::(i3, v3.load()); - })) - .dispatch([1, 1, 1]); + }), + ) + .dispatch([1, 1, 1]); let data = buf.copy_to_vec(); macro_rules! pop { ($t:ty, $offset:expr) => {{ @@ -833,8 +876,9 @@ fn bindless_byte_buffer() { let i1 = push!(Big, big); let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); - device - .create_kernel::(&track!(|out: ByteBufferVar| unsafe { + Kernel::::new( + &device, + track!(|out: ByteBufferVar| unsafe { let heap = heap.var(); let buf = heap.byte_address_buffer(0u32); let i0 = i0 as u64; @@ -855,8 +899,9 @@ fn bindless_byte_buffer() { out.write_as::(i1, v1.load()); out.write_as::(i2, v2.load()); out.write_as::(i3, v3.load()); - })) - .dispatch([1, 1, 1], &out); + }), + ) + .dispatch([1, 1, 1], &out); let data = out.copy_to_vec(); macro_rules! pop { ($t:ty, $offset:expr) => {{ @@ -944,25 +989,28 @@ fn atomic() { }; let foo_max = device.create_buffer_from_slice(&[foo_max_init]); let foo_min = device.create_buffer_from_slice(&[foo_min_init]); - let kernel = device.create_kernel::(&track!(|| { - let i = dispatch_id().x; - let foos = foos.var(); - let foo = foos.read(i); - let foo_max = foo_max.var().atomic_ref(0); - let foo_min = foo_min.var().atomic_ref(0); - foo_max.i.fetch_max(foo.i); - foo_max.v.x.fetch_max(foo.v.x); - foo_max.v.y.fetch_max(foo.v.y); - for i in 0..4u32 { - foo_max.a[i].fetch_max(foo.a[i]); - } - foo_min.i.fetch_min(foo.i); - foo_min.v.x.fetch_min(foo.v.x); - foo_min.v.y.fetch_min(foo.v.y); - for i in 0..4u32 { - foo_min.a[i].fetch_min(foo.a[i]); - } - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let i = dispatch_id().x; + let foos = foos.var(); + let foo = foos.read(i); + let foo_max = foo_max.var().atomic_ref(0); + let foo_min = foo_min.var().atomic_ref(0); + foo_max.i.fetch_max(foo.i); + foo_max.v.x.fetch_max(foo.v.x); + foo_max.v.y.fetch_max(foo.v.y); + for i in 0..4u32 { + foo_max.a[i].fetch_max(foo.a[i]); + } + foo_min.i.fetch_min(foo.i); + foo_min.v.x.fetch_min(foo.v.x); + foo_min.v.y.fetch_min(foo.v.y); + for i in 0..4u32 { + foo_min.a[i].fetch_min(foo.a[i]); + } + }), + ); kernel.dispatch([foos.len() as u32, 1, 1]); let foos = foos.view(..).copy_to_vec(); let foo_min = foo_min.view(..).copy_to_vec()[0]; @@ -986,3 +1034,52 @@ fn atomic() { assert_eq!(foo_max, expected_foo_max); assert_eq!(foo_min, expected_foo_min); } + +#[test] +fn dyn_callable() { + let device = get_device(); + let add = DynCallable:: DynExpr>::new( + &device, + Box::new(|a: DynExpr, b: DynExpr| -> DynExpr { + if let Some(a) = a.downcast::() { + let b = b.downcast::().unwrap(); + return DynExpr::new(track!(a + b)); + } else if let Some(a) = a.downcast::() { + let b = b.downcast::().unwrap(); + return DynExpr::new(track!(a + b)); + } else { + unreachable!() + } + }), + ); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + let w = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + + buf_z.write(tid, add.call(x.into(), y.into()).get::()); + w.var().write( + tid, + add.call(x.as_::().into(), y.as_::().into()) + .get::(), + ); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + let w_data = w.view(..).copy_to_vec(); + for i in 0..1024 { + assert_eq!(z_data[i], i as f32 + 1000.0 * i as f32); + assert_eq!(w_data[i], i as i32 + 1000 * i as i32); + } +}