diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 9f96666..9b6201d 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -5,7 +5,8 @@ use luisa::lang::types::vector::alias::*; use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ - Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, SurfaceCandidate, + Aabb, AccelBuildRequest, AccelOption, AccelTraceOptions, ProceduralCandidate, Ray, + SurfaceCandidate, }; use luisa_compute as luisa; use winit::event::{Event as WinitEvent, WindowEvent}; @@ -136,54 +137,55 @@ fn main() { Expr::<[f32; 3]>::from(d), 1e9f32, ); - let hit = accel.query_all( - ray, - 255, - RayQuery { - on_triangle_hit: |candidate: SurfaceCandidate| { - let bary = candidate.bary; - let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); - let t = candidate.committed_ray_t; - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(0, t); - debug_hit_t.write(1, candidate.ray().tmax); - }; - // if (uvw.xy().length() < 0.8) - // & (uvw.yz().length() < 0.8) - // & (uvw.xz().length() < 0.8) - if uvw.xy().length() < 0.8 - && uvw.yz().length() < 0.8 - && uvw.xz().length() < 0.8 - { - candidate.commit(); - } + let hit = accel + .traverse( + ray, + AccelTraceOptions { + mask: 0xffu32.expr(), + ..Default::default() }, - on_procedural_hit: |candidate: ProceduralCandidate| { - let ray = candidate.ray(); - let prim = candidate.prim; - let sphere = spheres.var().read(prim); - let o: Expr = ray.orig.into(); - let d: Expr = ray.dir.into(); - let t = Var::::zeroed(); + ) + .on_surface_hit(|candidate: SurfaceCandidate| { + let bary = candidate.bary; + let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); + let t = candidate.committed_ray_t; + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(0, t); + debug_hit_t.write(1, candidate.ray().tmax); + }; + // if (uvw.xy().length() < 0.8) + // & (uvw.yz().length() < 0.8) + // & (uvw.xz().length() < 0.8) + if uvw.xy().length() < 0.8 && uvw.yz().length() < 0.8 && uvw.xz().length() < 0.8 + { + candidate.commit(); + } + }) + .on_procedural_hit(|candidate: ProceduralCandidate| { + let ray = candidate.ray(); + let prim = candidate.prim; + let sphere = spheres.var().read(prim); + let o: Expr = ray.orig.into(); + let d: Expr = ray.dir.into(); + let t = Var::::zeroed(); - for _ in 0..100 { - let dist = (o + d * t - (sphere.center + translate.expr())).length() - - sphere.radius; - if dist < 0.001 { - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(2, t); - debug_hit_t.write(3, candidate.ray().tmax); - } - if t < ray.tmax { - candidate.commit(t); - } - break; + for _ in 0..100 { + let dist = (o + d * t - (sphere.center + translate.expr())).length() + - sphere.radius; + if dist < 0.001 { + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(2, t); + debug_hit_t.write(3, candidate.ray().tmax); + } + if t < ray.tmax { + candidate.commit(t); } - *t += dist; + break; } - }, - }, - ); + *t += dist; + } + }) + .trace(); let img = img.view(0).var(); let color = if hit.triangle_hit() { let bary = hit.bary; diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index a939495..5177df7 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -493,7 +493,11 @@ impl CommittedHitExpr { } #[tracked] pub fn procedural_hit(&self) -> Expr { - self.hit_type.eq(HitType::Procedural as u32) & self.bary.y.lt(0.0) + self.hit_type.eq(HitType::Procedural as u32) & self.bary.y.ge(0.0) + } + #[tracked] + pub fn curve_hit(&self) -> Expr { + self.hit_type.eq(HitType::Surface as u32) & self.bary.y.lt(0.0) } pub fn curve_parameter(&self) -> Expr { self.bary.x