From 7c0cc43b5a7c5dd09bb1fbcfb55828c6eb6ffaca Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Sat, 23 Sep 2023 18:05:16 +0100 Subject: [PATCH 1/3] Formatted literally everything, and fixed rustfmt being anal about code blocks in comments. --- luisa_compute/examples/autodiff.rs | 87 ++++++++++--------- luisa_compute/examples/backtrace.rs | 27 +++--- luisa_compute/examples/bindless.rs | 21 +++-- luisa_compute/examples/callable.rs | 22 ++--- luisa_compute/examples/custom_aggregate.rs | 3 +- luisa_compute/examples/mpm.rs | 3 +- luisa_compute/examples/path_tracer.rs | 3 +- luisa_compute/examples/path_tracer_cutout.rs | 3 +- luisa_compute/examples/printer.rs | 19 ++-- luisa_compute/examples/ray_query.rs | 3 +- luisa_compute/examples/raytracing.rs | 54 +++++++----- luisa_compute/src/lang.rs | 14 +-- luisa_compute/src/lang/ops/cast_impls.rs | 1 - luisa_compute/src/lang/ops/impls.rs | 2 +- luisa_compute/src/lang/ops/traits.rs | 2 +- luisa_compute/src/lang/soa.rs | 3 +- luisa_compute/src/lang/types.rs | 12 ++- .../src/lang/types/vector/element.rs | 1 - luisa_compute/src/lang/types/vector/impls.rs | 2 +- luisa_compute/src/runtime.rs | 11 +-- luisa_compute/src/runtime/kernel.rs | 10 +-- luisa_compute/tests/autodiff.rs | 8 +- luisa_compute/tests/misc.rs | 3 +- .../src/bin/derive-debug.rs | 2 - rustfmt.toml | 1 - 25 files changed, 171 insertions(+), 146 deletions(-) diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 57932e06..48e168e8 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -29,49 +29,52 @@ fn main() { let dy_gt = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = Kernel::::new(&device, track!(|| { - let tid = dispatch_id().x; - let buf_x = x.var(); - let buf_y = y.var(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let f = |x: Expr, y: Expr| { - if x > y { - x * y - } else { - y * x + (x / 4.0 * PI).sin() + let shader = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let buf_x = x.var(); + let buf_y = y.var(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let f = |x: Expr, y: Expr| { + if x > y { + x * y + } else { + y * x + (x / 4.0 * PI).sin() + } + }; + let df = |x: Expr, y: Expr| { + if x > y { + (y, x) + } else { + (y + (x / 4.0 * PI).cos() / 4.0 * PI, x) + } + }; + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = f(x, y); + backward(z); + dx_rev.write(tid, gradient(x)); + dy_rev.write(tid, gradient(y)); + }); + forward_autodiff(2, || { + propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]); + propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]); + let z = f(x, y); + let dx = output_gradients(z)[0]; + let dy = output_gradients(z)[1]; + dx_fwd.write(tid, dx); + dy_fwd.write(tid, dy); + }); + { + let (dx, dy) = df(x, y); + dx_gt.write(tid, dx); + dy_gt.write(tid, dy); } - }; - let df = |x: Expr, y: Expr| { - if x > y { - (y, x) - } else { - (y + (x / 4.0 * PI).cos() / 4.0 * PI, x) - } - }; - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = f(x, y); - backward(z); - dx_rev.write(tid, gradient(x)); - dy_rev.write(tid, gradient(y)); - }); - forward_autodiff(2, || { - propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]); - propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]); - let z = f(x, y); - let dx = output_gradients(z)[0]; - let dy = output_gradients(z)[1]; - dx_fwd.write(tid, dx); - dy_fwd.write(tid, dy); - }); - { - let (dx, dy) = df(x, y); - dx_gt.write(tid, dx); - dy_gt.write(tid, dy); - } - })); + }), + ); shader.dispatch([1024, 1, 1]); { diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index 5ab81fa7..9c7825fc 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -18,23 +18,26 @@ fn main() { } else { "cpu" }); - + let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = Kernel::)>::new(&device, track!(|buf_z| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid + 123); - let y = buf_y.read(tid); - let vx = Var::::zeroed(); // create a local mutable variable - *vx = x; - buf_z.write(tid, vx + y); - })); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid + 123); + let y = buf_y.read(tid); + let vx = Var::::zeroed(); // create a local mutable variable + *vx = x; + buf_z.write(tid, vx + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index 48e37a7f..fe2e64e0 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -62,15 +62,18 @@ fn main() { bindless.emplace_buffer_async(1, &y); bindless.emplace_tex2d_async(0, &img, Sampler::default()); bindless.update(); - let kernel = Kernel::)>::new(&device, track!(|buf_z| { - let bindless = bindless.var(); - let tid = dispatch_id().x; - let buf_x = bindless.buffer::(0_u32.expr()); - let buf_y = bindless.buffer::(1_u32.expr()); - let x = buf_x.read(tid).as_::().as_::(); - let y = buf_y.read(tid); - buf_z.write(tid, x + y); - })); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let bindless = bindless.var(); + let tid = dispatch_id().x; + let buf_x = bindless.buffer::(0_u32.expr()); + let buf_y = bindless.buffer::(1_u32.expr()); + let x = buf_x.read(tid).as_::().as_::(); + let y = buf_y.read(tid); + buf_z.write(tid, x + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let mut z_data = vec![0.0; 1024]; z.view(..).copy_to(&mut z_data); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index fc9c7457..daca5a24 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -17,22 +17,24 @@ fn main() { } else { "cpu" }); - let add = - Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); + let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = Kernel::)>::new(&device, track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - buf_z.write(tid, add.call(x, y)); - })); + buf_z.write(tid, add.call(x, y)); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/custom_aggregate.rs b/luisa_compute/examples/custom_aggregate.rs index f47e2f2b..4cad01a1 100644 --- a/luisa_compute/examples/custom_aggregate.rs +++ b/luisa_compute/examples/custom_aggregate.rs @@ -1,4 +1,5 @@ -use luisa::{prelude::*, lang::types::vector::alias::Float3}; +use luisa::lang::types::vector::alias::Float3; +use luisa::prelude::*; use luisa_compute as luisa; #[derive(Aggregate)] pub struct Spectrum { diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 0500bda7..4ee6d74b 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -2,7 +2,8 @@ use std::env::current_exe; use std::time::Instant; -use luisa::lang::types::vector::{alias::*, Mat2}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::Mat2; use luisa::prelude::*; use luisa_compute as luisa; use rand::Rng; diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index 712f7b40..9cb19216 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -6,7 +6,8 @@ use std::time::Instant; use winit::event::{Event as WinitEvent, WindowEvent}; use winit::event_loop::EventLoop; -use luisa::lang::types::vector::{alias::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index e7a325e6..554bf942 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -1,5 +1,6 @@ use image::Rgb; -use luisa::lang::types::vector::{alias::*, Mat4}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::Mat4; use luisa_compute_api_types::StreamTag; use rand::Rng; use std::env::current_exe; diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index 9312a2b1..a38e54df 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -21,14 +21,17 @@ fn main() { "cpu" }); let printer = Printer::new(&device, 65536); - let kernel = Kernel::::new(&device, track!(|| { - let id = dispatch_id().xy(); - if id.x == id.y { - lc_info!(printer, "id = {:?}", id); - } else { - lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y); - } - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let id = dispatch_id().xy(); + if id.x == id.y { + lc_info!(printer, "id = {:?}", id); + } else { + lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y); + } + }), + ); device.default_stream().with_scope(|s| { s.reset_printer(&printer); s.submit([kernel.dispatch_async([4, 4, 1])]); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index b04007d1..ef982c97 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -1,7 +1,8 @@ use std::env::current_exe; use image::Rgb; -use luisa::lang::types::vector::{alias::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; use luisa::prelude::*; use luisa::rtx::{ Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, TriangleCandidate, diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 672e7246..cddf5a65 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -1,7 +1,10 @@ use std::env::current_exe; use image::Rgb; -use luisa::lang::{types::vector::alias::*, types::vector::*, types::*, *}; +use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::*; +use luisa::lang::types::*; +use luisa::lang::*; use luisa::prelude::*; use luisa::rtx::{AccelBuildRequest, AccelOption, Ray}; use luisa_compute as luisa; @@ -36,29 +39,32 @@ fn main() { let img_w = 800; let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); - let rt_kernel = Kernel::::new(&device,track!(|| { - let accel = accel.var(); - let px = dispatch_id().xy(); - let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); - let xy = 2.0 * xy - 1.0; - let o = Float3::expr(0.0, 0.0, -1.0); - let d = Float3::expr(xy.x, xy.y, 0.0) - o; - let d = d.normalize(); - let ray = Ray::new_expr( - Expr::<[f32; 3]>::from(o), - 1e-3, - Expr::<[f32; 3]>::from(d), - 1e9, - ); - let hit = accel.trace_closest(ray); - let img = img.view(0).var(); - let color = select( - hit.valid(), - Float3::expr(hit.u, hit.v, 1.0), - Float3::expr(0.0, 0.0, 0.0), - ); - img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); - })); + let rt_kernel = Kernel::::new( + &device, + track!(|| { + let accel = accel.var(); + let px = dispatch_id().xy(); + let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); + let xy = 2.0 * xy - 1.0; + let o = Float3::expr(0.0, 0.0, -1.0); + let d = Float3::expr(xy.x, xy.y, 0.0) - o; + let d = d.normalize(); + let ray = Ray::new_expr( + Expr::<[f32; 3]>::from(o), + 1e-3, + Expr::<[f32; 3]>::from(d), + 1e9, + ); + let hit = accel.trace_closest(ray); + let img = img.view(0).var(); + let color = select( + hit.valid(), + Float3::expr(hit.u, hit.v, 1.0), + Float3::expr(0.0, 0.0, 0.0), + ); + img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); + }), + ); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() .with_title("Luisa Compute Rust - Ray Tracing") diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 354761e8..a9ebb710 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -380,10 +380,11 @@ pub fn __module_pools() -> &'static CArc { /// Don't call this function directly unless you know what you are doing /** This function is soley for constructing proxies - * Given a node, __extract selects the correct Func based on the node's type - * It then inserts the extract(node, i) call *at where the node is defined* - * *Note*, after insertion, the IrBuilder in the correct/parent scope might not be up to date - * Thus, for IrBuilder of each scope, it updates the insertion point to the end of the current basic block + * Given a node, __extract selects the correct Func based on the node's + * type It then inserts the extract(node, i) call *at where the node is + * defined* *Note*, after insertion, the IrBuilder in the correct/parent + * scope might not be up to date Thus, for IrBuilder of each scope, it + * updates the insertion point to the end of the current basic block */ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { let inst = &node.get().instruction; @@ -408,8 +409,9 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { } let i = b.const_(Const::Int32(index as i32)); - // Since we have inserted something, the insertion point in cur_builder might not be up to date - // So we need to set it to the end of the current basic block + // Since we have inserted something, the insertion point in cur_builder might + // not be up to date So we need to set it to the end of the current + // basic block macro_rules! update_builders { () => { for scope in &mut r.scopes { diff --git a/luisa_compute/src/lang/ops/cast_impls.rs b/luisa_compute/src/lang/ops/cast_impls.rs index d807947e..b6c34864 100644 --- a/luisa_compute/src/lang/ops/cast_impls.rs +++ b/luisa_compute/src/lang/ops/cast_impls.rs @@ -1,4 +1,3 @@ - #[rustfmt::skip]mod impl_{ use crate::prelude::*; use super::super::*; diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 2c137664..27ebb8bf 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -224,7 +224,7 @@ where self.clone().mul(self.clone()).mul(self.clone()) } fn recip(&self) -> Self { - ::from_node(__current_scope(|b|{ + ::from_node(__current_scope(|b| { let one = b.const_(Const::One(::type_())); b.call(Func::Div, &[one, self.node()], ::type_()) })) diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 947493ad..49c61e87 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -211,7 +211,7 @@ pub trait IntExpr { pub trait FloatExpr: Sized { type Bool; - + fn ceil(&self) -> Self; fn floor(&self) -> Self; fn round(&self) -> Self; diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs index d7d4e626..cee5ed19 100644 --- a/luisa_compute/src/lang/soa.rs +++ b/luisa_compute/src/lang/soa.rs @@ -1,4 +1,5 @@ -use luisa_compute_ir::{ir::Type, CArc}; +use luisa_compute_ir::ir::Type; +use luisa_compute_ir::CArc; use crate::prelude::*; /** A buffer with SOA layout. diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 11bad6b0..35401280 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -12,9 +12,9 @@ pub mod vector; // TODO: Check up on comments. -/// A value that can be used in a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Call [`expr`](Value::expr) or -/// [`var`](Value::var) to convert into a kernel-trackable type. +/// A value that can be used in a [`Kernel`] or [`Callable`]. Call +/// [`expr`](Value::expr) or [`var`](Value::var) to convert into a +/// kernel-trackable type. pub trait Value: Copy + TypeOf + 'static { /// A proxy for additional impls on [`Expr`]. type Expr: ExprProxy; @@ -139,8 +139,7 @@ impl AtomciRefProxyDataProxyData { } } } -/// An expression within a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Created from a raw value +/// An expression within a [`Kernel`] or [`Callable`]. Created from a raw value /// using [`Value::expr`]. /// /// Note that this does not store the value, and in order to get the result of a @@ -158,8 +157,7 @@ pub struct Expr { #[repr(C)] pub struct TypeTag(PhantomData); -/// A variable within a [`Kernel`](crate::runtime::Kernel) or -/// [`Callable`](crate::runtime::Callable). Created using [`Expr::var`] +/// A variable within a [`Kernel`] or [`Callable`]. Created using [`Expr::var`] /// and [`Value::var`]. /// /// Note that setting a `Var` using direct assignment will not work. Instead, diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs index a5afeffd..f82c121e 100644 --- a/luisa_compute/src/lang/types/vector/element.rs +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -9,7 +9,6 @@ macro_rules! element { type VectorAtomicRef = VectorAtomicRefProxy2<$T>; type VectorExprData = VectorExprData<$T, 2>; type VectorVarData = VectorVarData<$T, 2>; - } }; ($T:ty [ 3 ]: $A: ident) => { diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs index 61e387a4..8765d28b 100644 --- a/luisa_compute/src/lang/types/vector/impls.rs +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -303,7 +303,7 @@ macro_rules! impl_mat_proxy { Func::Mul.call2(self, rhs) } } - + impl MatExpr for Expr<$M> { type Scalar = Expr; type Value = $M; diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 68ebe8ca..a0a538eb 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1226,24 +1226,25 @@ pub struct KernelDef { /// An executable kernel /// Kernel creation can be done in multiple ways: /// - Seperate recording and compilation: +/// /// ```no_run /// // Recording: /// use luisa_compute::prelude::*; /// let ctx = Context::new(std::env::current_exe().unwrap()); /// let device = ctx.create_device("cpu"); -/// let kernel = KernelDef::, Buffer, -/// Buffer)>::new(&device, track!(|a,b,c|{ })); -/// // Compilation: +/// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); // Compilation: /// let kernel = device.compile_kernel(&kernel); /// ``` +/// /// - Recording and compilation in one step: +/// /// ```no_run /// use luisa_compute::prelude::*; /// let ctx = Context::new(std::env::current_exe().unwrap()); /// let device = ctx.create_device("cpu"); -/// let kernel = Kernel::, Buffer, -/// Buffer)>::new(&device, track!(|a,b,c|{ })); +/// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); /// ``` +/// /// - Asynchronous compilation use [`Kernel::::new_async`] /// - Custom build options using [`Kernel::::new_with_options`] pub struct Kernel { diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index bd738927..c0c77206 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -399,8 +399,8 @@ impl KernelBuilder { /// * `async_compile`: compile the kernel asynchronously /// * `enable_cache`: enable cache for the compiled kernel /// * `enable_fast_math`: enable fast math in the compiled kernel -/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of the generated PTX kernel -/// +/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of +/// the generated PTX kernel #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct KernelBuildOptions { pub enable_debug_info: bool, @@ -427,12 +427,12 @@ impl Default for KernelBuildOptions { } } } -pub trait CallableBuildFn { +pub trait CallableBuildFn { fn build_callable(&self, args: Option>, builder: &mut KernelBuilder) -> RawCallable; } -pub trait StaticCallableBuildFn: CallableBuildFn {} +pub trait StaticCallableBuildFn: CallableBuildFn {} // @FIXME: this looks redundant pub unsafe trait CallableRet { @@ -589,7 +589,7 @@ impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); macro_rules! impl_callable_build_for_fn { ($($Ts:ident)*) => { - impl CallableBuildFnR> for T + impl CallableBuildFnR> for T where T: Fn($($Ts,)*)->R + 'static { #[allow(non_snake_case)] #[allow(unused_variables)] diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 98b43734..7e9da96e 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -1208,8 +1208,9 @@ fn autodiff_callable() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let callable = - Callable::, Var, Expr)>::new(&device, track!(|vx, vy, t| { + let callable = Callable::, Var, Expr)>::new( + &device, + track!(|vx, vy, t| { let x = **vx; let y = **vy; autodiff(|| { @@ -1224,7 +1225,8 @@ fn autodiff_callable() { *vx = gradient(x); *vy = gradient(y); }); - })); + }), + ); let kernel = Kernel::::new( &device, track!(|| { diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index ea120a3e..423ad905 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1,8 +1,9 @@ use luisa::lang::ops::AddMaybeExpr; use luisa::lang::types::array::VLArrayVar; +use luisa::lang::types::core::*; use luisa::lang::types::dynamic::*; use luisa::lang::types::vector::alias::*; -use luisa::lang::types::{core::*, ExprProxy}; +use luisa::lang::types::ExprProxy; use luisa::prelude::*; use luisa_compute as luisa; use luisa_compute_api_types::StreamTag; diff --git a/luisa_compute_derive_impl/src/bin/derive-debug.rs b/luisa_compute_derive_impl/src/bin/derive-debug.rs index b94b49f7..e71d5cac 100644 --- a/luisa_compute_derive_impl/src/bin/derive-debug.rs +++ b/luisa_compute_derive_impl/src/bin/derive-debug.rs @@ -6,8 +6,6 @@ struct Foo { b: u32, } - - fn main() { let compiler = Compiler; let item: syn::ItemStruct = syn::parse_str( diff --git a/rustfmt.toml b/rustfmt.toml index 174fd7a4..d1bdb5ed 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,3 +1,2 @@ ignore = ["luisa_compute_sys"] imports_granularity = "Module" -wrap_comments = true From 2242db66da7ca73e06fa9c6d8b27913874f92438 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Sat, 23 Sep 2023 20:52:12 +0100 Subject: [PATCH 2/3] Added backrefs for KernelParamter. --- luisa_compute/src/runtime/kernel.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index c0c77206..e0d90d21 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -97,40 +97,47 @@ impl CallableParameter for BindlessArrayVar { } } impl KernelParameter for rtx::AccelVar { + type Arg = rtx::Accel; fn def_param(builder: &mut KernelBuilder) -> Self { builder.accel() } } pub trait KernelParameter { + type Arg: KernelArg; fn def_param(builder: &mut KernelBuilder) -> Self; } impl KernelParameter for Expr { + type Arg = T; fn def_param(builder: &mut KernelBuilder) -> Self { builder.uniform::() } } impl KernelParameter for BufferVar { + type Arg = Buffer; fn def_param(builder: &mut KernelBuilder) -> Self { builder.buffer() } } impl KernelParameter for Tex2dVar { + type Arg = Tex2d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex2d() } } impl KernelParameter for Tex3dVar { + type Arg = Tex3d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex3d() } } impl KernelParameter for BindlessArrayVar { + type Arg = BindlessArray; fn def_param(builder: &mut KernelBuilder) -> Self { builder.bindless_array() } @@ -139,6 +146,7 @@ impl KernelParameter for BindlessArrayVar { macro_rules! impl_kernel_param_for_tuple { ($first:ident $($rest:ident)*) => { impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { + type Arg = ($first::Arg, $($rest::Arg),*); #[allow(non_snake_case)] fn def_param(builder: &mut KernelBuilder) -> Self { ($first::def_param(builder), $($rest::def_param(builder)),*) @@ -148,6 +156,7 @@ macro_rules! impl_kernel_param_for_tuple { }; ()=>{ impl KernelParameter for () { + type Arg = (); fn def_param(_: &mut KernelBuilder) -> Self { } } From 541f1492bcd20f615168c428df94a65b2646a679 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Sat, 23 Sep 2023 21:10:30 +0100 Subject: [PATCH 3/3] Minor fix. --- luisa_compute/src/runtime/kernel.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index e0d90d21..c72a332b 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -104,7 +104,7 @@ impl KernelParameter for rtx::AccelVar { } pub trait KernelParameter { - type Arg: KernelArg; + type Arg: KernelArg + 'static; fn def_param(builder: &mut KernelBuilder) -> Self; }