Skip to content

Commit

Permalink
improving callable/kernel syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 11, 2023
1 parent 6e34d93 commit e5e92e9
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 85 deletions.
2 changes: 1 addition & 1 deletion luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ fn main() {
vertex_heap.emplace_buffer_async(index, vertex_buffers.last().unwrap());
index_heap.emplace_buffer_async(index, index_buffers.last().unwrap());
cmds.push(mesh.build_async(AccelBuildRequest::ForceBuild));
accel.push_mesh(&mesh, glam::Mat4::IDENTITY.into(), u8::MAX, true);
accel.push_mesh(&mesh, glam::Mat4::IDENTITY.into(), u32::MAX, true);
}
cmds.push(vertex_heap.update_async());
cmds.push(index_heap.update_async());
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ fn main() {
} else {
glam::Mat4::IDENTITY
};
accel.push_mesh(&mesh, m.into(), u8::MAX, false);
accel.push_mesh(&mesh, m.into(), u32::MAX, false);
}
cmds.push(vertex_heap.update_async());
cmds.push(index_heap.update_async());
Expand Down
81 changes: 81 additions & 0 deletions luisa_compute/examples/shadertoy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use luisa::prelude::*;
use luisa_compute as luisa;
use std::env::current_exe;
fn main() {
use luisa::*;
init_logger();

std::env::set_var("WINIT_UNIX_BACKEND", "x11");

let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});

let palette = device.create_callable::<(Expr<f32>,), Expr<Float3>>(&|d| {
make_float3(0.2, 0.7, 0.9).lerp(make_float3(1.0, 0.0, 1.0), Float3Expr::splat(d))
});
let rotate = device.create_callable::<(Expr<Float2>, Expr<f32>), Expr<Float2>>(&|mut p, a| {
let c = a.cos();
let s = a.sin();
make_float2(p.dot(make_float2(c, s)), p.dot(make_float2(-s, c)))
});
let map = device.create_callable::<(Expr<Float3>, Expr<f32>), Expr<f32>>(&|mut p, time| {
for i in 0..8 {
let t = time * 0.2;
let r = rotate.call(p.xz(), t);
p = make_float3(r.x(), r.y(), p.y()).xzy();
let r = rotate.call(p.xy(), t * 1.89);
p = make_float3(r.x(), r.y(), p.z());
p = make_float3(p.x().abs() - 0.5, p.y(), p.z().abs() - 0.5)
}
Float3Expr::splat(1.0).copysign(p).dot(p) * 0.2
});
let rm = device.create_callable::<(Expr<Float3>, Expr<Float3>, Expr<f32>), Expr<Float4>>(
&|ro, rd, time| {
let t = var!(f32, 0.0);
let col = var!(Float3);
let d = var!(f32);
for_range(0i32..64, |i| {
let p = ro + rd * *t;
*d.get_mut() = map.call(p, time) * 0.5;
if_!(d.cmplt(0.02) | d.cmpgt(100.0), { break_() });
*col.get_mut() += palette.call(p.length() * 0.1 / (400.0 * *d));
*t.get_mut() += *d;
});
let col = *col;
make_float4(col.x(), col.y(), col.z(), 1.0 / (100.0 * *d))
},
);
let clear_kernel = device.create_kernel::<(Tex2d<Float4>,)>(&|img| {
let coord = dispatch_id().xy();
img.write(coord, make_float4(0.3, 0.4, 0.5, 1.0));
});
let main_kernel = device.create_kernel::<(Tex2d<Float4>, f32)>(&|img, time| {
let xy = dispatch_id().xy();
let resolution = dispatch_size().xy();
let uv = (xy.float() - resolution.float() * 0.5) / resolution.x().float();
let r = rotate.call(make_float2(0.0, -50.0), time);
let ro = make_float3(r.x(), r.y(), 0.0).xzy();
let cf = (-ro).normalize();
let cs = cf.cross(make_float3(0.0, 10.0, 0.0)).normalize();
let cu = cf.cross(cs).normalize();
let uuv = ro + cf * 3.0 + uv.x() * cs + uv.y() * cu;
let rd = (uuv - ro).normalize();
let col = rm.call(ro, rd, time);
let color = col.xyz();
let alpha = col.w();
let old = img.read(xy).xyz();
let accum = color.lerp(old, Float3Expr::splat(alpha));
img.write(xy, make_float4(accum.x(), accum.y(), accum.z(), 1.0));
});
}
47 changes: 25 additions & 22 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2171,30 +2171,32 @@ unsafe impl<T: ExprProxy> CallableRet for T {
}
}

