diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index 50feb5c..dae9afc 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -236,7 +236,7 @@ fn main() { vertex_heap.emplace_buffer_async(index, vertex_buffers.last().unwrap()); index_heap.emplace_buffer_async(index, index_buffers.last().unwrap()); cmds.push(mesh.build_async(AccelBuildRequest::ForceBuild)); - accel.push_mesh(&mesh, glam::Mat4::IDENTITY.into(), u8::MAX, true); + accel.push_mesh(&mesh, glam::Mat4::IDENTITY.into(), u32::MAX, true); } cmds.push(vertex_heap.update_async()); cmds.push(index_heap.update_async()); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index cd2f9db..aed14ef 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -242,7 +242,7 @@ fn main() { } else { glam::Mat4::IDENTITY }; - accel.push_mesh(&mesh, m.into(), u8::MAX, false); + accel.push_mesh(&mesh, m.into(), u32::MAX, false); } cmds.push(vertex_heap.update_async()); cmds.push(index_heap.update_async()); diff --git a/luisa_compute/examples/shadertoy.rs b/luisa_compute/examples/shadertoy.rs new file mode 100644 index 0000000..ff7be00 --- /dev/null +++ b/luisa_compute/examples/shadertoy.rs @@ -0,0 +1,81 @@ +use luisa::prelude::*; +use luisa_compute as luisa; +use std::env::current_exe; +fn main() { + use luisa::*; + init_logger(); + + std::env::set_var("WINIT_UNIX_BACKEND", "x11"); + + let args: Vec = std::env::args().collect(); + assert!( + args.len() <= 2, + "Usage: {} . : cpu, cuda, dx, metal, remote", + args[0] + ); + + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device(if args.len() == 2 { + args[1].as_str() + } else { + "cpu" + }); + + let palette = device.create_callable::<(Expr,), Expr>(&|d| { + make_float3(0.2, 0.7, 0.9).lerp(make_float3(1.0, 0.0, 1.0), Float3Expr::splat(d)) + }); + let rotate = device.create_callable::<(Expr, Expr), Expr>(&|mut p, a| { + let c = a.cos(); + let s = a.sin(); + make_float2(p.dot(make_float2(c, s)), p.dot(make_float2(-s, c))) + }); + let map = device.create_callable::<(Expr, Expr), Expr>(&|mut p, time| { + for i in 0..8 { + let t = time * 0.2; + let r = rotate.call(p.xz(), t); + p = make_float3(r.x(), r.y(), p.y()).xzy(); + let r = rotate.call(p.xy(), t * 1.89); + p = make_float3(r.x(), r.y(), p.z()); + p = make_float3(p.x().abs() - 0.5, p.y(), p.z().abs() - 0.5) + } + Float3Expr::splat(1.0).copysign(p).dot(p) * 0.2 + }); + let rm = device.create_callable::<(Expr, Expr, Expr), Expr>( + &|ro, rd, time| { + let t = var!(f32, 0.0); + let col = var!(Float3); + let d = var!(f32); + for_range(0i32..64, |i| { + let p = ro + rd * *t; + *d.get_mut() = map.call(p, time) * 0.5; + if_!(d.cmplt(0.02) | d.cmpgt(100.0), { break_() }); + *col.get_mut() += palette.call(p.length() * 0.1 / (400.0 * *d)); + *t.get_mut() += *d; + }); + let col = *col; + make_float4(col.x(), col.y(), col.z(), 1.0 / (100.0 * *d)) + }, + ); + let clear_kernel = device.create_kernel::<(Tex2d,)>(&|img| { + let coord = dispatch_id().xy(); + img.write(coord, make_float4(0.3, 0.4, 0.5, 1.0)); + }); + let main_kernel = device.create_kernel::<(Tex2d, f32)>(&|img, time| { + let xy = dispatch_id().xy(); + let resolution = dispatch_size().xy(); + let uv = (xy.float() - resolution.float() * 0.5) / resolution.x().float(); + let r = rotate.call(make_float2(0.0, -50.0), time); + let ro = make_float3(r.x(), r.y(), 0.0).xzy(); + let cf = (-ro).normalize(); + let cs = cf.cross(make_float3(0.0, 10.0, 0.0)).normalize(); + let cu = cf.cross(cs).normalize(); + let uuv = ro + cf * 3.0 + uv.x() * cs + uv.y() * cu; + let rd = (uuv - ro).normalize(); + let col = rm.call(ro, rd, time); + let color = col.xyz(); + let alpha = col.w(); + let old = img.read(xy).xyz(); + let accum = color.lerp(old, Float3Expr::splat(alpha)); + img.write(xy, make_float4(accum.x(), accum.y(), accum.z(), 1.0)); + }); +} diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index ff43f5e..8e28be8 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -2171,30 +2171,32 @@ unsafe impl CallableRet for T { } } -pub trait CallableSignature<'a, R: CallableRet> { +pub trait CallableSignature { type Callable; type DynCallable; - type Fn: CallableBuildFn; + type Fn: CallableBuildFn + ?Sized; type StaticFn: StaticCallableBuildFn; type DynFn: CallableBuildFn + 'static; + type Ret: CallableRet; fn wrap_raw_callable(callable: RawCallable) -> Self::Callable; fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; } -pub trait KernelSignature<'a> { - type Fn: KernelBuildFn; +pub trait KernelSignature { + type Fn: KernelBuildFn + ?Sized; type Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; } macro_rules! impl_callable_signature { ()=>{ - impl<'a, R: CallableRet +'static> CallableSignature<'a, R> for () { - type Fn = &'a dyn Fn() ->R; + impl CallableSignature for fn()->R { + type Fn = dyn Fn() ->R; type DynFn = BoxR>; type StaticFn = fn() -> R; - type Callable = Callable<(), R>; - type DynCallable = DynCallable<(), R>; + type Callable = CallableR>; + type DynCallable = DynCallableR>; + type Ret = R; fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ Callable { inner: callable, @@ -2210,12 +2212,13 @@ macro_rules! impl_callable_signature { } }; ($first:ident $($rest:ident)*) => { - impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a, R> for ($first, $($rest,)*) { - type Fn = &'a dyn Fn($first, $($rest),*)->R; + impl CallableSignature for fn($first, $($rest,)*)->R { + type Fn = dyn Fn($first, $($rest),*)->R; type DynFn = BoxR>; - type Callable = Callable<($first, $($rest,)*), R>; + type Callable = CallableR>; type StaticFn = fn($first, $($rest,)*)->R; - type DynCallable = DynCallable<($first, $($rest,)*), R>; + type DynCallable = DynCallableR>; + type Ret = R; fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ Callable { inner: callable, @@ -2235,9 +2238,9 @@ macro_rules! impl_callable_signature { impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_kernel_signature { ()=>{ - impl<'a> KernelSignature<'a> for () { - type Fn = &'a dyn Fn(); - type Kernel = Kernel<()>; + impl KernelSignature for fn() { + type Fn = dyn Fn(); + type Kernel = Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { Self::Kernel{ inner:kernel, @@ -2247,9 +2250,9 @@ macro_rules! impl_kernel_signature { } }; ($first:ident $($rest:ident)*) => { - impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for ($first, $($rest,)*) { - type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); - type Kernel = Kernel<($first, $($rest,)*)>; + impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature for fn($first, $($rest,)*) { + type Fn = dyn Fn($first::Parameter, $($rest::Parameter),*); + type Kernel = Kernel; fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { Self::Kernel{ inner:kernel, @@ -2264,7 +2267,7 @@ impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_callable_build_for_fn { ()=>{ - impl CallableBuildFn for &dyn Fn()->R { + impl CallableBuildFn for dyn Fn()->R { fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |_| { self() @@ -2288,7 +2291,7 @@ macro_rules! impl_callable_build_for_fn { impl StaticCallableBuildFn for fn()->R {} }; ($first:ident $($rest:ident)*) => { - impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { + impl CallableBuildFn for dyn Fn($first, $($rest,)*)->R { #[allow(non_snake_case)] fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |builder| { @@ -2346,7 +2349,7 @@ macro_rules! impl_callable_build_for_fn { impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_kernel_build_for_fn { ()=>{ - impl KernelBuildFn for &dyn Fn() { + impl KernelBuildFn for dyn Fn() { fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { builder.build_kernel(options, |_| { self() @@ -2355,7 +2358,7 @@ macro_rules! impl_kernel_build_for_fn { } }; ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { + impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for dyn Fn($first, $($rest,)*) { #[allow(non_snake_case)] fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { builder.build_kernel(options, |builder| { diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 5b54ba6..8bbd5d1 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -134,7 +134,14 @@ impl Mesh { } impl Accel { - fn push_handle(&self, handle: InstanceHandle, transform: Mat4, visible: u8, opaque: bool) { + fn push_handle( + &self, + handle: InstanceHandle, + transform: Mat4, + ray_mask: u32, + opaque: bool, + user_id: u32, + ) { let mut flags = api::AccelBuildModificationFlags::PRIMITIVE | AccelBuildModificationFlags::TRANSFORM; @@ -154,8 +161,9 @@ impl Accel { mesh: handle.handle(), affine: transform.into_affine3x4(), flags, - visibility: visible, + visibility: ray_mask, index, + user_id, }, ); @@ -166,8 +174,9 @@ impl Accel { index: usize, handle: InstanceHandle, transform: Mat4, - visible: u8, + ray_mask: u32, opaque: bool, + user_id: u32, ) { let mut flags = api::AccelBuildModificationFlags::PRIMITIVE; dbg!(flags); @@ -185,41 +194,52 @@ impl Accel { mesh: handle.handle(), affine: transform.into_affine3x4(), flags, - visibility: visible, + visibility: ray_mask, index: index as u32, + user_id, }, ); let mut instance_handles = self.instance_handles.write(); instance_handles[index] = Some(handle); } - pub fn push_mesh(&self, mesh: &Mesh, transform: Mat4, visible: u8, opaque: bool) { + pub fn push_mesh(&self, mesh: &Mesh, transform: Mat4, ray_mask: u32, opaque: bool) { self.push_handle( InstanceHandle::Mesh(mesh.handle.clone()), transform, - visible, + ray_mask, opaque, + 0, ) } pub fn push_procedural_primitive( &self, prim: &ProceduralPrimitive, transform: Mat4, - visible: u8, + ray_mask: u32, ) { self.push_handle( InstanceHandle::Procedural(prim.handle.clone()), transform, - visible, + ray_mask, false, + 0, ) } - pub fn set_mesh(&self, index: usize, mesh: &Mesh, transform: Mat4, visible: u8, opaque: bool) { + pub fn set_mesh( + &self, + index: usize, + mesh: &Mesh, + transform: Mat4, + ray_mask: u32, + opaque: bool, + ) { self.set_handle( index, InstanceHandle::Mesh(mesh.handle.clone()), transform, - visible, + ray_mask, opaque, + 0, ) } pub fn set_procedural_primitive( @@ -227,14 +247,15 @@ impl Accel { index: usize, prim: &ProceduralPrimitive, transform: Mat4, - visible: u8, + ray_mask: u32, ) { self.set_handle( index, InstanceHandle::Procedural(prim.handle.clone()), transform, - visible, + ray_mask, false, + 0, ) } pub fn pop(&self) { @@ -348,8 +369,8 @@ pub enum HitType { pub fn offset_ray_origin(p: Expr, n: Expr) -> Expr { lazy_static! { - static ref F: Callable<(Expr, Expr), Expr> = - create_static_callable::<(Expr, Expr), Expr>(|p, n| { + static ref F: Callable, Expr)-> Expr> = + create_static_callable::, Expr)->Expr>(|p, n| { const ORIGIN: f32 = 1.0f32 / 32.0f32; const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32; const INT_SCALE: f32 = 256.0f32; diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 6304364..15f20b6 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -367,36 +367,27 @@ impl Device { modifications: RwLock::new(HashMap::new()), } } - pub fn create_callable<'a, S: CallableSignature<'a, R>, R: CallableRet>( - &self, - f: S::Fn, - ) -> S::Callable { + pub fn create_callable(&self, f: &S::Fn) -> S::Callable { let mut builder = KernelBuilder::new(Some(self.clone()), false); - let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); + let raw_callable = CallableBuildFn::build_callable(f, None, &mut builder); S::wrap_raw_callable(raw_callable) } - pub fn create_dyn_callable<'a, S: CallableSignature<'a, R>, R: CallableRet>( - &self, - f: S::DynFn, - ) -> S::DynCallable { + pub fn create_dyn_callable(&self, f: S::DynFn) -> S::DynCallable { S::create_dyn_callable(self.clone(), false, f) } - pub fn create_dyn_callable_once<'a, S: CallableSignature<'a, R>, R: CallableRet>( - &self, - f: S::DynFn, - ) -> S::DynCallable { + pub fn create_dyn_callable_once(&self, f: S::DynFn) -> S::DynCallable { S::create_dyn_callable(self.clone(), true, f) } - pub fn create_kernel<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { + pub fn create_kernel(&self, f: &S::Fn) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); let raw_kernel = - KernelBuildFn::build_kernel(&f, &mut builder, KernelBuildOptions::default()); + KernelBuildFn::build_kernel(f, &mut builder, KernelBuildOptions::default()); S::wrap_raw_kernel(raw_kernel) } - pub fn create_kernel_async<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { + pub fn create_kernel_async(&self, f: &S::Fn) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); let raw_kernel = KernelBuildFn::build_kernel( - &f, + f, &mut builder, KernelBuildOptions { async_compile: true, @@ -405,20 +396,18 @@ impl Device { ); S::wrap_raw_kernel(raw_kernel) } - pub fn create_kernel_with_options<'a, S: KernelSignature<'a>>( + pub fn create_kernel_with_options( &self, - f: S::Fn, + f: &S::Fn, options: KernelBuildOptions, ) -> S::Kernel { let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel(&f, &mut builder, options); + let raw_kernel = KernelBuildFn::build_kernel(f, &mut builder, options); S::wrap_raw_kernel(raw_kernel) } } -pub fn create_static_callable, R: CallableRet>( - f: S::StaticFn, -) -> S::Callable { +pub fn create_static_callable(f: S::StaticFn) -> S::Callable { let r_backup = RECORDER.with(|r| { let mut r = r.borrow_mut(); std::mem::replace(&mut *r, Recorder::new()) @@ -1163,26 +1152,26 @@ impl RawKernel { } } -pub struct Callable { +pub struct Callable { #[allow(dead_code)] pub(crate) inner: RawCallable, - pub(crate) _marker: std::marker::PhantomData<(T, R)>, + pub(crate) _marker: std::marker::PhantomData, } -pub(crate) struct DynCallableInner { - builder: Box, &mut KernelBuilder) -> Callable>, - callables: Vec>, +pub(crate) struct DynCallableInner { + builder: Box, &mut KernelBuilder) -> Callable>, + callables: Vec>, } -pub struct DynCallable { +pub struct DynCallable { #[allow(dead_code)] - pub(crate) inner: RefCell>, + pub(crate) inner: RefCell>, pub(crate) device: Device, pub(crate) init_once: bool, } -impl DynCallable { +impl DynCallable { pub(crate) fn new( device: Device, init_once: bool, - builder: Box, &mut KernelBuilder) -> Callable>, + builder: Box, &mut KernelBuilder) -> Callable>, ) -> Self { Self { device, @@ -1193,7 +1182,7 @@ impl DynCallable { init_once, } } - fn call_impl(&self, args: std::rc::Rc, nodes: &[NodeRef]) -> R { + fn call_impl(&self, args: std::rc::Rc, nodes: &[NodeRef]) -> S::Ret { RECORDER.with(|r| { if let Some(device) = r.borrow().device.as_ref() { assert!( @@ -1249,13 +1238,13 @@ pub struct RawCallable { pub(crate) resource_tracker: ResourceTracker, } -pub struct Kernel { +pub struct Kernel { pub(crate) inner: RawKernel, pub(crate) _marker: std::marker::PhantomData, } -unsafe impl Send for Kernel {} -unsafe impl Sync for Kernel {} -impl Kernel { +unsafe impl Send for Kernel {} +unsafe impl Sync for Kernel {} +impl Kernel { pub fn cache_dir(&self) -> Option { let handle = self.inner.unwrap(); let device = &self.inner.device; @@ -1307,7 +1296,7 @@ impl AsKernelArg for BindlessArray {} impl AsKernelArg for Accel {} macro_rules! impl_call_for_callable { ($first:ident $($rest:ident)*) => { - impl Callable<($first, $($rest,)*), R> { + impl CallableR> { #[allow(non_snake_case)] pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { let mut encoder = CallableArgEncoder::new(); @@ -1317,7 +1306,7 @@ macro_rules! impl_call_for_callable { lang::__invoke_callable(&self.inner.module, &encoder.args)) } } - impl DynCallable<($first, $($rest,)*), R> { + impl DynCallableR> { #[allow(non_snake_case)] pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { let mut encoder = CallableArgEncoder::new(); @@ -1329,13 +1318,13 @@ macro_rules! impl_call_for_callable { impl_call_for_callable!($($rest)*); }; ()=>{ - impl Callable<(), R> { + impl CallableR> { pub fn call(&self)->R { CallableRet::_from_return( lang::__invoke_callable(&self.inner.module, &[])) } } - impl DynCallable<(), R> { + impl DynCallableR> { pub fn call(&self)-> R{ self.call_impl(std::rc::Rc::new(()), &[]) } @@ -1346,7 +1335,7 @@ impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_dispatch_for_kernel { ($first:ident $($rest:ident)*) => { - impl <$first:KernelArg, $($rest: KernelArg),*> Kernel<($first, $($rest,)*)> { + impl <$first:KernelArg, $($rest: KernelArg),*> Kernel { #[allow(non_snake_case)] pub fn dispatch(&self, dispatch_size: [u32; 3], $first:&impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),*) { let mut encoder = KernelArgEncoder::new(); @@ -1368,7 +1357,7 @@ macro_rules! impl_dispatch_for_kernel { impl_dispatch_for_kernel!($($rest)*); }; ()=>{ - impl Kernel<()> { + impl Kernel { pub fn dispatch(&self, dispatch_size: [u32; 3]) { self.inner.dispatch(KernelArgEncoder::new(), dispatch_size) } diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 3a52972..442caed 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -618,7 +618,7 @@ fn autodiff_select() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -662,7 +662,7 @@ fn autodiff_detach() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); @@ -712,7 +712,7 @@ fn autodiff_select_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::<()>(&|| { + let kernel = device.create_kernel::(&|| { let buf_x = x.var(); let buf_y = y.var(); let buf_dx = dx.var(); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index feb8432..5dbfcd9 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit feb8432199d4d405e8aacca37cf97adac8ea72fc +Subproject commit 5dbfcd971e969d66dc4c5abd7ba69ab1231c1336