Skip to content

Commit

Permalink
make BufferView, Tex{2/3}d[View], etc deref to corresponding Var
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 16, 2023
1 parent 40d2ba3 commit cc03244
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 86 deletions.
51 changes: 37 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ To see the use of `luisa-compute-rs` in a high performance offline rendering sys
- [Autodiff](#autodiff)
- [Custom Operators](#custom-operators)
- [Callable](#callable)
- [Outlining](#outlining)
- [Kernel](#kernel)
- [Debugging](#debugging)
- [Advanced Usage](#advanced-usage)
Expand Down Expand Up @@ -65,10 +66,10 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, |buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&device, |buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let buf_x = &x; // x and y are captured
let buf_y = &y;
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
Expand Down Expand Up @@ -132,7 +133,7 @@ use luisa::prelude::*;
```

### `track!` and `#[tracked]` Macro
To start writing using DSL, let's first introduce the `track!` macro. `track!( expr )` rewrites `expr` and redirect operators/control flows to DSL's internal traits. It resolves the fundamental issue that Rust is unable to overload `operator=`.
To start writing using DSL, let's first introduce the `track!` macro. `track!(expr)` rewrites `expr` and redirect operators/control flows to DSL's internal traits. It resolves the fundamental issue that Rust is unable to overload `operator=`.

**Every operation involving a DSL object must be enclosed within `track!`**, except `Var<T>::store()` and `Var<T>::load()`

Expand All @@ -154,7 +155,7 @@ track!({
```
**We* highly encourage you to enclose the entire kernel inside `track!`**

Inside `track!` normal Rust syntax is still supported. Operations involving non-DSL values are still performed using native Rust operators. For example:
Inside `track!` normal Rust syntax is still supported (with a few exceptions). Operations involving non-DSL values are still performed using native Rust operators. For example:
```rust
// This is still valid
track!({
Expand All @@ -178,7 +179,7 @@ track!({
We also offer a `#[tracked]` macro that applies to a function. It transform the body of the function using `track!`.
```rust
#[tracked]
fn add(a:Expr<f32>, b:Expr<f32>)->Expr<f32> {
fn add(a:Expr<f32>, b:Expr<f32>) -> Expr<f32> {
a + b
}

Expand All @@ -187,7 +188,7 @@ However, not every kernel can be constructed using `track!` code only. We still
For example, we can use native `for` loops to unroll a DSL loop. We first starts with a native version using DSL loops.
```rust
#[tracked]
fn pow_naive(x:Expr<f32>, i:u32)->Expr<f32> {
fn pow_naive(x:Expr<f32>, i:u32) -> Expr<f32> {
let p = 1.0f32.var();
for _ in 0..i {
p *= x;
Expand Down Expand Up @@ -247,6 +248,7 @@ let v = 0.0f32.var();
track!(if cond {
*v += 1.0;
));
```

All operations except load/store should be performed on `Expr<T>`. `Var<T>` can only be used to load/store values.

Expand Down Expand Up @@ -291,7 +293,7 @@ We have extentded primitive types with methods similar to their host counterpart

`if`, `while`, `break`, `continue`, `return` and `loop` are supported via `tracked!` macro. It is also possible to construct these control flows without `track!`.

The `switch_` statement has to be constructe manually inside a `escape!` block. For example,
The `switch_` statement has to be constructe manually. For example,
```rust
let (x,y) = switch::<(Expr<i32>, Expr<f32>)>(value)
.case(1, || { ... })
Expand Down Expand Up @@ -334,7 +336,7 @@ track!({
let v = MyVec2::var_zeroed();
let sum = v.x +*v.y;
*v.x += 1.0;
let v = MyVec2::from_comps_expr(MyVec2Comps{x:1.0f32.expr(), y:1.0f32.expr()});
let v = MyVec2::from_comps_expr(MyVec2Comps{ x: 1.0f32.expr(), y: 1.0f32.expr()});
let v = MyVec2::new_expr(1.0f32, 2.0f32); // only if #[value_new] is present
});

Expand Down Expand Up @@ -418,6 +420,8 @@ let result = my_add.call(args);
```

### Callable
**Note:** Almost all usage of callables are largely replaced with `outline(...)` in the [Outlining](#outlining) section. We keep this section for reference.

Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable<fn(Args)->Ret>`.
The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar<T>`, .etc), expressions and references (pass a `Var<T>` to the callable). For example:
```rust
Expand Down Expand Up @@ -474,6 +478,25 @@ let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
}),
);
```
### Outlining
The `outline(|| { .. })` extracts a code snippet and deduplicate it into a callable. The callable is then called from the original location.
```rust
let add = |a:Expr<f32>, b:Expr<f32>|{
let sum = 0.0f32.var();
// automatically generates a callable and call it
outline(|| {
let a = 1.0f32.expr();
let b = 2.0f32.expr();
*sum = a + b;
});
**sum
};
let z = add(x, y);
// the following code is deduplicated
let w = add(z, y);
```
Since `outline` works via capturing, it is possible to pass arbitrary types or ever types that are not statically known.


### Kernel
A kernel can be written in a closure or a function. The closure/function should have a `Fn(/*args*/)->()` signature, where the args are taking the `Var` type of resources, such as `BufferVar<T>`, `Tex2D<T>`, etc.
Expand All @@ -488,9 +511,9 @@ kernel.dispatch([/*dispatch size*/], &arg0, &arg1, ...);
There are two ways to pass arguments to a kernel: by arguments or by capture.
```rust
let captured:Buffer<f32> = device.create_buffer(...);
let kernel = device.create_kernel::<fn(BufferVar<f32>>(arg| {
let kernel = device.create_kernel::<fn(Buffer<f32>>(arg| {
let v = arg.read(..);
let u = captured.var().read(..);
let u = captured.read(..);
}));
```
User can pass a maximum of 16 arguments to kernel and unlimited number of captured variables. If more than 16 arguments are needed, user can pack them into a struct and pass the struct as a single argument.
Expand All @@ -511,7 +534,7 @@ let BufferPair{a, b} = packed; // unpack if you need to use them later
```
### Debugging
We provide logging through the `log` crate. Users can either setup their own logger or use the `init_logger()` and `init_logger_verbose()` for handy initialization.
For `debug` builds, oob checks are automatically inserted so that an assertion failure would occur if oob access is detected. On CPU backend, it will be accompanied by an informative message such as `assertion failed: i.cmplt(self.len()) at xx.rs:yy:zz`. Setting the environment variable `LUISA_BACKTRACE=1` would display a stacktrace containing the *DSL* code that records the kernel. For other backends, assertion with message is still *WIP*.
For `debug` builds, oob checks are automatically inserted so that an assertion failure would occur if oob access is detected. On CPU/CUDA backend, it will be accompanied by an informative message such as `assertion failed: i.cmplt(self.len()) at xx.rs:yy:zz`. Setting the environment variable `LUISA_BACKTRACE=1` would display a stacktrace containing the *DSL* code that records the kernel. For other backends, assertion with message is still *WIP*.

For `release` builds however, these checks are disabled by default for performance reasons. To enable them, set environment variable `LUISA_DEBUG=1` prior to launching the application.

Expand All @@ -524,10 +547,10 @@ TODO
### API
Host-side safety: The API aims to be 100% safe on host side. However, the safety of async operations are gauranteed via staticly know sync points (such as `Stream::submit_and_sync`). If fully dynamic async operations are needed, user need to manually lift the liftime and use unsafe code accordingly.

Device-side safety: Due to the async nature of device-side operations. It is both very difficult to propose a safe **host** API that captures **device** resource lifetime. While device-side safety isn't guaranteed at compile time, on `cpu` backend runtime checks will catch any illegal memory access/racing condition during execution. However, for other backends such check is either too expensive or impractical and memory errors would result in undefined behavior instead.
Device-side safety: Due to the async nature of device-side operations. It is both very difficult to propose a safe **host** API that captures **device** resource lifetime. While device-side safety isn't guaranteed at compile time, on `cpu` backend runtime checks will catch any illegal memory access during execution. However, for other backends such check is either too expensive or impractical and memory errors would result in undefined behavior instead.

### Backend
Safety checks such as OOB is generally not available for GPU backends. As it is difficult to produce meaningful debug message in event of a crash. However, the Rust backend provided in the crate contains full safety checks and is recommended for debugging.
Safety checks such as OOB is generally not available for many GPU backends. As it is difficult to produce meaningful debug message in event of a crash. However, the CPU backend provided in the crate contains full safety checks and is recommended for debugging.

## Citation
When using luisa-compute-rs in an academic project, we encourage you to cite
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ fn main() {
x.view(..).fill_fn(|i| i as f32);
sum.view(..).fill(0.0);
let shader = Kernel::<fn()>::new(&device, &|| {
let buf_x = x.var();
let buf_sum = sum.var();
let buf_x = &x;
let buf_sum = &sum;
let tid = dispatch_id().x;
buf_sum.atomic_fetch_add(0, buf_x.read(tid));
});
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ fn main() {
&device,
&track!(|| {
let tid = dispatch_id().x;
let buf_x = x.var();
let buf_y = y.var();
let buf_x = &x;
let buf_y = &y;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ fn main() {
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let buf_x = &x; // x and y are captured
let buf_y = &y;
let tid = dispatch_id().x;
let x = buf_x.read(tid + 123);
let y = buf_y.read(tid);
Expand Down
7 changes: 2 additions & 5 deletions luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,8 @@ fn main() {
break;
}

let vertex_buffer = vertex_heap.var().buffer::<[f32; 3]>(hit.inst_id);
let triangle = index_heap
.var()
.buffer::<Index>(hit.inst_id)
.read(hit.prim_id);
let vertex_buffer = vertex_heap.buffer::<[f32; 3]>(hit.inst_id);
let triangle = index_heap.buffer::<Index>(hit.inst_id).read(hit.prim_id);

let p0: Expr<Float3> = vertex_buffer.read(triangle[0]).into();
let p1: Expr<Float3> = vertex_buffer.read(triangle[1]).into();
Expand Down
1 change: 0 additions & 1 deletion luisa_compute/examples/ray_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ fn main() {
let rt_kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let accel = accel.var();
let px = dispatch_id().xy();
let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32);
let xy = 2.0 * xy - 1.0;
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/vecadd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ fn main() {
},
&|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let buf_x = &x; // x and y are captured
let buf_y = &y;
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
Expand Down
9 changes: 9 additions & 0 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ pub(crate) struct FnRecorder {
pub(crate) building_kernel: bool,
pub(crate) pools: CArc<ModulePools>,
pub(crate) arena: Bump,
pub(crate) dtors: Vec<(*mut u8, fn(*mut u8))>,
pub(crate) callable_ret_type: Option<CArc<Type>>,
pub(crate) const_builder: IrBuilder,
pub(crate) index_const_pool: IndexMap<i32, NodeRef>,
Expand Down Expand Up @@ -426,6 +427,7 @@ impl FnRecorder {
kernel_id,
parent,
index_const_pool: IndexMap::new(),
dtors: vec![],
const_builder: IrBuilder::new(pools.clone()),
rt: ResourceTracker::new(),
}
Expand Down Expand Up @@ -530,6 +532,13 @@ impl FnRecorder {
arg
}
}
impl Drop for FnRecorder {
fn drop(&mut self) {
for (ptr, dtor) in self.dtors.drain(..) {
dtor(ptr);
}
}
}
thread_local! {
pub(crate) static RECORDER: RefCell<Option<FnRecorderPtr>> = RefCell::new(None);
}
Expand Down
115 changes: 80 additions & 35 deletions luisa_compute/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,48 @@ pub struct BufferView<T: Value> {
pub(crate) total_size_bytes: usize,
pub(crate) _marker: PhantomData<fn() -> T>,
}
#[macro_export]
macro_rules! impl_resource_deref_to_var {
($r:ident, $v:ident [T: $tr:ident]) => {
impl<T: $tr> std::ops::Deref for $r<T> {
type Target = $v<T>;
fn deref(&self) -> &Self::Target {
let v = self.var();
with_recorder(|r| {
let v = r.arena.alloc(v);
r.dtors.push((v as *mut _ as *mut u8, |v| unsafe {
std::ptr::drop_in_place(v as *mut $v<T>)
}));
unsafe { std::mem::transmute(v) }
})
}
}

};
($r:ident, $v:ident) => {
impl std::ops::Deref for $r {
type Target = $v;
fn deref(&self) -> &Self::Target {
let v = self.var();
with_recorder(|r| {
let v = r.arena.alloc(v);
r.dtors.push((v as *mut _ as *mut u8, |v| unsafe {
std::ptr::drop_in_place(v as *mut $v)
}));
unsafe { std::mem::transmute(v) }
})
}
}

};
}
impl_resource_deref_to_var!(BufferView, BufferVar [T: Value]);
impl_resource_deref_to_var!(Tex2dView, Tex2dVar [T: IoTexel]);
impl_resource_deref_to_var!(Tex3dView, Tex3dVar [T: IoTexel]);
impl_resource_deref_to_var!(Tex2d, Tex2dVar [T: IoTexel]);
impl_resource_deref_to_var!(Tex3d, Tex3dVar [T: IoTexel]);
impl_resource_deref_to_var!(BindlessArray, BindlessArrayVar);

impl<T: Value> BufferView<T> {
/// reinterpret the buffer as a different type
/// must satisfy `std::mem::size_of::<T>() * self.len() % std::mem::size_of::<U>() == 0`
Expand Down Expand Up @@ -1224,7 +1266,10 @@ macro_rules! impl_tex_view {
callback: None,
}
}
pub fn copy_from_buffer<U: StorageTexel<T> + Value>(&self, buffer_view: &BufferView<U>) {
pub fn copy_from_buffer<U: StorageTexel<T> + Value>(
&self,
buffer_view: &BufferView<U>,
) {
submit_default_stream_and_sync(
&self.device,
[self.copy_from_buffer_async(buffer_view)],
Expand Down Expand Up @@ -1319,12 +1364,12 @@ impl<T: IoTexel> Tex2d<T> {
pub fn format(&self) -> PixelFormat {
self.handle.format
}
pub fn read(&self, uv: impl AsExpr<Value = Uint2>) -> Expr<T> {
self.var().read(uv)
}
pub fn write(&self, uv: impl AsExpr<Value = Uint2>, v: impl AsExpr<Value = T>) {
self.var().write(uv, v)
}
// pub fn read(&self, uv: impl AsExpr<Value = Uint2>) -> Expr<T> {
// self.var().read(uv)
// }
// pub fn write(&self, uv: impl AsExpr<Value = Uint2>, v: impl AsExpr<Value = T>) {
// self.var().write(uv, v)
// }
}
impl<T: IoTexel> Tex3d<T> {
pub fn view(&self, level: u32) -> Tex3dView<T> {
Expand All @@ -1345,12 +1390,12 @@ impl<T: IoTexel> Tex3d<T> {
pub fn format(&self) -> PixelFormat {
self.handle.format
}
pub fn read(&self, uv: impl AsExpr<Value = Uint3>) -> Expr<T> {
self.var().read(uv)
}
pub fn write(&self, uv: impl AsExpr<Value = Uint3>, v: impl AsExpr<Value = T>) {
self.var().write(uv, v)
}
// pub fn read(&self, uv: impl AsExpr<Value = Uint3>) -> Expr<T> {
// self.var().read(uv)
// }
// pub fn write(&self, uv: impl AsExpr<Value = Uint3>, v: impl AsExpr<Value = T>) {
// self.var().write(uv, v)
// }
}
#[derive(Clone)]
pub struct BufferVar<T: Value> {
Expand Down Expand Up @@ -1849,28 +1894,28 @@ impl<T: Value> ToNode for Buffer<T> {
self.var().node()
}
}
impl<T: Value> IndexRead for BufferView<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
self.var().read(i)
}
}
impl<T: Value> IndexWrite for BufferView<T> {
fn write<I: IntoIndex, V: AsExpr<Value = T>>(&self, i: I, v: V) {
self.var().write(i, v)
}
}
impl<T: Value> IndexRead for Buffer<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
self.var().read(i)
}
}
impl<T: Value> IndexWrite for Buffer<T> {
fn write<I: IntoIndex, V: AsExpr<Value = T>>(&self, i: I, v: V) {
self.var().write(i, v)
}
}
// impl<T: Value> IndexRead for BufferView<T> {
// type Element = T;
// fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
// self.var().read(i)
// }
// }
// impl<T: Value> IndexWrite for BufferView<T> {
// fn write<I: IntoIndex, V: AsExpr<Value = T>>(&self, i: I, v: V) {
// self.var().write(i, v)
// }
// }
// impl<T: Value> IndexRead for Buffer<T> {
// type Element = T;
// fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
// self.var().read(i)
// }
// }
// impl<T: Value> IndexWrite for Buffer<T> {
// fn write<I: IntoIndex, V: AsExpr<Value = T>>(&self, i: I, v: V) {
// self.var().write(i, v)
// }
// }
impl<T: Value> IndexRead for BufferVar<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
Expand Down
Loading

0 comments on commit cc03244

Please sign in to comment.