pub trait CallableSignature<'a, R: CallableRet> {
pub trait CallableSignature {
type Callable;
type DynCallable;
type Fn: CallableBuildFn;
type Fn: CallableBuildFn + ?Sized;
type StaticFn: StaticCallableBuildFn;
type DynFn: CallableBuildFn + 'static;
type Ret: CallableRet;
fn wrap_raw_callable(callable: RawCallable) -> Self::Callable;
fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable;
}

pub trait KernelSignature<'a> {
type Fn: KernelBuildFn;
pub trait KernelSignature {
type Fn: KernelBuildFn + ?Sized;
type Kernel;

fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel;
}
macro_rules! impl_callable_signature {
()=>{
impl<'a, R: CallableRet +'static> CallableSignature<'a, R> for () {
type Fn = &'a dyn Fn() ->R;
impl<R: CallableRet +'static> CallableSignature for fn()->R {
type Fn = dyn Fn() ->R;
type DynFn = Box<dyn Fn() ->R>;
type StaticFn = fn() -> R;
type Callable = Callable<(), R>;
type DynCallable = DynCallable<(), R>;
type Callable = Callable<fn()->R>;
type DynCallable = DynCallable<fn()->R>;
type Ret = R;
fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{
Callable {
inner: callable,
Expand All @@ -2210,12 +2212,13 @@ macro_rules! impl_callable_signature {
}
};
($first:ident $($rest:ident)*) => {
impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a, R> for ($first, $($rest,)*) {
type Fn = &'a dyn Fn($first, $($rest),*)->R;
impl<R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature for fn($first, $($rest,)*)->R {
type Fn = dyn Fn($first, $($rest),*)->R;
type DynFn = Box<dyn Fn($first, $($rest),*)->R>;
type Callable = Callable<($first, $($rest,)*), R>;
type Callable = Callable<fn($first, $($rest,)*)->R>;
type StaticFn = fn($first, $($rest,)*)->R;
type DynCallable = DynCallable<($first, $($rest,)*), R>;
type DynCallable = DynCallable<fn($first, $($rest,)*)->R>;
type Ret = R;
fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{
Callable {
inner: callable,
Expand All @@ -2235,9 +2238,9 @@ macro_rules! impl_callable_signature {
impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
macro_rules! impl_kernel_signature {
()=>{
impl<'a> KernelSignature<'a> for () {
type Fn = &'a dyn Fn();
type Kernel = Kernel<()>;
impl KernelSignature for fn() {
type Fn = dyn Fn();
type Kernel = Kernel<fn()>;
fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel {
Self::Kernel{
inner:kernel,
Expand All @@ -2247,9 +2250,9 @@ macro_rules! impl_kernel_signature {
}
};
($first:ident $($rest:ident)*) => {
impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for ($first, $($rest,)*) {
type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*);
type Kernel = Kernel<($first, $($rest,)*)>;
impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature for fn($first, $($rest,)*) {
type Fn = dyn Fn($first::Parameter, $($rest::Parameter),*);
type Kernel = Kernel<fn($first, $($rest,)*)>;
fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel {
Self::Kernel{
inner:kernel,
Expand All @@ -2264,7 +2267,7 @@ impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);

macro_rules! impl_callable_build_for_fn {
()=>{
impl<R:CallableRet +'static> CallableBuildFn for &dyn Fn()->R {
impl<R:CallableRet +'static> CallableBuildFn for dyn Fn()->R {
fn build_callable(&self, _args: Option<Rc<dyn Any>>, builder: &mut KernelBuilder)->RawCallable {
builder.build_callable( |_| {
self()
Expand All @@ -2288,7 +2291,7 @@ macro_rules! impl_callable_build_for_fn {
impl <R:CallableRet +'static> StaticCallableBuildFn for fn()->R {}
};
($first:ident $($rest:ident)*) => {
impl<R:CallableRet +'static, $first:CallableParameter, $($rest: CallableParameter),*> CallableBuildFn for &dyn Fn($first, $($rest,)*)->R {
impl<R:CallableRet +'static, $first:CallableParameter, $($rest: CallableParameter),*> CallableBuildFn for dyn Fn($first, $($rest,)*)->R {
#[allow(non_snake_case)]
fn build_callable(&self, args: Option<Rc<dyn Any>>, builder: &mut KernelBuilder)->RawCallable {
builder.build_callable( |builder| {
Expand Down Expand Up @@ -2346,7 +2349,7 @@ macro_rules! impl_callable_build_for_fn {
impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
macro_rules! impl_kernel_build_for_fn {
()=>{
impl KernelBuildFn for &dyn Fn() {
impl KernelBuildFn for dyn Fn() {
fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel {
builder.build_kernel(options, |_| {
self()
Expand All @@ -2355,7 +2358,7 @@ macro_rules! impl_kernel_build_for_fn {
}
};
($first:ident $($rest:ident)*) => {
impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) {
impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for dyn Fn($first, $($rest,)*) {
#[allow(non_snake_case)]
fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel {
builder.build_kernel(options, |builder| {
Expand Down
49 changes: 35 additions & 14 deletions luisa_compute/src/rtx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ impl Mesh {
}

impl Accel {
fn push_handle(&self, handle: InstanceHandle, transform: Mat4, visible: u8, opaque: bool) {
fn push_handle(
&self,
handle: InstanceHandle,
transform: Mat4,
ray_mask: u32,
opaque: bool,
user_id: u32,
) {
let mut flags =
api::AccelBuildModificationFlags::PRIMITIVE | AccelBuildModificationFlags::TRANSFORM;

Expand All @@ -154,8 +161,9 @@ impl Accel {
mesh: handle.handle(),
affine: transform.into_affine3x4(),
flags,
visibility: visible,
visibility: ray_mask,
index,
user_id,
},
);

Expand All @@ -166,8 +174,9 @@ impl Accel {
index: usize,
handle: InstanceHandle,
transform: Mat4,
visible: u8,
ray_mask: u32,
opaque: bool,
user_id: u32,
) {
let mut flags = api::AccelBuildModificationFlags::PRIMITIVE;
dbg!(flags);
Expand All @@ -185,56 +194,68 @@ impl Accel {
mesh: handle.handle(),
affine: transform.into_affine3x4(),
flags,
visibility: visible,
visibility: ray_mask,
index: index as u32,
user_id,
},
);
let mut instance_handles = self.instance_handles.write();
instance_handles[index] = Some(handle);
}
pub fn push_mesh(&self, mesh: &Mesh, transform: Mat4, visible: u8, opaque: bool) {
pub fn push_mesh(&self, mesh: &Mesh, transform: Mat4, ray_mask: u32, opaque: bool) {
self.push_handle(
InstanceHandle::Mesh(mesh.handle.clone()),
transform,
visible,
ray_mask,
opaque,
0,
)
}
pub fn push_procedural_primitive(
&self,
prim: &ProceduralPrimitive,
transform: Mat4,
visible: u8,
ray_mask: u32,
) {
self.push_handle(
InstanceHandle::Procedural(prim.handle.clone()),
transform,
visible,
ray_mask,
false,
0,
)
}
pub fn set_mesh(&self, index: usize, mesh: &Mesh, transform: Mat4, visible: u8, opaque: bool) {
pub fn set_mesh(
&self,
index: usize,
mesh: &Mesh,
transform: Mat4,
ray_mask: u32,
opaque: bool,
) {
self.set_handle(
index,
InstanceHandle::Mesh(mesh.handle.clone()),
transform,
visible,
ray_mask,
opaque,
0,
)
}
pub fn set_procedural_primitive(
&self,
index: usize,
prim: &ProceduralPrimitive,
transform: Mat4,
visible: u8,
ray_mask: u32,
) {
self.set_handle(
index,
InstanceHandle::Procedural(prim.handle.clone()),
transform,
visible,
ray_mask,
false,
0,
)
}
pub fn pop(&self) {
Expand Down Expand Up @@ -348,8 +369,8 @@ pub enum HitType {

pub fn offset_ray_origin(p: Expr<Float3>, n: Expr<Float3>) -> Expr<Float3> {
lazy_static! {
static ref F: Callable<(Expr<Float3>, Expr<Float3>), Expr<Float3>> =
create_static_callable::<(Expr<Float3>, Expr<Float3>), Expr<Float3>>(|p, n| {
static ref F: Callable<fn(Expr<Float3>, Expr<Float3>)-> Expr<Float3>> =
create_static_callable::<fn(Expr<Float3>, Expr<Float3>)->Expr<Float3>>(|p, n| {
const ORIGIN: f32 = 1.0f32 / 32.0f32;
const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32;
const INT_SCALE: f32 = 256.0f32;
Expand Down
Loading

0 comments on commit e5e92e9

Please sign in to comment.