diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index e0bc581..16bd522 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -11,7 +11,7 @@ use winit::event_loop::EventLoop; use luisa::prelude::*; use luisa::rtx::{ offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, - RayQuery, TriangleCandidate, + RayQuery, SurfaceCandidate, }; use luisa_compute as luisa; @@ -359,7 +359,7 @@ fn main() { let light_area = light_u.cross(light_v).length(); let light_normal = light_u.cross(light_v).normalize(); - let filter = |c: &TriangleCandidate| { + let filter = |c: &SurfaceCandidate| { let valid = true.var(); if c.inst == 5u32 { *valid = (c.bary.y * 6.0f32).fract() < 0.6f32; @@ -377,7 +377,7 @@ fn main() { ray, 255, RayQuery { - on_triangle_hit: |c: TriangleCandidate| { + on_triangle_hit: |c: SurfaceCandidate| { if filter(&c) { c.commit(); } @@ -440,7 +440,7 @@ fn main() { shadow_ray, 255, RayQuery { - on_triangle_hit: |c: TriangleCandidate| { + on_triangle_hit: |c: SurfaceCandidate| { if_!(filter(&c), { c.commit(); }); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 545c2d2..6d492d7 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -5,7 +5,7 @@ use luisa::lang::types::vector::alias::*; use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ - Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, TriangleCandidate, + Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, SurfaceCandidate, }; use luisa_compute as luisa; use winit::event::{Event as WinitEvent, WindowEvent}; @@ -140,7 +140,7 @@ fn main() { ray, 255, RayQuery { - on_triangle_hit: |candidate: TriangleCandidate| { + on_triangle_hit: |candidate: SurfaceCandidate| { let bary = candidate.bary; let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); let t = candidate.committed_ray_t; diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index a65fead..7958db2 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -11,6 +11,7 @@ use crate::internal_prelude::*; use bumpalo::Bump; use indexmap::IndexMap; +use luisa_compute_ir::ir::CurveBasisSet; use crate::runtime::{RawCallable, WeakDevice}; @@ -314,9 +315,13 @@ pub(crate) struct FnRecorder { pub(crate) const_builder: IrBuilder, pub(crate) index_const_pool: IndexMap, pub(crate) rt: ResourceTracker, + pub(crate) curve_bases: CurveBasisSet, } pub(crate) type FnRecorderPtr = Rc>; impl FnRecorder { + pub(crate) fn add_required_curve_basis(&mut self, basis: CurveBasisSet) { + self.curve_bases.merge(basis); + } pub(crate) fn make_index_const(&mut self, idx: i32) -> NodeRef { if let Some(node) = self.index_const_pool.get(&idx) { return *node; @@ -412,6 +417,7 @@ impl FnRecorder { .map(|p| p.borrow().inaccessible.clone()) .unwrap_or_else(|| Rc::new(RefCell::new(HashSet::new()))), scopes: vec![], + curve_bases: CurveBasisSet::empty(), captured_resources: IndexMap::new(), cpu_custom_ops: IndexMap::new(), callables: IndexMap::new(), @@ -561,6 +567,7 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef { if node.node.is_user_data() { return node; } + if r.inaccessible.borrow().contains(&node.node) { panic!( r#"Detected using node outside of its scope. It is possible that you use `RefCell` or `Cell` to store an `Expr` or `Var` @@ -573,6 +580,10 @@ Please define a `Var` in the parent scope and assign to it instead!"# if ptr == node.recorder { return node; } + let ty = node.node.type_(); + if ty.is_opaque("LC_RayQueryAny") || ty.is_opaque("LC_RayQueryAll") { + panic!("RayQuery cannot be captured!"); + } r.map_captured_vars(node) }) } diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index e7da0a3..17115c2 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -6,6 +6,7 @@ use crate::internal_prelude::*; use crate::runtime::*; use crate::{ResourceTracker, *}; +use luisa_compute_ir::ir::CurveBasisSet; use luisa_compute_ir::ir::{AccelBinding, Binding, Func, Instruction, IrBuilder, Node, Type}; use parking_lot::RwLock; use std::ops::Deref; @@ -360,14 +361,14 @@ pub struct ProceduralHit { #[repr(C)] #[derive(Clone, Copy, Value, Debug, Soa)] -pub struct CommittedHit { +pub struct SurfaceHit { pub inst_id: u32, pub prim_id: u32, pub bary: Float2, pub hit_type: u32, pub committed_ray_t: f32, } -impl CommittedHitExpr { +impl SurfaceHitExpr { pub fn miss(&self) -> Expr { self.hit_type.eq(HitType::Miss as u32) } @@ -385,33 +386,32 @@ pub enum HitType { Triangle = 1, Procedural = 2, } - +#[tracked] pub fn offset_ray_origin( p: impl AsExpr, n: impl AsExpr, ) -> Expr { - lazy_static! { - static ref F: Callable, Expr) -> Expr> = - Callable::, Expr) -> Expr>::new_static(|p, n| { - const ORIGIN: f32 = 1.0f32 / 32.0f32; - const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32; - const INT_SCALE: f32 = 256.0f32; - track!(unsafe { - let of_i = (INT_SCALE * n).as_int3(); - let p_i = p.bitcast::() + p.lt(0.0f32).select(-of_i, of_i); - (p.abs() < ORIGIN).select(p + FLOAT_SCALE * n, p_i.bitcast::()) - }) - }); - } - let p: Expr = p.as_expr(); - let n: Expr = n.as_expr(); - F.call(p, n) + let ret = Var::::zeroed(); + outline(|| { + let p: Expr = p.as_expr(); + let n: Expr = n.as_expr(); + let origin: f32 = 1.0f32 / 32.0f32; + let float_scale: f32 = 1.0f32 / 65536.0f32; + let int_scale: f32 = 256.0f32; + unsafe { + let of_i = (int_scale * n).as_int3(); + let p_i = p.bitcast::() + p.lt(0.0f32).select(-of_i, of_i); + *ret = (p.abs() < origin).select(p + float_scale * n, p_i.bitcast::()) + } + }); + ret.load() } pub type Index = [u32; 3]; #[repr(C)] #[repr(align(8))] #[derive(Clone, Copy, Value, Debug, Soa)] +#[deprecated(note = "Use `SurfaceHit` instead")] pub struct Hit { pub inst_id: u32, pub prim_id: u32, @@ -419,6 +419,7 @@ pub struct Hit { pub v: f32, pub t: f32, } +pub type CommitedHit = SurfaceHit; #[cfg(test)] mod test { @@ -427,12 +428,12 @@ mod test { use super::*; assert_eq!(std::mem::align_of::(), 16); assert_eq!(std::mem::size_of::(), 32); - assert_eq!(std::mem::size_of::(), 24); - assert_eq!(std::mem::align_of::(), 8); + assert_eq!(std::mem::size_of::(), 24); + assert_eq!(std::mem::align_of::(), 8); assert_eq!(std::mem::size_of::(), 12); } } - +#[deprecated] impl HitExpr { pub fn valid(&self) -> Expr { self.inst_id.ne(u32::MAX) @@ -443,16 +444,19 @@ impl HitExpr { } #[derive(Clone, Copy)] -pub struct TriangleCandidate { +pub struct SurfaceCandidate { query: SafeNodeRef, hit: Expr, } +#[deprecated(note = "Use `SurfaceCandidate` instead")] +pub type TriangleCandidate = SurfaceCandidate; + #[derive(Clone, Copy)] pub struct ProceduralCandidate { query: SafeNodeRef, hit: Expr, } -impl TriangleCandidate { +impl SurfaceCandidate { pub fn commit(&self) { let query = self.query.get(); __current_scope(|b| b.call(Func::RayQueryCommitTriangle, &[query], Type::void())); @@ -469,7 +473,7 @@ impl TriangleCandidate { ) } } -impl Deref for TriangleCandidate { +impl Deref for SurfaceCandidate { type Target = TriangleHitExpr; fn deref(&self) -> &Self::Target { &self.hit @@ -504,8 +508,84 @@ pub struct RayQuery { pub on_triangle_hit: T, pub on_procedural_hit: P, } +#[derive(Clone, Copy)] +pub struct AccelTraceOptions { + pub curve_bases: CurveBasisSet, + pub mask: Expr, +} +pub struct RayQueryBase { + query: SafeNodeRef, + on_surface_hit: Option>, + on_procedural_hit: Option>, +} +impl RayQueryBase { + pub fn on_surface_hit(mut self, f: impl Fn(SurfaceCandidate)) -> Self { + assert!( + self.on_surface_hit.is_none(), + "Surface hit already recorded" + ); + with_recorder(|r| { + let pools = r.pools.clone(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + let query = self.query.get(); + let candidate = SurfaceCandidate { + query: self.query, + hit: FromNode::from_node( + __current_scope(|b| { + b.call( + Func::RayQueryTriangleCandidateHit, + &[query], + TriangleHit::type_(), + ) + }) + .into(), + ), + }; + (f)(candidate); + let on_surface_hit = __pop_scope(); + Self { + on_surface_hit: Some(on_surface_hit), + ..self + } + } + pub fn on_procedural_hit(mut self, f: impl Fn(ProceduralCandidate)) -> Self { + assert!( + self.on_procedural_hit.is_none(), + "Procedural hit already recorded" + ); + with_recorder(|r| { + let pools = r.pools.clone(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + let query = self.query.get(); + let procedural_candidate = ProceduralCandidate { + query: self.query, + hit: FromNode::from_node( + __current_scope(|b| { + b.call( + Func::RayQueryProceduralCandidateHit, + &[query], + ProceduralHit::type_(), + ) + }) + .into(), + ), + }; + (f)(procedural_candidate); + let on_procedural_hit = __pop_scope(); + Self { + on_procedural_hit: Some(on_procedural_hit), + ..self + } + } + pub fn trace(self) -> Expr { + todo!() + } +} impl AccelVar { - #[inline] pub fn instance_transform(&self, index: Expr) -> Expr { let index = index.node().get(); let self_node = self.node.get(); @@ -521,7 +601,52 @@ impl AccelVar { ) } - #[inline] + pub fn intersect( + &self, + ray: impl AsExpr, + options: AccelTraceOptions, + ) -> Expr { + let ray = ray.as_expr().node().get(); + let mask = options.mask.node().get(); + let self_node = self.node.get(); + with_recorder(|r| { + r.add_required_curve_basis(options.curve_bases); + }); + FromNode::from_node( + __current_scope(|b| { + b.call( + Func::RayTracingTraceClosest, + &[self_node, ray, mask], + SurfaceHit::type_(), + ) + }) + .into(), + ) + } + pub fn intersect_any( + &self, + ray: impl AsExpr, + options: AccelTraceOptions, + ) -> Expr { + let ray = ray.as_expr().node().get(); + let mask = options.mask.node().get(); + let self_node = self.node.get(); + with_recorder(|r| { + r.add_required_curve_basis(options.curve_bases); + }); + FromNode::from_node( + __current_scope(|b| { + b.call( + Func::RayTracingTraceAny, + &[self_node, ray, mask], + bool::type_(), + ) + }) + .into(), + ) + } + + #[deprecated(note = "Use `intersect` instead")] pub fn trace_closest_masked( &self, ray: impl AsExpr, @@ -541,7 +666,8 @@ impl AccelVar { .into(), ) } - #[inline] + + #[deprecated(note = "Use `intersect_any` instead")] pub fn trace_any_masked( &self, ray: impl AsExpr, @@ -561,50 +687,100 @@ impl AccelVar { .into(), ) } - #[inline] + #[deprecated(note = "Use `intersect` instead")] + #[allow(deprecated)] pub fn trace_closest(&self, ray: impl AsExpr) -> Expr { self.trace_closest_masked(ray, u32::MAX.expr()) } - #[inline] + #[deprecated(note = "Use `intersect_any` instead")] + #[allow(deprecated)] pub fn trace_any(&self, ray: impl AsExpr) -> Expr { self.trace_any_masked(ray, u32::MAX.expr()) } - #[inline] + fn make_rq( + &self, + ray: impl AsExpr, + options: AccelTraceOptions, + ) -> RayQueryBase { + let ray = ray.as_expr().node().get(); + let mask = options.mask.node().get(); + let self_node = self.node.get(); + let query = __current_scope(|b| { + b.call( + if TERMINATE_ON_FIRST { + Func::RayTracingQueryAny + } else { + Func::RayTracingQueryAll + }, + &[self_node, ray, mask], + Type::opaque( + if TERMINATE_ON_FIRST { + "LC_RayQueryAny" + } else { + "LC_RayQueryAll" + } + .into(), + ), + ) + }); + let recorder = with_recorder(|r| r as *const _); + RayQueryBase { + query: query.into(), + on_procedural_hit: None, + on_surface_hit: None, + } + } + pub fn traverse( + &self, + ray: impl AsExpr, + options: AccelTraceOptions, + ) -> RayQueryBase { + self.make_rq(ray, options) + } + pub fn traverse_any( + &self, + ray: impl AsExpr, + options: AccelTraceOptions, + ) -> RayQueryBase { + self.make_rq(ray, options) + } + #[deprecated(note = "Use `traverse` instead")] pub fn query_all( &self, ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where - T: FnOnce(TriangleCandidate), + T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), { self._query(false, ray, mask, ray_query) } - #[inline] + + #[deprecated(note = "Use `traverse_any` instead")] pub fn query_any( &self, ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where - T: FnOnce(TriangleCandidate), + T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), { self._query(true, ray, mask, ray_query) } - #[inline] + fn _query( &self, terminate_on_first: bool, ray: impl AsExpr, mask: impl AsExpr, ray_query: RayQuery, - ) -> Expr + ) -> Expr where - T: FnOnce(TriangleCandidate), + T: FnOnce(SurfaceCandidate), P: FnOnce(ProceduralCandidate), { let ray = ray.as_expr().node().get(); @@ -634,7 +810,7 @@ impl AccelVar { let s = &mut r.scopes; s.push(IrBuilder::new(pools)); }); - let triangle_candidate = TriangleCandidate { + let triangle_candidate = SurfaceCandidate { query: query.into(), hit: FromNode::from_node( __current_scope(|b| { @@ -672,7 +848,7 @@ impl AccelVar { FromNode::from_node( __current_scope(|b| { b.ray_query(query, on_triangle_hit, on_procedural_hit, Type::void()); - b.call(Func::RayQueryCommittedHit, &[query], CommittedHit::type_()) + b.call(Func::RayQueryCommittedHit, &[query], SurfaceHit::type_()) }) .into(), ) diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index f8618e0..d4400cf 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit f8618e094981b82d35b561eef4038e6435fdf74c +Subproject commit d4400cf49b6ca40854d6cbb5191f9b8033ffce82