diff --git a/luisa_compute/examples/curve.rs b/luisa_compute/examples/curve.rs new file mode 100644 index 0000000..e71fdf5 --- /dev/null +++ b/luisa_compute/examples/curve.rs @@ -0,0 +1 @@ +fn main() {} \ No newline at end of file diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index b8915de..2dbc645 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -10,7 +10,7 @@ use luisa::lang::types::vector::alias::*; use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ - offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, + offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, AccelTraceOptions, }; use luisa_compute as luisa; @@ -354,20 +354,25 @@ fn main() { let depth = Var::::zeroed(); while depth < 10u32 { - let hit = accel.trace_closest(**ray); + let trace_options = AccelTraceOptions { + mask:0xff.expr(), + ..Default::default() + }; + let hit = accel.intersect(ray, trace_options); if !hit.valid() { break; } - let vertex_buffer = vertex_heap.buffer::<[f32; 3]>(hit.inst_id); - let triangle = index_heap.buffer::(hit.inst_id).read(hit.prim_id); + let vertex_buffer = vertex_heap.buffer::<[f32; 3]>(hit.inst); + let triangle = index_heap.buffer::(hit.inst).read(hit.prim); let p0: Expr = vertex_buffer.read(triangle[0]).into(); let p1: Expr = vertex_buffer.read(triangle[1]).into(); let p2: Expr = vertex_buffer.read(triangle[2]).into(); - let p = p0 * (1.0f32 - hit.u - hit.v) + p1 * hit.u + p2 * hit.v; + let bary = hit.triangle_barycentric_coord(); + let p = p0 * (1.0f32 - bary.x - bary.y) + p1 * bary.x + p2 * bary.y; let n = (p1 - p0).cross(p2 - p0).normalize(); let origin: Expr = (**ray.orig).into(); @@ -377,9 +382,9 @@ fn main() { break; } let pp = offset_ray_origin(p, n); - let albedo = cbox_materials.read(hit.inst_id); + let albedo = cbox_materials.read(hit.inst); // hit light - if hit.inst_id == 7u32 { + if hit.inst == 7u32 { if depth == 0u32 { radiance.store(radiance + light_emission); } else { @@ -399,7 +404,7 @@ fn main() { let wi_light = (pp_light - pp).normalize(); let shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.0f32.expr(), d_light); - let occluded = accel.trace_any(shadow_ray); + let occluded = accel.intersect_any(shadow_ray, trace_options); let cos_wi_light = wi_light.dot(n); let cos_light = -light_normal.dot(wi_light); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 16bd522..b516988 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -1,6 +1,7 @@ use image::Rgb; use luisa::lang::types::vector::alias::*; use luisa::lang::types::vector::Mat4; +use luisa::rtx::AccelTraceOptions; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; @@ -372,35 +373,31 @@ fn main() { let depth = Var::::zeroed(); while depth < 10 { - // let hit = accel.trace_closest(ray); - let hit = accel.query_all( - ray, - 255, - RayQuery { - on_triangle_hit: |c: SurfaceCandidate| { - if filter(&c) { - c.commit(); - } - }, - on_procedural_hit: |_c| {}, - }, - ); + let trace_options = AccelTraceOptions { + mask: 0xff.expr(), + ..Default::default() + }; + let hit = accel + .traverse(ray, trace_options) + .on_surface_hit(|c: SurfaceCandidate| { + if filter(&c) { + c.commit(); + } + }) + .trace(); if hit.miss() { break; } - let vertex_buffer = vertex_heap.var().buffer::<[f32; 3]>(hit.inst_id); - let triangle = index_heap - .var() - .buffer::(hit.inst_id) - .read(hit.prim_id); + let vertex_buffer = vertex_heap.var().buffer::<[f32; 3]>(hit.inst); + let triangle = index_heap.var().buffer::(hit.inst).read(hit.prim); let p0: Expr = vertex_buffer.read(triangle[0]).into(); let p1: Expr = vertex_buffer.read(triangle[1]).into(); let p2: Expr = vertex_buffer.read(triangle[2]).into(); - let m = accel.instance_transform(hit.inst_id); + let m = accel.instance_transform(hit.inst); let p = p0 * (1.0f32 - hit.bary.x - hit.bary.y) + p1 * hit.bary.x + p2 * hit.bary.y; let p = (m * Float4::expr(p.x, p.y, p.z, 1.0f32)).xyz(); @@ -414,9 +411,9 @@ fn main() { break; } let pp = offset_ray_origin(p, n); - let albedo = cbox_materials.read(hit.inst_id); + let albedo = cbox_materials.read(hit.inst); // hit light - if hit.inst_id == 7u32 { + if hit.inst == 7u32 { if depth == 0u32 { radiance.store(radiance + light_emission); } else { @@ -436,18 +433,14 @@ fn main() { let wi_light = (pp_light - pp).normalize(); let shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.0f32.expr(), d_light); - let occluded = accel.query_any( - shadow_ray, - 255, - RayQuery { - on_triangle_hit: |c: SurfaceCandidate| { - if_!(filter(&c), { - c.commit(); - }); - }, - on_procedural_hit: |_c| {}, - }, - ); + let occluded = accel + .traverse_any(shadow_ray, trace_options) + .on_surface_hit(|c: SurfaceCandidate| { + if filter(&c) { + c.commit(); + } + }) + .trace(); let occluded = !occluded.miss(); let cos_wi_light = wi_light.dot(n); let cos_light = -light_normal.dot(wi_light); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 6d492d7..fbf77aa 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -191,7 +191,7 @@ fn main() { uvw } else { if hit.procedural_hit() { - let prim = hit.prim_id; + let prim = hit.prim; let sphere = spheres.var().read(prim); let normal = (Expr::::from(ray.orig) + Expr::::from(ray.dir) * hit.committed_ray_t diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 97c2c3e..2404916 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -6,7 +6,7 @@ use luisa::lang::types::vector::*; use luisa::lang::types::*; use luisa::prelude::*; -use luisa::rtx::{AccelBuildRequest, AccelOption, Ray}; +use luisa::rtx::{AccelBuildRequest, AccelOption, AccelTraceOptions, Ray}; use luisa_compute as luisa; use winit::event::{Event as WinitEvent, WindowEvent}; use winit::event_loop::{ControlFlow, EventLoop}; @@ -54,11 +54,17 @@ fn main() { Expr::<[f32; 3]>::from(d), 1e9, ); - let hit = accel.trace_closest(ray); + let hit = accel.intersect( + ray, + AccelTraceOptions { + mask: 0xff.expr(), + ..Default::default() + }, + ); let img = img.view(0).var(); let color = select( hit.valid(), - Float3::expr(hit.u, hit.v, 1.0), + hit.triangle_barycentric_coord().extend(1.0), Float3::expr(0.0, 0.0, 0.0), ); img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 7958db2..fdc68d7 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -320,7 +320,7 @@ pub(crate) struct FnRecorder { pub(crate) type FnRecorderPtr = Rc>; impl FnRecorder { pub(crate) fn add_required_curve_basis(&mut self, basis: CurveBasisSet) { - self.curve_bases.merge(basis); + self.curve_bases.insert(basis); } pub(crate) fn make_index_const(&mut self, idx: i32) -> NodeRef { if let Some(node) = self.index_const_pool.get(&idx) { diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 17115c2..ce7b0da 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -345,9 +345,12 @@ pub struct Aabb { #[repr(C)] #[derive(Clone, Copy, Value, Debug, Soa)] -pub struct TriangleHit { +pub struct SurfaceHit { pub inst: u32, pub prim: u32, + /// Don't use directly + /// + /// use [`SurfaceHitExpr::triangle_barycentric_coord`] and [`SurfaceHitExpr::curve_parameter`] to access pub bary: Float2, pub committed_ray_t: f32, } @@ -361,29 +364,37 @@ pub struct ProceduralHit { #[repr(C)] #[derive(Clone, Copy, Value, Debug, Soa)] -pub struct SurfaceHit { - pub inst_id: u32, - pub prim_id: u32, +pub struct CommittedHit { + pub inst: u32, + pub prim: u32, pub bary: Float2, pub hit_type: u32, pub committed_ray_t: f32, } -impl SurfaceHitExpr { +impl CommittedHitExpr { pub fn miss(&self) -> Expr { self.hit_type.eq(HitType::Miss as u32) } + #[tracked] pub fn triangle_hit(&self) -> Expr { - self.hit_type.eq(HitType::Triangle as u32) + self.hit_type.eq(HitType::Surface as u32) & self.bary.y.ge(0.0) } + #[tracked] pub fn procedural_hit(&self) -> Expr { - self.hit_type.eq(HitType::Procedural as u32) + self.hit_type.eq(HitType::Procedural as u32) & self.bary.y.lt(0.0) + } + pub fn curve_parameter(&self) -> Expr { + self.bary.x + } + pub fn triangle_barycentric_coord(&self) -> Expr { + self.bary } } #[derive(Clone, Copy)] #[repr(u32)] pub enum HitType { Miss = 0, - Triangle = 1, + Surface = 1, Procedural = 2, } #[tracked] @@ -419,7 +430,6 @@ pub struct Hit { pub v: f32, pub t: f32, } -pub type CommitedHit = SurfaceHit; #[cfg(test)] mod test { @@ -443,11 +453,33 @@ impl HitExpr { } } +impl SurfaceHitExpr { + pub fn valid(&self) -> Expr { + self.inst.ne(u32::MAX) + } + pub fn miss(&self) -> Expr { + self.prim.eq(u32::MAX) + } + pub fn is_curve(&self) -> Expr { + self.bary.y.lt(0.0) + } + pub fn is_triangle(&self) -> Expr { + !self.is_curve() + } + pub fn curve_parameter(&self) -> Expr { + self.bary.x + } + pub fn triangle_barycentric_coord(&self) -> Expr { + self.bary + } +} + #[derive(Clone, Copy)] pub struct SurfaceCandidate { query: SafeNodeRef, - hit: Expr, + hit: Expr, } +pub type TriangleHit = SurfaceHit; #[deprecated(note = "Use `SurfaceCandidate` instead")] pub type TriangleCandidate = SurfaceCandidate; @@ -474,7 +506,7 @@ impl SurfaceCandidate { } } impl Deref for SurfaceCandidate { - type Target = TriangleHitExpr; + type Target = SurfaceHitExpr; fn deref(&self) -> &Self::Target { &self.hit } @@ -513,13 +545,21 @@ pub struct AccelTraceOptions { pub curve_bases: CurveBasisSet, pub mask: Expr, } +impl Default for AccelTraceOptions { + fn default() -> Self { + Self { + curve_bases: CurveBasisSet::empty(), + mask: u32::MAX.expr(), + } + } +} pub struct RayQueryBase { query: SafeNodeRef, on_surface_hit: Option>, on_procedural_hit: Option>, } impl RayQueryBase { - pub fn on_surface_hit(mut self, f: impl Fn(SurfaceCandidate)) -> Self { + pub fn on_surface_hit(self, f: impl Fn(SurfaceCandidate)) -> Self { assert!( self.on_surface_hit.is_none(), "Surface hit already recorded" @@ -550,7 +590,7 @@ impl RayQueryBase { ..self } } - pub fn on_procedural_hit(mut self, f: impl Fn(ProceduralCandidate)) -> Self { + pub fn on_procedural_hit(self, f: impl Fn(ProceduralCandidate)) -> Self { assert!( self.on_procedural_hit.is_none(), "Procedural hit already recorded" @@ -581,10 +621,24 @@ impl RayQueryBase { ..self } } - pub fn trace(self) -> Expr { - todo!() + pub fn trace(self) -> Expr { + let query = self.query.get(); + let on_surface_hit = self + .on_surface_hit + .unwrap_or_else(|| IrBuilder::new(__module_pools().clone()).finish()); + let on_procedural_hit = self + .on_procedural_hit + .unwrap_or_else(|| IrBuilder::new(__module_pools().clone()).finish()); + FromNode::from_node( + __current_scope(|b| { + b.ray_query(query, on_surface_hit, on_procedural_hit, Type::void()); + b.call(Func::RayQueryCommittedHit, &[query], CommittedHit::type_()) + }) + .into(), + ) } } + impl AccelVar { pub fn instance_transform(&self, index: Expr) -> Expr { let index = index.node().get(); @@ -647,6 +701,7 @@ impl AccelVar { } #[deprecated(note = "Use `intersect` instead")] + #[allow(deprecated)] pub fn trace_closest_masked( &self, ray: impl AsExpr, @@ -723,7 +778,6 @@ impl AccelVar { ), ) }); - let recorder = with_recorder(|r| r as *const _); RayQueryBase { query: query.into(), on_procedural_hit: None, @@ -750,7 +804,7 @@ impl AccelVar { ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), @@ -764,7 +818,7 @@ impl AccelVar { ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), @@ -778,7 +832,7 @@ impl AccelVar { ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), @@ -848,7 +902,7 @@ impl AccelVar { FromNode::from_node( __current_scope(|b| { b.ray_query(query, on_triangle_hit, on_procedural_hit, Type::void()); - b.call(Func::RayQueryCommittedHit, &[query], SurfaceHit::type_()) + b.call(Func::RayQueryCommittedHit, &[query], CommittedHit::type_()) }) .into(), ) diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index ce8eb78..94e0175 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -379,6 +379,7 @@ impl KernelBuilder { cpu_custom_ops_set.insert(CArc::as_ptr(op) as u64); } for c in &callables { + r.add_required_curve_basis(c.0.module.curve_basis_set); for capture in c.0.captures.as_ref() { if !captured_set.contains(capture) { captured_set.insert(*capture); @@ -428,6 +429,7 @@ impl KernelBuilder { let entry = const_block; r.add_block_to_inaccessible(&entry); let ir_module = Module { + curve_basis_set: r.curve_bases, entry, kind: ModuleKind::Kernel, pools: r.pools.clone(), @@ -487,6 +489,7 @@ impl KernelBuilder { let entry = const_block; assert!(r.captured_vars.is_empty()); let ir_module = Module { + curve_basis_set: r.curve_bases, entry, kind: ModuleKind::Kernel, pools: r.pools.clone(), diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index d4400cf..150037c 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit d4400cf49b6ca40854d6cbb5191f9b8033ffce82 +Subproject commit 150037c996ccae86a3d8d42b438dc33baab6754d