From 2f056e48aa76d576e8f35a9b718a14df122c7935 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Fri, 22 Sep 2023 17:33:23 -0400 Subject: [PATCH] fixed incorrect proxy generatoin --- luisa_compute/examples/path_tracer_cutout.rs | 453 ++++++++++--------- luisa_compute/src/lang.rs | 56 ++- luisa_compute/tests/misc.rs | 11 + luisa_compute_sys/LuisaCompute | 2 +- 4 files changed, 300 insertions(+), 222 deletions(-) diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 94ae453..92cd10d 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -1,5 +1,5 @@ use image::Rgb; -use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::{alias::*, Mat4}; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; @@ -9,8 +9,8 @@ use winit::event_loop::EventLoop; use luisa::prelude::*; use luisa::rtx::{ - offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayQuery, - TriangleCandidate, + offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, + RayQuery, TriangleCandidate, }; use luisa_compute as luisa; @@ -206,7 +206,7 @@ fn main() { let vertex_heap = device.create_bindless_array(65536); let index_heap = device.create_bindless_array(65536); - let mut vertex_buffers: Vec> = vec![]; + let mut vertex_buffers: Vec> = vec![]; let mut index_buffers: Vec> = vec![]; let accel = device.create_accel(AccelOption::default()); let stream = device.create_stream(StreamTag::Graphics); @@ -223,7 +223,7 @@ fn main() { cmds.push(vertex_buffer.view(..).copy_from_async(unsafe { let vertex_ptr = model.mesh.positions.as_ptr(); std::slice::from_raw_parts( - vertex_ptr as *const PackedFloat3, + vertex_ptr as *const [f32; 3], model.mesh.positions.len() / 3, ) })); @@ -242,7 +242,11 @@ fn main() { } else { glam::Mat4::IDENTITY }; - accel.push_mesh(&mesh, m.into(), u32::MAX, false); + #[cfg(feature = "glam")] + let m: Mat4 = m.into(); + #[cfg(not(feature = "glam"))] + let m: Mat4 = unsafe { std::mem::transmute(m) }; + accel.push_mesh(&mesh, m, u32::MAX, false); } cmds.push(vertex_heap.update_async()); cmds.push(index_heap.update_async()); @@ -252,238 +256,269 @@ fn main() { }); // use create_kernel_async to compile multiple kernels in parallel - let path_tracer = device - .create_kernel_async::, Tex2d, Accel, Uint2)>( - &|image: Tex2dVar, - seed_image: Tex2dVar, - accel: AccelVar, - resolution: Expr| { - set_block_size([16u32, 16u32, 1u32]); - let cbox_materials = [ - Float3::new(0.725f32, 0.710f32, 0.680f32), // floor - Float3::new(0.725f32, 0.710f32, 0.680f32), // ceiling - Float3::new(0.725f32, 0.710f32, 0.680f32), // back wall - Float3::new(0.140f32, 0.450f32, 0.091f32), // right wall - Float3::new(0.630f32, 0.065f32, 0.050f32), // left wall - Float3::new(0.725f32, 0.710f32, 0.680f32), // short box - Float3::new(0.725f32, 0.710f32, 0.680f32), // tall box - Float3::new(0.000f32, 0.000f32, 0.000f32), // light - ].expr(); - - let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::)->Expr>(|state: Var| { - const LCG_A: u32 = 1664525u32; - const LCG_C: u32 = 1013904223u32; - *state.get_mut() = LCG_A * *state + LCG_C; - (*state & 0x00ffffffu32).float() * (1.0f32 / 0x01000000u32 as f32) - }); - lcg.call(state) - }; + let path_tracer = device.create_kernel_async::, Tex2d, Accel, Uint2)>( + track!(&|image: Tex2dVar, + seed_image: Tex2dVar, + accel: AccelVar, + resolution: Expr| { + set_block_size([16u32, 16u32, 1u32]); + let cbox_materials = [ + Float3::new(0.725f32, 0.710f32, 0.680f32), // floor + Float3::new(0.725f32, 0.710f32, 0.680f32), // ceiling + Float3::new(0.725f32, 0.710f32, 0.680f32), // back wall + Float3::new(0.140f32, 0.450f32, 0.091f32), // right wall + Float3::new(0.630f32, 0.065f32, 0.050f32), // left wall + Float3::new(0.725f32, 0.710f32, 0.680f32), // short box + Float3::new(0.725f32, 0.710f32, 0.680f32), // tall box + Float3::new(0.000f32, 0.000f32, 0.000f32), // light + ] + .expr(); + + let lcg = |state: Var| -> Expr { + let lcg = create_static_callable::) -> Expr>(|state: Var| { + const LCG_A: u32 = 1664525u32; + const LCG_C: u32 = 1013904223u32; + *state = LCG_A * state + LCG_C; + (state & 0x00ffffffu32).as_f32() * (1.0f32 / 0x01000000u32 as f32) + }); + lcg.call(state) + }; - let make_ray = |o: Expr, d: Expr, tmin: Expr, tmax: Expr| -> Expr { - struct_!(Ray { + let make_ray = + |o: Expr, d: Expr, tmin: Expr, tmax: Expr| -> Expr { + Ray::from_comps_expr(RayComps { orig: o.into(), tmin: tmin, - dir:d.into(), - tmax: tmax + dir: d.into(), + tmax: tmax, }) }; - let generate_ray = |p: Expr| -> Expr { + let generate_ray = |p: Expr| -> Expr { + let fov = escape!({ const FOV: f32 = 27.8f32 * std::f32::consts::PI / 180.0f32; - let origin = Float3::expr(-0.01f32, 0.995f32, 5.0f32); + FOV + }); + let origin = Float3::expr(-0.01f32, 0.995f32, 5.0f32); - let pixel = origin - + Float3::expr( - p.x() * f32::tan(0.5f32 * FOV), - p.y() * f32::tan(0.5f32 * FOV), + let pixel = origin + + Float3::expr( + p.x * escape!(f32::tan(0.5f32 * fov)), + p.y * escape!(f32::tan(0.5f32 * fov)), -1.0f32, ); - let direction = (pixel - origin).normalize(); - make_ray(origin, direction, 0.0f32.into(), f32::MAX.into()) - }; - - let balanced_heuristic = |pdf_a: Expr, pdf_b: Expr| { - pdf_a / (pdf_a + pdf_b).max(1e-4f32) - }; + let direction = (pixel - origin).normalize(); + make_ray(origin, direction, 0.0f32.expr(), f32::MAX.expr()) + }; - let make_onb = |normal: Expr| -> Expr { - let binormal = if_!( - normal.x().abs().cmpgt(normal.z().abs()), { - Float3::expr(-normal.y(), normal.x(), 0.0f32) - }, else { - Float3::expr(0.0f32, -normal.z(), normal.y()) - } - ); - let tangent = binormal.cross(normal).normalize(); - OnbExpr::new(tangent, binormal, normal) - }; + let balanced_heuristic = + |pdf_a: Expr, pdf_b: Expr| pdf_a / luisa::max(pdf_a + pdf_b, 1e-4f32); - let cosine_sample_hemisphere = |u: Expr| { - let r = u.x().sqrt(); - let phi = 2.0f32 * std::f32::consts::PI * u.y(); - Float3::expr(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt()) + let make_onb = |normal: Expr| -> Expr { + let binormal = if normal.x.abs() > normal.z.abs() { + Float3::expr(-normal.y, normal.x, 0.0f32) + } else { + Float3::expr(0.0f32, -normal.z, normal.y) }; - let coord = dispatch_id().xy(); - let frame_size = resolution.x().min(resolution.y()).float(); - let state = Var::::zeroed(); - state.store(seed_image.read(coord)); + let tangent = binormal.cross(normal).normalize(); + Onb::new_expr(tangent, binormal, normal) + }; - let rx = lcg(state); - let ry = lcg(state); + let cosine_sample_hemisphere = |u: Expr| { + let r = u.x.sqrt(); + let phi = 2.0f32 * std::f32::consts::PI * u.y; + Float3::expr(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x).sqrt()) + }; - let pixel = (coord.float() + Float2::expr(rx, ry)) / frame_size * 2.0f32 - 1.0f32; + let coord = dispatch_id().xy(); + let frame_size = luisa::min(resolution.x, resolution.y).as_f32(); + let state = Var::::zeroed(); + state.store(seed_image.read(coord)); + + let rx = lcg(state); + let ry = lcg(state); + + let pixel = (coord.cast_f32() + Float2::expr(rx, ry)) / frame_size * 2.0f32 - 1.0f32; + + let radiance = Var::::zeroed(); + radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); + for _ in 0..SPP_PER_DISPATCH as u32 { + let init_ray = generate_ray(pixel * Float2::expr(1.0f32, -1.0f32)); + let ray = Var::::zeroed(); + ray.store(init_ray); + + let beta = Var::::zeroed(); + beta.store(Float3::expr(1.0f32, 1.0f32, 1.0f32)); + let pdf_bsdf = Var::::zeroed(); + pdf_bsdf.store(0.0f32); + + let light_position = Float3::expr(-0.24f32, 1.98f32, 0.16f32); + let light_u = Float3::expr(-0.24f32, 1.98f32, -0.22f32) - light_position; + let light_v = Float3::expr(0.23f32, 1.98f32, 0.16f32) - light_position; + let light_emission = Float3::expr(17.0f32, 12.0f32, 4.0f32); + let light_area = light_u.cross(light_v).length(); + let light_normal = light_u.cross(light_v).normalize(); + + let filter = |c: &TriangleCandidate| { + let valid = true.var(); + if c.inst == 5u32 { + *valid = (c.bary.y * 6.0f32).fract() < 0.6f32; + } + if c.inst == 6u32 { + *valid = (c.bary.y * 5.0f32).fract() < 0.5f32; + } + **valid + }; - let radiance = Var::::zeroed(); - radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); - for_range(0..SPP_PER_DISPATCH as u32, |_| { - let init_ray = generate_ray(pixel * Float2::expr(1.0f32, -1.0f32)); - let ray = Var::::zeroed(); - ray.store(init_ray); - - let beta = Var::::zeroed(); - beta.store(Float3::expr(1.0f32, 1.0f32, 1.0f32)); - let pdf_bsdf = Var::::zeroed(); - pdf_bsdf.store(0.0f32); - - let light_position = Float3::expr(-0.24f32, 1.98f32, 0.16f32); - let light_u = Float3::expr(-0.24f32, 1.98f32, -0.22f32) - light_position; - let light_v = Float3::expr(0.23f32, 1.98f32, 0.16f32) - light_position; - let light_emission = Float3::expr(17.0f32, 12.0f32, 4.0f32); - let light_area = light_u.cross(light_v).length(); - let light_normal = light_u.cross(light_v).normalize(); - - let filter = |c: &TriangleCandidate| { - let valid = true.var(); - if_!(c.inst().cmpeq(5u32), { valid.store((c.bary().y() * 6.0f32).fract().cmplt(0.6f32)); }); - if_!(c.inst().cmpeq(6u32), { valid.store((c.bary().y() * 5.0f32).fract().cmplt(0.5f32)); }); - valid.load() - }; - - let depth = Var::::zeroed(); - while_!(depth.load().cmplt(10u32), { - // let hit = accel.trace_closest(ray); - let hit = accel.query_all(ray, 255, RayQuery { + let depth = Var::::zeroed(); + while depth < 10 { + // let hit = accel.trace_closest(ray); + let hit = accel.query_all( + ray, + 255, + RayQuery { on_triangle_hit: |c: TriangleCandidate| { - if_!(filter(&c), { c.commit(); }); + if filter(&c) { + c.commit(); + } }, - on_procedural_hit: |_c| {} - }); - - if_!(hit.miss(), { - break_(); - }); - - let vertex_buffer = vertex_heap.var().buffer::(hit.inst_id()); - let triangle = index_heap - .var() - .buffer::(hit.inst_id()) - .read(hit.prim_id()); - - let p0: Expr = vertex_buffer.read(triangle.x()).into(); - let p1: Expr = vertex_buffer.read(triangle.y()).into(); - let p2: Expr = vertex_buffer.read(triangle.z()).into(); - let m = accel.instance_transform(hit.inst_id()); - let p = p0 * (1.0f32 - hit.bary().x() - hit.bary().y()) + p1 * hit.bary().x() + p2 * hit.bary().y(); - let p = (m * Float4::expr(p.x(), p.y(), p.z(), 1.0f32)).xyz(); - let n = (p1 - p0).cross(p2 - p0); - let n = (m * Float4::expr(n.x(), n.y(), n.z(), 0.0f32)).xyz().normalize(); - - let origin: Expr = ray.load().orig().into(); - let direction: Expr = ray.load().dir().into(); - let cos_wi = -direction.dot(n); - if_!(cos_wi.abs().cmplt(1e-4f32), { - break_(); - }); - let pp = offset_ray_origin(p, n); - let albedo = cbox_materials.read(hit.inst_id()); - // hit light - if_!(hit.inst_id().cmpeq(7u32), { - if_!(depth.load().cmpeq(0u32), { - radiance.store(radiance.load() + light_emission); - }, else { - let pdf_light = (p - origin).length_squared() / (light_area * cos_wi); - if_!(pdf_light.cmpgt(0.0f32), { - let mis_weight = balanced_heuristic(pdf_bsdf.load(), pdf_light); - radiance.store(radiance.load() + mis_weight * *beta * light_emission); - }); - }); - break_(); - }, else{ - - // sample light - let ux_light = lcg(state); - let uy_light = lcg(state); - let p_light = light_position + ux_light * light_u + uy_light * light_v; - - let pp_light = offset_ray_origin(p_light, light_normal); - let d_light = (pp - pp_light).length(); - let wi_light = (pp_light - pp).normalize(); - let shadow_ray = make_ray(offset_ray_origin(pp, n), wi_light, 0.0f32.into(), d_light); - let occluded = accel.query_any(shadow_ray, 255, RayQuery { + on_procedural_hit: |_c| {}, + }, + ); + + if hit.miss() { + break; + } + + let vertex_buffer = vertex_heap.var().buffer::<[f32; 3]>(hit.inst_id); + let triangle = index_heap + .var() + .buffer::(hit.inst_id) + .read(hit.prim_id); + + let p0: Expr = vertex_buffer.read(triangle[0]).into(); + let p1: Expr = vertex_buffer.read(triangle[1]).into(); + let p2: Expr = vertex_buffer.read(triangle[2]).into(); + + let m = accel.instance_transform(hit.inst_id); + let p = + p0 * (1.0f32 - hit.bary.x - hit.bary.y) + p1 * hit.bary.x + p2 * hit.bary.y; + let p = (m * Float4::expr(p.x, p.y, p.z, 1.0f32)).xyz(); + let n = (p1 - p0).cross(p2 - p0); + let n = (m * Float4::expr(n.x, n.y, n.z, 0.0f32)).xyz().normalize(); + + let origin: Expr = (**ray.orig).into(); + let direction: Expr = (**ray.dir).into(); + let cos_wi = -direction.dot(n); + if cos_wi.abs() < 1e-4f32 { + break; + } + let pp = offset_ray_origin(p, n); + let albedo = cbox_materials.read(hit.inst_id); + // hit light + if hit.inst_id == 7u32 { + if depth == 0u32 { + radiance.store(radiance + light_emission); + } else { + let pdf_light = (p - origin).length_squared() / (light_area * cos_wi); + let mis_weight = balanced_heuristic(**pdf_bsdf, pdf_light); + radiance.store(radiance + mis_weight * beta * light_emission); + } + break; + } else { + // sample light + let ux_light = lcg(state); + let uy_light = lcg(state); + let p_light = light_position + ux_light * light_u + uy_light * light_v; + + let pp_light = offset_ray_origin(p_light, light_normal); + let d_light = (pp - pp_light).length(); + let wi_light = (pp_light - pp).normalize(); + let shadow_ray = + make_ray(offset_ray_origin(pp, n), wi_light, 0.0f32.expr(), d_light); + let occluded = accel.query_any( + shadow_ray, + 255, + RayQuery { on_triangle_hit: |c: TriangleCandidate| { - if_!(filter(&c), { c.commit(); }); + if_!(filter(&c), { + c.commit(); + }); + }, + on_procedural_hit: |_c| {}, }, - on_procedural_hit: |_c| {} - }); - let occluded = !occluded.miss(); - let cos_wi_light = wi_light.dot(n); - let cos_light = -light_normal.dot(wi_light); - - if_!(!occluded & (cos_wi_light * cos_wi).cmpgt(0.0f32) & cos_light.cmpgt(1e-4f32), { - let pdf_light = (d_light * d_light) / (light_area * cos_light); - let pdf_bsdf = cos_wi_light * std::f32::consts::FRAC_1_PI; - let mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); - let bsdf = albedo * std::f32::consts::FRAC_1_PI * cos_wi_light; - radiance.store(*radiance+ *beta * bsdf * mis_weight * light_emission / pdf_light.max(1e-4f32)); - }); - }); - // sample BSDF - let onb = make_onb(n); - let ux = lcg(state); - let uy = lcg(state); - let new_direction = onb.to_world(cosine_sample_hemisphere(Float2::expr(ux, uy))); - *ray.get_mut() = make_ray(pp, new_direction, 0.0f32.into(), std::f32::MAX.into()); - *beta.get_mut() *= albedo; - pdf_bsdf.store(cos_wi.abs() * std::f32::consts::FRAC_1_PI); - - // russian roulette - let l = Float3::expr(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta); - if_!(l.cmpeq(0.0f32), { break_(); }); - let q = l.max(0.05f32); - let r = lcg(state); - if_!(r.cmpgt(q), { break_(); }); - *beta.get_mut() = *beta / q; - - *depth.get_mut() += 1; - }); - }); - radiance.store(radiance.load() / SPP_PER_DISPATCH as f32); - seed_image.write(coord, *state); - if_!(radiance.load().is_nan().any(), { radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); }); - let radiance = radiance.load().clamp(0.0f32, 30.0f32); - let old = image.read(dispatch_id().xy()); - let spp = old.w(); - let radiance = radiance + old.xyz(); - image.write(dispatch_id().xy(), Float4::expr(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32)); - }, - ) - ; + ); + let occluded = !occluded.miss(); + let cos_wi_light = wi_light.dot(n); + let cos_light = -light_normal.dot(wi_light); + + if !occluded && cos_wi_light > 1e-4f32 && cos_light > 1e-4f32 { + let pdf_light = (d_light * d_light) / (light_area * cos_light); + let pdf_bsdf = cos_wi_light * std::f32::consts::FRAC_1_PI; + let mis_weight = balanced_heuristic(pdf_light, pdf_bsdf); + let bsdf = albedo * std::f32::consts::FRAC_1_PI * cos_wi_light; + *radiance += beta * bsdf * mis_weight * light_emission + / luisa::max(pdf_light, 1e-4f32); + } + } + // sample BSDF + let onb = make_onb(n); + let ux = lcg(state); + let uy = lcg(state); + let new_direction = + onb.to_world(cosine_sample_hemisphere(Float2::expr(ux, uy))); + *ray = make_ray(pp, new_direction, 0.0f32.expr(), std::f32::MAX.expr()); + *beta *= albedo; + pdf_bsdf.store(cos_wi.abs() * std::f32::consts::FRAC_1_PI); + + // russian roulette + let l = Float3::expr(0.212671f32, 0.715160f32, 0.072169f32).dot(beta); + if l == 0.0f32 { + break; + } + let q = luisa::max(l, 0.05f32); + let r = lcg(state); + if r > q { + break; + } + *beta = beta / q; + + *depth += 1; + } + } + radiance.store(radiance / SPP_PER_DISPATCH as f32); + seed_image.write(coord, state); + if radiance.is_nan().any() { + radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); + } + let radiance = radiance.clamp( + Float3::splat_expr(0.0f32.expr()), + Float3::splat_expr(30.0f32.expr()), + ); + let old = image.read(dispatch_id().xy()); + let spp = old.w; + let radiance = radiance + old.xyz(); + image.write( + dispatch_id().xy(), + Float4::expr(radiance.x, radiance.y, radiance.z, spp + 1.0f32), + ); + }), + ); let display = - device.create_kernel_async::, Tex2d)>(&|acc, display| { + device.create_kernel_async::, Tex2d)>(track!(&|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); - let spp = radiance.w(); + let spp = radiance.w; let radiance = radiance.xyz() / spp; // workaround a rust-analyzer bug - let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055; + let r = 1.055f32 * radiance.powf(1.0 / 2.4f32) - 0.055; - let srgb = Expr::::select(radiance.cmplt(0.0031308), radiance * 12.92, r); - display.write(coord, Float4::expr(srgb.x(), srgb.y(), srgb.z(), 1.0f32)); - }); + let srgb = radiance.lt(0.0031308).select(radiance * 12.92, r); + display.write(coord, Float4::expr(srgb.x, srgb.y, srgb.z, 1.0f32)); + })); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index bc7f1ec..ca15ac8 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -375,20 +375,46 @@ pub fn __module_pools() -> &'static CArc { unsafe { std::mem::transmute(pool) } }) } -// pub fn __load(node: NodeRef) -> Expr { -// __current_scope(|b| { -// let node = b.load(node); -// Expr::::from_node(node) -// }) -// } -// pub fn __store(var:NodeRef, value:NodeRef) { -// let inst = &var.get().instruction; -// } +/// Don't call this function directly unless you know what you are doing +/** This function is soley for constructing proxies + * Given a node, __extract selects the correct Func based on the node's type + * It then inserts the extract(node, i) call *at where the node is defined* + * *Note*, after insertion, the IrBuilder in the correct/parent scope might not be up to date + * Thus, for IrBuilder of each scope, it updates the insertion point to the end of the current basic block + */ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { let inst = &node.get().instruction; - __current_scope(|b| { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + + let pools = { + let cur_builder = r.scopes.last_mut().unwrap(); + cur_builder.pools() + }; + let mut b = IrBuilder::new_without_bb(pools.clone()); + + if !node.is_argument() && !node.is_uniform() && !node.is_atomic_ref() { + // These nodes are not attached to any BB + // however, we need to generate the index node + // We generate them at the top of current module + b.set_insert_point(node); + } else { + let first_scope = &r.scopes[0]; + let first_scope_bb = first_scope.bb(); + b.set_insert_point(first_scope_bb.first()); + } + let i = b.const_(Const::Int32(index as i32)); + // Since we have inserted something, the insertion point in cur_builder might not be up to date + // So we need to set it to the end of the current basic block + macro_rules! update_builders { + () => { + for scope in &mut r.scopes { + scope.set_insert_point_to_end(); + } + }; + } let op = match inst.as_ref() { Instruction::Local { .. } => Func::GetElementPtr, Instruction::Argument { by_value } => { @@ -402,18 +428,24 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { Func::AtomicRef => { let mut indices = args.to_vec(); indices.push(i); - return b.call_no_append(Func::AtomicRef, &indices, ::type_()); + let n = b.call_no_append(Func::AtomicRef, &indices, ::type_()); + update_builders!(); + return n; } Func::GetElementPtr => { let mut indices = args.to_vec(); indices.push(i); - return b.call(Func::GetElementPtr, &indices, ::type_()); + let n = b.call(Func::GetElementPtr, &indices, ::type_()); + update_builders!(); + return n; } _ => Func::ExtractElement, }, _ => Func::ExtractElement, }; let node = b.call(op, &[node, i], ::type_()); + + update_builders!(); node }) } diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index a0cbf13..83d65c8 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -954,3 +954,14 @@ fn atomic() { assert_eq!(foo_min, expected_foo_min); } + + +#[test] +fn expr_proxy() { + let device = get_device(); + let foo = device.create_buffer_from_fn(1024, |_| Foo { + i: 0, + v: Float2::new(0.0, 0.0), + a: [0; 4], + }); +} \ No newline at end of file diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 112b854..1272ab8 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 112b8547620b2fb4218d0e79b9a9158ef3c5ea9f +Subproject commit 1272ab89c8cddb73189f439ef5fcd6c1737b45d2