diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 969f6c77..6c112245 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -29,49 +29,52 @@ fn main() { let dy_gt = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = Kernel::::new(&device, track!(&|| { - let tid = dispatch_id().x; - let buf_x = x.var(); - let buf_y = y.var(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let f = |x: Expr, y: Expr| { - if x > y { - x * y - } else { - y * x + (x / 4.0 * PI).sin() + let shader = Kernel::::new( + &device, + &track!(|| { + let tid = dispatch_id().x; + let buf_x = x.var(); + let buf_y = y.var(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let f = |x: Expr, y: Expr| { + if x > y { + x * y + } else { + y * x + (x / 4.0 * PI).sin() + } + }; + let df = |x: Expr, y: Expr| { + if x > y { + (y, x) + } else { + (y + (x / 4.0 * PI).cos() / 4.0 * PI, x) + } + }; + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = f(x, y); + backward(z); + dx_rev.write(tid, gradient(x)); + dy_rev.write(tid, gradient(y)); + }); + forward_autodiff(2, || { + propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]); + propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]); + let z = f(x, y); + let dx = output_gradients(z)[0]; + let dy = output_gradients(z)[1]; + dx_fwd.write(tid, dx); + dy_fwd.write(tid, dy); + }); + { + let (dx, dy) = df(x, y); + dx_gt.write(tid, dx); + dy_gt.write(tid, dy); } - }; - let df = |x: Expr, y: Expr| { - if x > y { - (y, x) - } else { - (y + (x / 4.0 * PI).cos() / 4.0 * PI, x) - } - }; - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = f(x, y); - backward(z); - dx_rev.write(tid, gradient(x)); - dy_rev.write(tid, gradient(y)); - }); - forward_autodiff(2, || { - propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]); - propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]); - let z = f(x, y); - let dx = output_gradients(z)[0]; - let dy = output_gradients(z)[1]; - dx_fwd.write(tid, dx); - dy_fwd.write(tid, dy); - }); - { - let (dx, dy) = df(x, y); - dx_gt.write(tid, dx); - dy_gt.write(tid, dy); - } - })); + }), + ); shader.dispatch([1024, 1, 1]); { diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index fbf99db6..a369da14 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -18,23 +18,26 @@ fn main() { } else { "cpu" }); - + let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = Kernel::)>::new(&device, &track!(|buf_z| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid + 123); - let y = buf_y.read(tid); - let vx = Var::::zeroed(); // create a local mutable variable - *vx = x; - buf_z.write(tid, vx + y); - })); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid + 123); + let y = buf_y.read(tid); + let vx = Var::::zeroed(); // create a local mutable variable + *vx = x; + buf_z.write(tid, vx + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index a39d2431..8426df3a 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -62,15 +62,18 @@ fn main() { bindless.emplace_buffer_async(1, &y); bindless.emplace_tex2d_async(0, &img, Sampler::default()); bindless.update(); - let kernel = Kernel::)>::new(&device, &track!(|buf_z| { - let bindless = bindless.var(); - let tid = dispatch_id().x; - let buf_x = bindless.buffer::(0_u32.expr()); - let buf_y = bindless.buffer::(1_u32.expr()); - let x = buf_x.read(tid).as_::().as_::(); - let y = buf_y.read(tid); - buf_z.write(tid, x + y); - })); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let bindless = bindless.var(); + let tid = dispatch_id().x; + let buf_x = bindless.buffer::(0_u32.expr()); + let buf_y = bindless.buffer::(1_u32.expr()); + let x = buf_x.read(tid).as_::().as_::(); + let y = buf_y.read(tid); + buf_z.write(tid, x + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let mut z_data = vec![0.0; 1024]; z.view(..).copy_to(&mut z_data); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 1df7c7c1..92090052 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -17,22 +17,24 @@ fn main() { } else { "cpu" }); - let add = - Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); + let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = Kernel::)>::new(&device, &track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - buf_z.write(tid, add.call(x, y)); - })); + buf_z.write(tid, add.call(x, y)); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/custom_aggregate.rs b/luisa_compute/examples/custom_aggregate.rs index f47e2f2b..4cad01a1 100644 --- a/luisa_compute/examples/custom_aggregate.rs +++ b/luisa_compute/examples/custom_aggregate.rs @@ -1,4 +1,5 @@ -use luisa::{prelude::*, lang::types::vector::alias::Float3}; +use luisa::lang::types::vector::alias::Float3; +use luisa::prelude::*; use luisa_compute as luisa; #[derive(Aggregate)] pub struct Spectrum { diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 5c567377..66f621d3 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -2,7 +2,8 @@ use std::env::current_exe; use std::time::Instant; -use luisa::lang::types::vector::{alias::*, Mat2}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::Mat2; use luisa::prelude::*; use luisa_compute as luisa; use rand::Rng; diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index e38879ce..8fdd0b00 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -6,7 +6,8 @@ use std::time::Instant; use winit::event::{Event as WinitEvent, WindowEvent}; use winit::event_loop::EventLoop; -use luisa::lang::types::vector::{alias::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 0fa8c1ec..4fee5d38 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -1,5 +1,6 @@ use image::Rgb; -use luisa::lang::types::vector::{alias::*, Mat4}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::Mat4; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index 0cc015eb..ea070dc9 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -21,14 +21,17 @@ fn main() { "cpu" }); let printer = Printer::new(&device, 65536); - let kernel = Kernel::::new(&device, &track!(|| { - let id = dispatch_id().xy(); - if id.x == id.y { - lc_info!(printer, "id = {:?}", id); - } else { - lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y); - } - })); + let kernel = Kernel::::new( + &device, + &track!(|| { + let id = dispatch_id().xy(); + if id.x == id.y { + lc_info!(printer, "id = {:?}", id); + } else { + lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y); + } + }), + ); device.default_stream().with_scope(|s| { s.reset_printer(&printer); s.submit([kernel.dispatch_async([4, 4, 1])]); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 055b395f..287bf50c 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -1,7 +1,8 @@ use std::env::current_exe; use image::Rgb; -use luisa::lang::types::vector::{alias::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, TriangleCandidate, diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 4240c2a6..cc3abd81 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -1,7 +1,10 @@ use std::env::current_exe; use image::Rgb; -use luisa::lang::{types::vector::alias::*, types::vector::*, types::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; +use luisa::lang::types::*; +use luisa::lang::*; use luisa::prelude::*; use luisa::rtx::{AccelBuildRequest, AccelOption, Ray}; use luisa_compute as luisa; @@ -36,29 +39,32 @@ fn main() { let img_w = 800; let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); - let rt_kernel = Kernel::::new(&device,&track!(|| { - let accel = accel.var(); - let px = dispatch_id().xy(); - let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); - let xy = 2.0 * xy - 1.0; - let o = Float3::expr(0.0, 0.0, -1.0); - let d = Float3::expr(xy.x, xy.y, 0.0) - o; - let d = d.normalize(); - let ray = Ray::new_expr( - Expr::<[f32; 3]>::from(o), - 1e-3, - Expr::<[f32; 3]>::from(d), - 1e9, - ); - let hit = accel.trace_closest(ray); - let img = img.view(0).var(); - let color = select( - hit.valid(), - Float3::expr(hit.u, hit.v, 1.0), - Float3::expr(0.0, 0.0, 0.0), - ); - img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); - })); + let rt_kernel = Kernel::::new( + &device, + &track!(|| { + let accel = accel.var(); + let px = dispatch_id().xy(); + let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); + let xy = 2.0 * xy - 1.0; + let o = Float3::expr(0.0, 0.0, -1.0); + let d = Float3::expr(xy.x, xy.y, 0.0) - o; + let d = d.normalize(); + let ray = Ray::new_expr( + Expr::<[f32; 3]>::from(o), + 1e-3, + Expr::<[f32; 3]>::from(d), + 1e9, + ); + let hit = accel.trace_closest(ray); + let img = img.view(0).var(); + let color = select( + hit.valid(), + Float3::expr(hit.u, hit.v, 1.0), + Float3::expr(0.0, 0.0, 0.0), + ); + img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); + }), + ); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() .with_title("Luisa Compute Rust - Ray Tracing") diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 354761e8..a9ebb710 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -380,10 +380,11 @@ pub fn __module_pools() -> &'static CArc { /// Don't call this function directly unless you know what you are doing /** This function is soley for constructing proxies - * Given a node, __extract selects the correct Func based on the node's type - * It then inserts the extract(node, i) call *at where the node is defined* - * *Note*, after insertion, the IrBuilder in the correct/parent scope might not be up to date - * Thus, for IrBuilder of each scope, it updates the insertion point to the end of the current basic block + * Given a node, __extract selects the correct Func based on the node's + * type It then inserts the extract(node, i) call *at where the node is + * defined* *Note*, after insertion, the IrBuilder in the correct/parent + * scope might not be up to date Thus, for IrBuilder of each scope, it + * updates the insertion point to the end of the current basic block */ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { let inst = &node.get().instruction; @@ -408,8 +409,9 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { } let i = b.const_(Const::Int32(index as i32)); - // Since we have inserted something, the insertion point in cur_builder might not be up to date - // So we need to set it to the end of the current basic block + // Since we have inserted something, the insertion point in cur_builder might + // not be up to date So we need to set it to the end of the current + // basic block macro_rules! update_builders { () => { for scope in &mut r.scopes { diff --git a/luisa_compute/src/lang/ops/cast_impls.rs b/luisa_compute/src/lang/ops/cast_impls.rs index d807947e..b6c34864 100644 --- a/luisa_compute/src/lang/ops/cast_impls.rs +++ b/luisa_compute/src/lang/ops/cast_impls.rs @@ -1,4 +1,3 @@ - #[rustfmt::skip]mod impl_{ use crate::prelude::*; use super::super::*; diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 2c137664..27ebb8bf 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -224,7 +224,7 @@ where self.clone().mul(self.clone()).mul(self.clone()) } fn recip(&self) -> Self { - ::from_node(__current_scope(|b|{ + ::from_node(__current_scope(|b| { let one = b.const_(Const::One(::type_())); b.call(Func::Div, &[one, self.node()], ::type_()) })) diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 947493ad..49c61e87 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -211,7 +211,7 @@ pub trait IntExpr { pub trait FloatExpr: Sized { type Bool; - + fn ceil(&self) -> Self; fn floor(&self) -> Self; fn round(&self) -> Self; diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs index d7d4e626..cee5ed19 100644 --- a/luisa_compute/src/lang/soa.rs +++ b/luisa_compute/src/lang/soa.rs @@ -1,4 +1,5 @@ -use luisa_compute_ir::{ir::Type, CArc}; +use luisa_compute_ir::ir::Type; +use luisa_compute_ir::CArc; use crate::prelude::*; /** A buffer with SOA layout. diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 11bad6b0..35401280 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -12,9 +12,9 @@ pub mod vector; // TODO: Check up on comments. -/// A value that can be used in a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Call [`expr`](Value::expr) or -/// [`var`](Value::var) to convert into a kernel-trackable type. +/// A value that can be used in a [`Kernel`] or [`Callable`]. Call +/// [`expr`](Value::expr) or [`var`](Value::var) to convert into a +/// kernel-trackable type. pub trait Value: Copy + TypeOf + 'static { /// A proxy for additional impls on [`Expr`]. type Expr: ExprProxy; @@ -139,8 +139,7 @@ impl AtomciRefProxyDataProxyData { } } } -/// An expression within a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Created from a raw value +/// An expression within a [`Kernel`] or [`Callable`]. Created from a raw value /// using [`Value::expr`]. /// /// Note that this does not store the value, and in order to get the result of a @@ -158,8 +157,7 @@ pub struct Expr { #[repr(C)] pub struct TypeTag(PhantomData); -/// A variable within a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Created using [`Expr::var`] +/// A variable within a [`Kernel`] or [`Callable`]. Created using [`Expr::var`] /// and [`Value::var`]. /// /// Note that setting a `Var` using direct assignment will not work. Instead, diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs index a5afeffd..f82c121e 100644 --- a/luisa_compute/src/lang/types/vector/element.rs +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -9,7 +9,6 @@ macro_rules! element { type VectorAtomicRef = VectorAtomicRefProxy2<$T>; type VectorExprData = VectorExprData<$T, 2>; type VectorVarData = VectorVarData<$T, 2>; - } }; ($T:ty [ 3 ]: $A: ident) => { diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs index 61e387a4..8765d28b 100644 --- a/luisa_compute/src/lang/types/vector/impls.rs +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -303,7 +303,7 @@ macro_rules! impl_mat_proxy { Func::Mul.call2(self, rhs) } } - + impl MatExpr for Expr<$M> { type Scalar = Expr; type Value = $M; diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index fe430422..ca318667 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1259,6 +1259,7 @@ pub struct KernelDef { /// An executable kernel /// Kernel creation can be done in multiple ways: /// - Seperate recording and compilation: +/// /// ```no_run /// // Recording: /// use luisa_compute::prelude::*; @@ -1269,7 +1270,9 @@ pub struct KernelDef { /// // Compilation: /// let kernel = device.compile_kernel_def(&kernel); /// ``` +/// /// - Recording and compilation in one step: +/// /// ```no_run /// use luisa_compute::prelude::*; /// let ctx = Context::new(std::env::current_exe().unwrap()); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 58b1b780..b1b2a450 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -97,40 +97,47 @@ impl CallableParameter for BindlessArrayVar { } } impl KernelParameter for rtx::AccelVar { + type Arg = rtx::Accel; fn def_param(builder: &mut KernelBuilder) -> Self { builder.accel() } } pub trait KernelParameter { + type Arg: KernelArg + 'static; fn def_param(builder: &mut KernelBuilder) -> Self; } impl KernelParameter for Expr { + type Arg = T; fn def_param(builder: &mut KernelBuilder) -> Self { builder.uniform::() } } impl KernelParameter for BufferVar { + type Arg = Buffer; fn def_param(builder: &mut KernelBuilder) -> Self { builder.buffer() } } impl KernelParameter for Tex2dVar { + type Arg = Tex2d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex2d() } } impl KernelParameter for Tex3dVar { + type Arg = Tex3d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex3d() } } impl KernelParameter for BindlessArrayVar { + type Arg = BindlessArray; fn def_param(builder: &mut KernelBuilder) -> Self { builder.bindless_array() } @@ -139,6 +146,7 @@ impl KernelParameter for BindlessArrayVar { macro_rules! impl_kernel_param_for_tuple { ($first:ident $($rest:ident)*) => { impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { + type Arg = ($first::Arg, $($rest::Arg),*); #[allow(non_snake_case)] fn def_param(builder: &mut KernelBuilder) -> Self { ($first::def_param(builder), $($rest::def_param(builder)),*) @@ -148,6 +156,7 @@ macro_rules! impl_kernel_param_for_tuple { }; ()=>{ impl KernelParameter for () { + type Arg = (); fn def_param(_: &mut KernelBuilder) -> Self { } } @@ -399,8 +408,8 @@ impl KernelBuilder { /// * `async_compile`: compile the kernel asynchronously /// * `enable_cache`: enable cache for the compiled kernel /// * `enable_fast_math`: enable fast math in the compiled kernel -/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of the generated PTX kernel -/// +/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of +/// the generated PTX kernel #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct KernelBuildOptions { pub enable_debug_info: bool, @@ -606,7 +615,7 @@ impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_callable_build_for_fn { ($($Ts:ident)*) => { impl CallableBuildFnR> for T - where T: Fn($($Ts,)*)->R + 'static { + where T: Fn($($Ts,)*)->R { #[allow(non_snake_case)] #[allow(unused_variables)] fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 2f6fd668..8fe8dd10 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -1208,8 +1208,9 @@ fn autodiff_callable() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let callable = - Callable::, Var, Expr)>::new(&device, track!(|vx, vy, t| { + let callable = Callable::, Var, Expr)>::new( + &device, + track!(|vx, vy, t| { let x = **vx; let y = **vy; autodiff(|| { @@ -1224,7 +1225,8 @@ fn autodiff_callable() { *vx = gradient(x); *vy = gradient(y); }); - })); + }), + ); let kernel = Kernel::::new( &device, &track!(|| { diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 69d4ca59..1e2002ad 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1,8 +1,9 @@ use luisa::lang::ops::AddMaybeExpr; use luisa::lang::types::array::VLArrayVar; +use luisa::lang::types::core::*; use luisa::lang::types::dynamic::*; use luisa::lang::types::vector::alias::*; -use luisa::lang::types::{core::*, ExprProxy}; +use luisa::lang::types::ExprProxy; use luisa::prelude::*; use luisa_compute as luisa; use luisa_compute_api_types::StreamTag; @@ -48,6 +49,7 @@ fn event() { assert_eq!(v[0], (1 + 3) * (4 + 5)); } #[test] +#[test] #[should_panic] fn callable_return_mismatch() { let device = get_device(); @@ -126,7 +128,7 @@ fn callable() { &device, |buf: BufferVar, i: Expr, v: Var| { buf.write(i, v.load()); - track!(*v+=1;) + track!(*v += 1); }, ); let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); @@ -159,6 +161,43 @@ fn callable() { } } #[test] +fn callable_capture() { + let device = get_device(); + + let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + let w = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as u32); + y.view(..).fill_fn(|i| 1000 * i as u32); + let write = Callable::, Var)>::new(&device, |i, v| { + z.write(i, v); + track!(*v += 1); + }); + let kernel = Kernel::::new( + &device, + &track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_w = w.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let z = add.call(x, y).var(); + write.call(tid, z); + buf_w.write(tid, z); + }), + ); + kernel.dispatch([1024, 1, 1]); + let z_data = z.view(..).copy_to_vec(); + let w_data = w.view(..).copy_to_vec(); + for i in 0..z_data.len() { + assert_eq!(z_data[i], (i + 1000 * i) as u32); + assert_eq!(w_data[i], (i + 1000 * i) as u32 + 1); + } +} +#[test] fn vec_cast() { let device = get_device(); let f: Buffer = device.create_buffer(1024); diff --git a/luisa_compute_derive_impl/src/bin/derive-debug.rs b/luisa_compute_derive_impl/src/bin/derive-debug.rs index b94b49f7..e71d5cac 100644 --- a/luisa_compute_derive_impl/src/bin/derive-debug.rs +++ b/luisa_compute_derive_impl/src/bin/derive-debug.rs @@ -6,8 +6,6 @@ struct Foo { b: u32, } - - fn main() { let compiler = Compiler; let item: syn::ItemStruct = syn::parse_str( diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index 47874390..3791fd26 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -128,6 +128,7 @@ impl Compiler { #parameter_def impl #impl_generics #runtime_path::KernelParameter for #parameter_name #ty_generics #where_clause{ + type Arg = #name #ty_generics #where_clause; fn def_param(builder: &mut #runtime_path::KernelBuilder) -> Self { Self{ #(#field_names: #runtime_path::KernelParameter::def_param(builder)),* diff --git a/rustfmt.toml b/rustfmt.toml index 174fd7a4..d1bdb5ed 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,3 +1,2 @@ ignore = ["luisa_compute_sys"] imports_granularity = "Module" -wrap_comments = true