Skip to content

Commit

Permalink
Merge pull request #25 from LuisaGroup/nested-callable
Browse files Browse the repository at this point in the history
Nested callable
  • Loading branch information
shiinamiyuki authored Oct 10, 2023
2 parents 14eb9e5 + ca58236 commit 7ce0576
Show file tree
Hide file tree
Showing 26 changed files with 1,643 additions and 1,174 deletions.
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,19 +432,20 @@ let a = 1.0f32.var();
pass_by_ref.call(a);
cpu_dbg!(*a); // prints 2.0
```
***Note***: You cannot record a callable when recording another kernel or callables. This is because a callable can capture outer variables such as buffers. However, capturing local variables define in another callable is undefined behavior. To avoid this, we disallow recording a callable when recording another callable or kernel.
***Note***: You create callables within another callable and even capture variables from the outer callable. The only limitation is that you cannot return a callable from another callable.
```rust
let add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
// runtime error!
let another_add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
a + b
}));
a + b
}));
let add = track!(Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(
&device,
|a, b| {
// callables can be defined within callables
let partial_add = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device, |y| a + y);
partial_add.call(b)
}
));
```

***However, we acknowledge that recording a callable inside another callable/kernel is a useful feature***. Thus we provide two ways to workaround this limitation:
1. Use static callables. A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example,
#### Static callables
A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example,
```rust
lazy_static! {
static ref ADD:Callable<fn(Expr<f32>, Expr<f32>)->Expr<f32>> = Callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>::new_static(|a, b| {
Expand All @@ -454,16 +455,24 @@ lazy_static! {
ADD.call(x, y);
```

2. Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL.
#### Dynamic callables (templates)
Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL.

```rust
let add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
// no error!
let another_add = DynCallable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(Box::new(|a, b| {
a + b
})));
a + b
}));
let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
&device,
Box::new(|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
let b = b.downcast::<f32>().unwrap();
return DynExpr::new(track!(a + b));
} else if let Some(a) = a.downcast::<i32>() {
let b = b.downcast::<i32>().unwrap();
return DynExpr::new(track!(a + b));
} else {
unreachable!()
}
}),
);
```

### Kernel
Expand Down
12 changes: 0 additions & 12 deletions luisa_compute/examples/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,6 @@ fn main() {
}),
);

let init_grid = Kernel::<fn()>::new_async(&device, &|| {
let idx = index(dispatch_id().xy());
u0.var().write(idx, Float2::expr(0.0f32, 0.0f32));
u1.var().write(idx, Float2::expr(0.0f32, 0.0f32));

rho0.var().write(idx, 0.0f32);
rho1.var().write(idx, 0.0f32);

p0.var().write(idx, 0.0f32);
p1.var().write(idx, 0.0f32);
div.var().write(idx, 0.0f32);
});

let clear_pressure = Kernel::<fn()>::new_async(&device, &|| {
let idx = index(dispatch_id().xy());
Expand Down
49 changes: 49 additions & 0 deletions luisa_compute/examples/nested_callable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use luisa::prelude::*;
use luisa_compute as luisa;
use std::env::current_exe;

fn main() {
luisa::init_logger();
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});

let add = track!(Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(
&device,
|a, b| {
// callables can be defined within callables
let partial_add = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device, |y| a + y);
partial_add.call(b)
}
));
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);

buf_z.write(tid, add.call(x, y));
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
}
3 changes: 0 additions & 3 deletions luisa_compute/examples/vecadd.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::env::current_exe;

use luisa::lang::types::vector::alias::*;
use luisa::prelude::*;
use luisa::runtime::{Kernel, KernelDef};
use luisa_compute as luisa;
#[tracked]
fn main() {
Expand Down
Loading

0 comments on commit 7ce0576

Please sign in to comment.