From 28d7e23552512806732dd7cfa7ade11db4ed92af Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sat, 23 Sep 2023 03:47:10 -0400 Subject: [PATCH 1/3] refactored kernel creation --- luisa_compute/examples/atomic.rs | 15 +- luisa_compute/examples/autodiff.rs | 2 +- luisa_compute/examples/backtrace.rs | 2 +- luisa_compute/examples/bindgroup.rs | 2 +- luisa_compute/examples/bindless.rs | 2 +- luisa_compute/examples/callable.rs | 4 +- luisa_compute/examples/callable_advanced.rs | 38 +- luisa_compute/examples/custom_op.rs | 33 +- luisa_compute/examples/fluid.rs | 93 +-- luisa_compute/examples/mpm.rs | 245 ++++--- luisa_compute/examples/path_tracer.rs | 27 +- luisa_compute/examples/path_tracer_cutout.rs | 21 +- luisa_compute/examples/polymorphism.rs | 21 +- .../examples/polymorphism_advanced.rs | 29 +- luisa_compute/examples/printer.rs | 2 +- luisa_compute/examples/ray_query.rs | 165 ++--- luisa_compute/examples/raytracing.rs | 2 +- luisa_compute/examples/vecadd.rs | 30 +- luisa_compute/src/lib.rs | 4 +- luisa_compute/src/printer.rs | 25 +- luisa_compute/src/rtx.rs | 4 +- luisa_compute/src/runtime.rs | 328 +++++---- luisa_compute/src/runtime/kernel.rs | 403 +++++------ luisa_compute/tests/autodiff.rs | 476 +++++++------ luisa_compute/tests/misc.rs | 656 ++++++++++-------- 25 files changed, 1382 insertions(+), 1247 deletions(-) diff --git a/luisa_compute/examples/atomic.rs b/luisa_compute/examples/atomic.rs index f6651cef..82f01919 100644 --- a/luisa_compute/examples/atomic.rs +++ b/luisa_compute/examples/atomic.rs @@ -10,12 +10,15 @@ fn main() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_sum = sum.var(); - let tid = dispatch_id().x; - buf_sum.atomic_fetch_add(0, buf_x.read(tid)); - })); + let shader = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_sum = sum.var(); + let tid = dispatch_id().x; + buf_sum.atomic_fetch_add(0, buf_x.read(tid)); + }), + ); shader.dispatch([x.len() as u32, 1, 1]); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 15842cdc..57932e06 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -29,7 +29,7 @@ fn main() { let dy_gt = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = device.create_kernel::(track!(&|| { + let shader = Kernel::::new(&device, track!(|| { let tid = dispatch_id().x; let buf_x = x.var(); let buf_y = y.var(); diff --git a/luisa_compute/examples/backtrace.rs b/luisa_compute/examples/backtrace.rs index fdaa0361..5ab81fa7 100644 --- a/luisa_compute/examples/backtrace.rs +++ b/luisa_compute/examples/backtrace.rs @@ -24,7 +24,7 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(track!(&|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); diff --git a/luisa_compute/examples/bindgroup.rs b/luisa_compute/examples/bindgroup.rs index a2965fb8..77df3ed8 100644 --- a/luisa_compute/examples/bindgroup.rs +++ b/luisa_compute/examples/bindgroup.rs @@ -22,6 +22,6 @@ fn main() { y, exclude: 42.0, }; - let shader = device.create_kernel::)>(&|_args| {}); + let shader = Kernel::)>::new(&device, |_args| {}); shader.dispatch([1024, 1, 1], &my_args); } diff --git a/luisa_compute/examples/bindless.rs b/luisa_compute/examples/bindless.rs index 0ee206dd..48e37a7f 100644 --- a/luisa_compute/examples/bindless.rs +++ b/luisa_compute/examples/bindless.rs @@ -62,7 +62,7 @@ fn main() { bindless.emplace_buffer_async(1, &y); bindless.emplace_tex2d_async(0, &img, Sampler::default()); bindless.update(); - let kernel = device.create_kernel::)>(&track!(|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { let bindless = bindless.var(); let tid = dispatch_id().x; let buf_x = bindless.buffer::(0_u32.expr()); diff --git a/luisa_compute/examples/callable.rs b/luisa_compute/examples/callable.rs index 23255297..fc9c7457 100644 --- a/luisa_compute/examples/callable.rs +++ b/luisa_compute/examples/callable.rs @@ -18,13 +18,13 @@ fn main() { "cpu" }); let add = - device.create_callable::, Expr) -> Expr>(&|a, b| track!(a + b)); + Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { + let kernel = Kernel::)>::new(&device, track!(|buf_z| { let buf_x = x.var(); let buf_y = y.var(); let tid = dispatch_id().x; diff --git a/luisa_compute/examples/callable_advanced.rs b/luisa_compute/examples/callable_advanced.rs index 5de40244..f8630ab3 100644 --- a/luisa_compute/examples/callable_advanced.rs +++ b/luisa_compute/examples/callable_advanced.rs @@ -18,8 +18,9 @@ fn main() { } else { "cpu" }); - let add = device.create_dyn_callable:: DynExpr>(Box::new( - |a: DynExpr, b: DynExpr| -> DynExpr { + let add = DynCallable:: DynExpr>::new( + &device, + Box::new(|a: DynExpr, b: DynExpr| -> DynExpr { if let Some(a) = a.downcast::() { let b = b.downcast::().unwrap(); return DynExpr::new(track!(a + b)); @@ -29,28 +30,31 @@ fn main() { } else { unreachable!() } - }, - )); + }), + ); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - buf_z.write(tid, add.call(x.into(), y.into()).get::()); - w.var().write( - tid, - add.call(x.as_::().into(), y.as_::().into()) - .get::(), - ); - })); + buf_z.write(tid, add.call(x.into(), y.into()).get::()); + w.var().write( + tid, + add.call(x.as_::().into(), y.as_::().into()) + .get::(), + ); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index f039264d..65bcbcf9 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -28,21 +28,24 @@ fn main() { println!("Hello from thread 0!"); } }); - let shader = device.create_kernel::)>(&track!(|buf_z: BufferVar| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let args = MyAddArgs::new_expr(x, y, 0.0f32.expr()); - let result = my_add.call(args); - let _ = my_print.call(tid); - if tid == 0 { - cpu_dbg!(args); - } - buf_z.write(tid, result.result); - })); + let shader = Kernel::)>::new( + &device, + track!(|buf_z: BufferVar| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let args = MyAddArgs::new_expr(x, y, 0.0f32.expr()); + let result = my_add.call(args); + let _ = my_print.call(tid); + if tid == 0 { + cpu_dbg!(args); + } + buf_z.write(tid, result.result); + }), + ); shader.dispatch([1024, 1, 1], &z); let mut z_data = vec![0.0; 1024]; z.view(..).copy_to(&mut z_data); diff --git a/luisa_compute/examples/fluid.rs b/luisa_compute/examples/fluid.rs index cf5d1eed..c17ac49d 100644 --- a/luisa_compute/examples/fluid.rs +++ b/luisa_compute/examples/fluid.rs @@ -118,24 +118,25 @@ fn main() { } ); - let advect = device - .create_kernel_async::, Buffer, Buffer, Buffer)>( - track!(&|u0, u1, rho0, rho1| { - let coord = dispatch_id().xy(); - let u = u0.read(index(coord)); - - // trace backward - let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32()); - p = p - u * dt; - - // advect - u1.write(index(coord), sample_vel(u0, p.x, p.y)); - rho1.write(index(coord), sample_float(rho0, p.x, p.y)); - }), - ); + let advect = Kernel::, Buffer, Buffer, Buffer)>::new_async( + &device, + track!(|u0, u1, rho0, rho1| { + let coord = dispatch_id().xy(); + let u = u0.read(index(coord)); + + // trace backward + let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32()); + p = p - u * dt; + + // advect + u1.write(index(coord), sample_vel(u0, p.x, p.y)); + rho1.write(index(coord), sample_float(rho0, p.x, p.y)); + }), + ); - let divergence = - device.create_kernel_async::, Buffer)>(track!(&|u, div| { + let divergence = Kernel::, Buffer)>::new_async( + &device, + track!(|u, div| { let coord = dispatch_id().xy(); if coord.x < (N_GRID as u32 - 1) && coord.y < (N_GRID as u32 - 1) { let dx = (u.read(index(Uint2::expr(coord.x + 1, coord.y))).x @@ -146,10 +147,12 @@ fn main() { * 0.5; div.write(index(coord), dx + dy); } - })); + }), + ); - let pressure_solve = device.create_kernel_async::, Buffer, Buffer)>( - track!(&|p0, p1, div| { + let pressure_solve = Kernel::, Buffer, Buffer)>::new_async( + &device, + track!(|p0, p1, div| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -166,8 +169,9 @@ fn main() { }), ); - let pressure_apply = - device.create_kernel_async::, Buffer)>(track!(&|p, u| { + let pressure_apply = Kernel::, Buffer)>::new_async( + &device, + track!(|p, u| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -184,10 +188,12 @@ fn main() { u.write(ij, u.read(ij) - f_p); } - })); + }), + ); - let integrate = - device.create_kernel_async::, Buffer)>(track!(&|u, rho| { + let integrate = Kernel::, Buffer)>::new_async( + &device, + track!(|u, rho| { let coord = dispatch_id().xy(); let ij = index(coord); @@ -199,10 +205,12 @@ fn main() { // fade rho.write(ij, rho.read(ij) * (1.0f32 - 0.1f32 * dt)); - })); + }), + ); - let init = device.create_kernel_async::, Buffer, Float2)>(track!( - &|rho, u, dir| { + let init = Kernel::, Buffer, Float2)>::new_async( + &device, + track!(|rho, u, dir| { let coord = dispatch_id().xy(); let i = coord.x.as_i32(); let j = coord.y.as_i32(); @@ -214,10 +222,10 @@ fn main() { rho.write(ij, 1.0f32); u.write(ij, dir); } - } - )); + }), + ); - let init_grid = device.create_kernel_async::(&|| { + let init_grid = Kernel::::new_async(&device, || { let idx = index(dispatch_id().xy()); u0.var().write(idx, Float2::expr(0.0f32, 0.0f32)); u1.var().write(idx, Float2::expr(0.0f32, 0.0f32)); @@ -230,21 +238,24 @@ fn main() { div.var().write(idx, 0.0f32); }); - let clear_pressure = device.create_kernel_async::(&|| { + let clear_pressure = Kernel::::new_async(&device, || { let idx = index(dispatch_id().xy()); p0.var().write(idx, 0.0f32); p1.var().write(idx, 0.0f32); }); - let draw_rho = device.create_kernel_async::(&track!(|| { - let coord = dispatch_id().xy(); - let ij = index(coord); - let value = rho0.var().read(ij); - display.var().write( - Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y), - Float4::expr(value, 0.0f32, 0.0f32, 1.0f32), - ); - })); + let draw_rho = Kernel::::new_async( + &device, + track!(|| { + let coord = dispatch_id().xy(); + let ij = index(coord); + let value = rho0.var().read(ij); + display.var().write( + Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y), + Float4::expr(value, 0.0f32, 0.0f32, 1.0f32), + ); + }), + ); event_loop.run(move |event, _, control_flow| { control_flow.set_poll(); diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 8044ccf4..b731b8e0 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -90,137 +90,154 @@ fn main() { p.x + p.y * N_GRID as u32 }); - let clear_grid = device.create_kernel_async::(track!(&|| { - let idx = index(dispatch_id().xy()); - grid_v.var().write(idx * 2, 0.0f32); - grid_v.var().write(idx * 2 + 1, 0.0f32); - grid_m.var().write(idx, 0.0f32); - })); - - let point_to_grid = device.create_kernel_async::(track!(&|| { - let p = dispatch_id().x; - let xp = x.var().read(p) / DX; - let base = (xp - 0.5f32).cast_i32(); - let fx = xp - base.cast_f32(); - - let w = [ - 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), - 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), - 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), - ]; - let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX); - let affine = - Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); - let vp = v.var().read(p); - escape!(for ii in 0..9 { - let (i, j) = (ii % 3, ii / 3); - track!({ - let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX; - let weight = w[i].x * w[j].y; - let vadd = weight * (P_MASS * vp + affine * dpos); - let idx = index((base + offset).cast_u32()); - grid_v.var().atomic_fetch_add(idx * 2, vadd.x); - grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); - grid_m.var().atomic_fetch_add(idx, weight * P_MASS); - }); - }); - let _ = (); // WHAT? - })); - - let simulate_grid = device.create_kernel_async::(&track!(|| { - let coord = dispatch_id().xy(); - let i = index(coord); - let v = Var::::zeroed(); - v.store(Float2::expr( - grid_v.var().read(i * 2u32), - grid_v.var().read(i * 2u32 + 1u32), - )); - let m = grid_m.var().read(i); - - v.store(select(m > 0.0f32, v.load() / m, v.load())); - let vx = v.load().x; - let vy = v.load().y - DT * GRAVITY; - let vx = select( - coord.x < BOUND && (vx < 0.0f32) || coord.x + BOUND > N_GRID as u32 && (vx > 0.0f32), - 0.0f32.expr(), - vx, - ); - let vy = select( - coord.y < BOUND && (vy < 0.0f32) || coord.y + BOUND > N_GRID as u32 && (vy > 0.0f32), - 0.0f32.expr(), - vy, - ); - grid_v.var().write(i * 2, vx); - grid_v.var().write(i * 2 + 1, vy); - })); + let clear_grid = Kernel::::new( + &device, + track!(|| { + let idx = index(dispatch_id().xy()); + grid_v.var().write(idx * 2, 0.0f32); + grid_v.var().write(idx * 2 + 1, 0.0f32); + grid_m.var().write(idx, 0.0f32); + }), + ); - let grid_to_point = device.create_kernel_async::(track!(&|| { - let p = dispatch_id().x; - let xp = x.var().read(p) / DX; - let base = (xp - 0.5f32).cast_i32(); - let fx = xp - base.cast_f32(); + let point_to_grid = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + let xp = x.var().read(p) / DX; + let base = (xp - 0.5f32).cast_i32(); + let fx = xp - base.cast_f32(); - let w = [ - 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), - 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), - 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), - ]; - let new_v = Var::::zeroed(); - let new_C = Var::::zeroed(); - new_v.store(Float2::expr(0.0f32, 0.0f32)); - new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); - escape!({ - for ii in 0..9 { + let w = [ + 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), + 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), + 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), + ]; + let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX); + let affine = + Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); + let vp = v.var().read(p); + escape!(for ii in 0..9 { let (i, j) = (ii % 3, ii / 3); track!({ let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX.expr(); + let dpos = (offset.cast_f32() - fx) * DX; let weight = w[i].x * w[j].y; + let vadd = weight * (P_MASS * vp + affine * dpos); let idx = index((base + offset).cast_u32()); - let g_v = Float2::expr( - grid_v.var().read(idx * 2u32), - grid_v.var().read(idx * 2u32 + 1u32), - ); - new_v.store(new_v.load() + weight * g_v); - new_C.store( - new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX), - ); + grid_v.var().atomic_fetch_add(idx * 2, vadd.x); + grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); + grid_m.var().atomic_fetch_add(idx, weight * P_MASS); }); - } - }); + }); + let _ = (); // WHAT? + }), + ); - v.var().write(p, new_v); - x.var().write(p, x.var().read(p) + new_v.load() * DT); - J.var() - .write(p, J.var().read(p) * (1.0f32 + DT * trace(new_C.load()))); - C.var().write(p, new_C); - })); + let simulate_grid = Kernel::::new( + &device, + track!(|| { + let coord = dispatch_id().xy(); + let i = index(coord); + let v = Var::::zeroed(); + v.store(Float2::expr( + grid_v.var().read(i * 2u32), + grid_v.var().read(i * 2u32 + 1u32), + )); + let m = grid_m.var().read(i); - let clear_display = device.create_kernel_async::(&|| { + v.store(select(m > 0.0f32, v.load() / m, v.load())); + let vx = v.load().x; + let vy = v.load().y - DT * GRAVITY; + let vx = select( + coord.x < BOUND && (vx < 0.0f32) + || coord.x + BOUND > N_GRID as u32 && (vx > 0.0f32), + 0.0f32.expr(), + vx, + ); + let vy = select( + coord.y < BOUND && (vy < 0.0f32) + || coord.y + BOUND > N_GRID as u32 && (vy > 0.0f32), + 0.0f32.expr(), + vy, + ); + grid_v.var().write(i * 2, vx); + grid_v.var().write(i * 2 + 1, vy); + }), + ); + + let grid_to_point = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + let xp = x.var().read(p) / DX; + let base = (xp - 0.5f32).cast_i32(); + let fx = xp - base.cast_f32(); + + let w = [ + 0.5f32 * (1.5f32 - fx) * (1.5f32 - fx), + 0.75f32 - (fx - 1.0f32) * (fx - 1.0f32), + 0.5f32 * (fx - 0.5f32) * (fx - 0.5f32), + ]; + let new_v = Var::::zeroed(); + let new_C = Var::::zeroed(); + new_v.store(Float2::expr(0.0f32, 0.0f32)); + new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); + escape!({ + for ii in 0..9 { + let (i, j) = (ii % 3, ii / 3); + track!({ + let offset = Int2::expr(i as i32, j as i32); + let dpos = (offset.cast_f32() - fx) * DX.expr(); + let weight = w[i].x * w[j].y; + let idx = index((base + offset).cast_u32()); + let g_v = Float2::expr( + grid_v.var().read(idx * 2u32), + grid_v.var().read(idx * 2u32 + 1u32), + ); + new_v.store(new_v.load() + weight * g_v); + new_C.store( + new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX), + ); + }); + } + }); + + v.var().write(p, new_v); + x.var().write(p, x.var().read(p) + new_v.load() * DT); + J.var() + .write(p, J.var().read(p) * (1.0f32 + DT * trace(new_C.load()))); + C.var().write(p, new_C); + }), + ); + + let clear_display = Kernel::::new(&device, || { display.var().write( dispatch_id().xy(), Float4::expr(0.1f32, 0.2f32, 0.3f32, 1.0f32), ); }); - let draw_particles = device.create_kernel_async::(&track!(|| { - let p = dispatch_id().x; - for i in -1..=1 { - for j in -1..=1 { - let pos = (x.var().read(p) * RESOLUTION as f32).cast_i32() + Int2::expr(i, j); - if pos.x >= (0i32) - && pos.x < (RESOLUTION as i32) - && pos.y >= (0i32) - && pos.y < (RESOLUTION as i32) - { - display.var().write( - Uint2::expr(pos.x.cast_u32(), RESOLUTION - 1u32 - pos.y.cast_u32()), - Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32), - ); + let draw_particles = Kernel::::new( + &device, + track!(|| { + let p = dispatch_id().x; + for i in -1..=1 { + for j in -1..=1 { + let pos = (x.var().read(p) * RESOLUTION as f32).cast_i32() + Int2::expr(i, j); + if pos.x >= (0i32) + && pos.x < (RESOLUTION as i32) + && pos.y >= (0i32) + && pos.y < (RESOLUTION as i32) + { + display.var().write( + Uint2::expr(pos.x.cast_u32(), RESOLUTION - 1u32 - pos.y.cast_u32()), + Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32), + ); + } } } - } - })); + }), + ); event_loop.run(move |event, _, control_flow| { control_flow.set_poll(); match event { diff --git a/luisa_compute/examples/path_tracer.rs b/luisa_compute/examples/path_tracer.rs index 7278494b..712f7b40 100644 --- a/luisa_compute/examples/path_tracer.rs +++ b/luisa_compute/examples/path_tracer.rs @@ -8,7 +8,9 @@ use winit::event_loop::EventLoop; use luisa::lang::types::vector::{alias::*, *}; use luisa::prelude::*; -use luisa::rtx::{offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps}; +use luisa::rtx::{ + offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps, +}; use luisa_compute as luisa; #[derive(Value, Clone, Copy)] @@ -244,11 +246,12 @@ fn main() { }); // use create_kernel_async to compile multiple kernels in parallel - let path_tracer = device.create_kernel_async::, Tex2d, Accel, Uint2)>( - track!(&|image: Tex2dVar, - seed_image: Tex2dVar, - accel: AccelVar, - resolution: Expr| { + let path_tracer = Kernel::, Tex2d, Accel, Uint2)>::new_async( + &device, + track!(|image: Tex2dVar, + seed_image: Tex2dVar, + accel: AccelVar, + resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); let cbox_materials = ([ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor @@ -263,7 +266,7 @@ fn main() { .expr(); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::) -> Expr>(|state: Var| { + let lcg = Callable::) -> Expr>::new_static(|state: Var| { const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state = LCG_A * state + LCG_C; @@ -278,7 +281,7 @@ fn main() { orig: o.into(), tmin: tmin, dir: d.into(), - tmax: tmax + tmax: tmax, }) }; @@ -455,8 +458,9 @@ fn main() { ); }), ); - let display = - device.create_kernel_async::, Tex2d)>(track!(&|acc, display| { + let display = Kernel::, Tex2d)>::new_async( + &device, + track!(|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); @@ -468,7 +472,8 @@ fn main() { let srgb = radiance.lt(0.0031308).select(radiance * 12.92, r); display.write(coord, Float4::expr(srgb.x, srgb.y, srgb.z, 1.0f32)); - })); + }), + ); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/path_tracer_cutout.rs b/luisa_compute/examples/path_tracer_cutout.rs index 92cd10de..e7a325e6 100644 --- a/luisa_compute/examples/path_tracer_cutout.rs +++ b/luisa_compute/examples/path_tracer_cutout.rs @@ -256,11 +256,12 @@ fn main() { }); // use create_kernel_async to compile multiple kernels in parallel - let path_tracer = device.create_kernel_async::, Tex2d, Accel, Uint2)>( - track!(&|image: Tex2dVar, - seed_image: Tex2dVar, - accel: AccelVar, - resolution: Expr| { + let path_tracer = Kernel::, Tex2d, Accel, Uint2)>::new_async( + &device, + track!(|image: Tex2dVar, + seed_image: Tex2dVar, + accel: AccelVar, + resolution: Expr| { set_block_size([16u32, 16u32, 1u32]); let cbox_materials = [ Float3::new(0.725f32, 0.710f32, 0.680f32), // floor @@ -275,7 +276,7 @@ fn main() { .expr(); let lcg = |state: Var| -> Expr { - let lcg = create_static_callable::) -> Expr>(|state: Var| { + let lcg = Callable::) -> Expr>::new_static(|state: Var| { const LCG_A: u32 = 1664525u32; const LCG_C: u32 = 1013904223u32; *state = LCG_A * state + LCG_C; @@ -505,8 +506,9 @@ fn main() { ); }), ); - let display = - device.create_kernel_async::, Tex2d)>(track!(&|acc, display| { + let display = Kernel::, Tex2d)>::new_async( + &device, + track!(|acc, display| { set_block_size([16, 16, 1]); let coord = dispatch_id().xy(); let radiance = acc.read(coord); @@ -518,7 +520,8 @@ fn main() { let srgb = radiance.lt(0.0031308).select(radiance * 12.92, r); display.write(coord, Float4::expr(srgb.x, srgb.y, srgb.z, 1.0f32)); - })); + }), + ); let img_w = 1024; let img_h = 1024; let acc_img = device.create_tex2d::(PixelStorage::Float4, img_w, img_h, 1); diff --git a/luisa_compute/examples/polymorphism.rs b/luisa_compute/examples/polymorphism.rs index 9a9c4126..8539c4e4 100644 --- a/luisa_compute/examples/polymorphism.rs +++ b/luisa_compute/examples/polymorphism.rs @@ -50,15 +50,18 @@ fn main() { poly_area.register((), &circles); poly_area.register((), &squares); let areas = device.create_buffer::(4); - let shader = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let tag = tid / 2; - let index = tid % 2; - let area = poly_area - .get(TagIndex::new_expr(tag, index)) - .dispatch(|_tag, _key, obj| obj.area()); - areas.var().write(tid, area); - })); + let shader = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let tag = tid / 2; + let index = tid % 2; + let area = poly_area + .get(TagIndex::new_expr(tag, index)) + .dispatch(|_tag, _key, obj| obj.area()); + areas.var().write(tid, area); + }), + ); shader.dispatch([4, 1, 1]); let areas = areas.view(..).copy_to_vec(); println!("{:?}", areas); diff --git a/luisa_compute/examples/polymorphism_advanced.rs b/luisa_compute/examples/polymorphism_advanced.rs index 316600a6..a62a03a4 100644 --- a/luisa_compute/examples/polymorphism_advanced.rs +++ b/luisa_compute/examples/polymorphism_advanced.rs @@ -132,19 +132,22 @@ fn main() { ); let poly_shader = builder.build(); let result = device.create_buffer::(100); - let kernel = device.create_kernel::(&track!(|| { - let i = dispatch_id().x; - let x = i.as_f32() / 100.0 * PI; - let ctx = ShaderEvalContext { - poly_shader: &poly_shader, - key: &shader_final_key, - }; - let tag_index = TagIndex::new_expr(shader_final.tag, shader_final.index); - let v = poly_shader - .get(tag_index) - .dispatch(|_, _, shader| shader.evaluate(x, &ctx)); - result.var().write(i, v); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let i = dispatch_id().x; + let x = i.as_f32() / 100.0 * PI; + let ctx = ShaderEvalContext { + poly_shader: &poly_shader, + key: &shader_final_key, + }; + let tag_index = TagIndex::new_expr(shader_final.tag, shader_final.index); + let v = poly_shader + .get(tag_index) + .dispatch(|_, _, shader| shader.evaluate(x, &ctx)); + result.var().write(i, v); + }), + ); kernel.dispatch([100, 1, 1]); let result = result.copy_to_vec(); for i in 0..100 { diff --git a/luisa_compute/examples/printer.rs b/luisa_compute/examples/printer.rs index af173ae7..9312a2b1 100644 --- a/luisa_compute/examples/printer.rs +++ b/luisa_compute/examples/printer.rs @@ -21,7 +21,7 @@ fn main() { "cpu" }); let printer = Printer::new(&device, 65536); - let kernel = device.create_kernel::(track!(&|| { + let kernel = Kernel::::new(&device, track!(|| { let id = dispatch_id().xy(); if id.x == id.y { lc_info!(printer, "id = {:?}", id); diff --git a/luisa_compute/examples/ray_query.rs b/luisa_compute/examples/ray_query.rs index 1b6d63fc..b04007d1 100644 --- a/luisa_compute/examples/ray_query.rs +++ b/luisa_compute/examples/ray_query.rs @@ -119,93 +119,98 @@ fn main() { let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); let debug_hit_t = device.create_buffer::(4); - let rt_kernel = device.create_kernel::(&track!(|| { - let accel = accel.var(); - let px = dispatch_id().xy(); - let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32); - let xy = 2.0 * xy - 1.0; - let o = Float3::expr(0.0, 0.0, -2.0); - let d = Float3::expr(xy.x, xy.y, 0.0) - o; - let d = d.normalize(); + let rt_kernel = Kernel::::new( + &device, + track!(|| { + let accel = accel.var(); + let px = dispatch_id().xy(); + let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32); + let xy = 2.0 * xy - 1.0; + let o = Float3::expr(0.0, 0.0, -2.0); + let d = Float3::expr(xy.x, xy.y, 0.0) - o; + let d = d.normalize(); - let ray = Ray::new_expr( - Expr::<[f32; 3]>::from(o + translate.expr()), - 1e-3f32, - Expr::<[f32; 3]>::from(d), - 1e9f32, - ); - let hit = accel.query_all( - ray, - 255, - RayQuery { - on_triangle_hit: |candidate: TriangleCandidate| { - let bary = candidate.bary; - let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); - let t = candidate.committed_ray_t; - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(0, t); - debug_hit_t.write(1, candidate.ray().tmax); - }; - // if (uvw.xy().length() < 0.8) - // & (uvw.yz().length() < 0.8) - // & (uvw.xz().length() < 0.8) - if uvw.xy().length() < 0.8 && uvw.yz().length() < 0.8 && uvw.xz().length() < 0.8 - { - candidate.commit(); - } - }, - on_procedural_hit: |candidate: ProceduralCandidate| { - let ray = candidate.ray(); - let prim = candidate.prim; - let sphere = spheres.var().read(prim); - let o: Expr = ray.orig.into(); - let d: Expr = ray.dir.into(); - let t = Var::::zeroed(); + let ray = Ray::new_expr( + Expr::<[f32; 3]>::from(o + translate.expr()), + 1e-3f32, + Expr::<[f32; 3]>::from(d), + 1e9f32, + ); + let hit = accel.query_all( + ray, + 255, + RayQuery { + on_triangle_hit: |candidate: TriangleCandidate| { + let bary = candidate.bary; + let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); + let t = candidate.committed_ray_t; + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(0, t); + debug_hit_t.write(1, candidate.ray().tmax); + }; + // if (uvw.xy().length() < 0.8) + // & (uvw.yz().length() < 0.8) + // & (uvw.xz().length() < 0.8) + if uvw.xy().length() < 0.8 + && uvw.yz().length() < 0.8 + && uvw.xz().length() < 0.8 + { + candidate.commit(); + } + }, + on_procedural_hit: |candidate: ProceduralCandidate| { + let ray = candidate.ray(); + let prim = candidate.prim; + let sphere = spheres.var().read(prim); + let o: Expr = ray.orig.into(); + let d: Expr = ray.dir.into(); + let t = Var::::zeroed(); - for _ in 0..100 { - let dist = (o + d * t - (sphere.center + translate.expr())).length() - - sphere.radius; - if dist < 0.001 { - if (px == Uint2::expr(400, 400)).all() { - debug_hit_t.write(2, t); - debug_hit_t.write(3, candidate.ray().tmax); - } - if t < ray.tmax { - candidate.commit(t); + for _ in 0..100 { + let dist = (o + d * t - (sphere.center + translate.expr())).length() + - sphere.radius; + if dist < 0.001 { + if (px == Uint2::expr(400, 400)).all() { + debug_hit_t.write(2, t); + debug_hit_t.write(3, candidate.ray().tmax); + } + if t < ray.tmax { + candidate.commit(t); + } + break; } - break; + *t += dist; } - *t += dist; - } + }, }, - }, - ); - let img = img.view(0).var(); - let color = if hit.triangle_hit() { - let bary = hit.bary; - let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); - uvw - } else { - if hit.procedural_hit() { - let prim = hit.prim_id; - let sphere = spheres.var().read(prim); - let normal = (Expr::::from(ray.orig) - + Expr::::from(ray.dir) * hit.committed_ray_t - - sphere.center) - .normalize(); - let light_dir = Float3::expr(1.0, 0.6, -0.2).normalize(); - let light = Float3::expr(1.0, 1.0, 1.0); - let ambient = Float3::expr(0.1, 0.1, 0.1); - let diffuse = luisa::max(light * normal.dot(light_dir), 0.0); - let color = ambient + diffuse; - color + ); + let img = img.view(0).var(); + let color = if hit.triangle_hit() { + let bary = hit.bary; + let uvw = Float3::expr(1.0 - bary.x - bary.y, bary.x, bary.y); + uvw } else { - Float3::expr(0.0, 0.0, 0.0) - } - }; + if hit.procedural_hit() { + let prim = hit.prim_id; + let sphere = spheres.var().read(prim); + let normal = (Expr::::from(ray.orig) + + Expr::::from(ray.dir) * hit.committed_ray_t + - sphere.center) + .normalize(); + let light_dir = Float3::expr(1.0, 0.6, -0.2).normalize(); + let light = Float3::expr(1.0, 1.0, 1.0); + let ambient = Float3::expr(0.1, 0.1, 0.1); + let diffuse = luisa::max(light * normal.dot(light_dir), 0.0); + let color = ambient + diffuse; + color + } else { + Float3::expr(0.0, 0.0, 0.0) + } + }; - img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); - })); + img.write(px, Float4::expr(color.x, color.y, color.z, 1.0)); + }), + ); let event_loop = EventLoop::new(); let window = winit::window::WindowBuilder::new() .with_title("Luisa Compute Rust - Ray Query") diff --git a/luisa_compute/examples/raytracing.rs b/luisa_compute/examples/raytracing.rs index 0adb7cb3..672e7246 100644 --- a/luisa_compute/examples/raytracing.rs +++ b/luisa_compute/examples/raytracing.rs @@ -36,7 +36,7 @@ fn main() { let img_w = 800; let img_h = 800; let img = device.create_tex2d::(PixelStorage::Byte4, img_w, img_h, 1); - let rt_kernel = device.create_kernel::(&track!(|| { + let rt_kernel = Kernel::::new(&device,track!(|| { let accel = accel.var(); let px = dispatch_id().xy(); let xy = px.as_::() / Float2::expr(img_w as f32, img_h as f32); diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index 448023d5..a90882f2 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -2,8 +2,9 @@ use std::env::current_exe; use luisa::lang::types::vector::alias::*; use luisa::prelude::*; -use std::cell::RefCell; +use luisa::runtime::{Kernel, KernelDef}; use luisa_compute as luisa; +use std::cell::RefCell; fn main() { luisa::init_logger(); let args: Vec = std::env::args().collect(); @@ -24,18 +25,21 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(track!(&|buf_z| { - // z is pass by arg - let buf_x = x.var(); // x and y are captured - let buf_y = y.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let vx = 2.0_f32.var(); // create a local mutable variable - *vx += x; // store to vx - *vx = vx; - buf_z.write(tid, vx + y); - })); + + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + // z is pass by arg + let buf_x = x.var(); // x and y are captured + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let vx = 2.0_f32.var(); // create a local mutable variable + *vx += x; // store to vx + buf_z.write(tid, vx + y); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 4c117b25..12e3d3cd 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -40,7 +40,8 @@ pub mod prelude { pub use crate::resource::{IoTexel, StorageTexel, *}; pub use crate::runtime::api::StreamTag; pub use crate::runtime::{ - create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + Callable, Command, Device, DynCallable, Kernel, KernelBuildOptions, KernelDef, Scope, + Stream, }; pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, while_, Context}; @@ -155,6 +156,7 @@ impl Context { } } +#[derive(Clone)] pub struct ResourceTracker { resources: Vec>, } diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index eab2eb67..56c4d97e 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -149,31 +149,8 @@ impl Printer { count_per_arg: args.count_per_arg, }); } - pub fn reset(&self) -> PrinterReset { - PrinterReset { inner: self } - } - pub fn print(&self) -> PrinterPrint { - PrinterPrint { inner: self } - } -} -pub struct PrinterPrint<'a> { - inner: &'a Printer, -} -pub struct PrinterReset<'a> { - inner: &'a Printer, -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - fn shl(self, rhs: PrinterPrint<'a>) -> Self::Output { - self.print(rhs.inner) - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - fn shl(self, rhs: PrinterReset<'a>) -> Self::Output { - self.reset_printer(rhs.inner) - } } + impl<'a> Scope<'a> { pub fn reset_printer(&self, printer: &Printer) -> &Self { printer diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 303e58b7..4bce93d1 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -391,7 +391,7 @@ pub enum HitType { pub fn offset_ray_origin(p: Expr, n: Expr) -> Expr { lazy_static! { static ref F: Callable, Expr) -> Expr> = - create_static_callable::, Expr) -> Expr>(|p, n| { + Callable::, Expr) -> Expr>::new_static(|p, n| { const ORIGIN: f32 = 1.0f32 / 32.0f32; const FLOAT_SCALE: f32 = 1.0f32 / 65536.0f32; const INT_SCALE: f32 = 256.0f32; @@ -473,7 +473,7 @@ impl Deref for TriangleCandidate { } } impl ProceduralCandidate { - pub fn commit(&self, t: impl AsExpr) { + pub fn commit(&self, t: impl AsExpr) { let t = t.as_expr(); __current_scope(|b| { b.call( diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 218afa93..871aa50f 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -202,7 +202,7 @@ impl Device { }), _marker: PhantomData {}, len: count, - _is_byte_buffer:false, + _is_byte_buffer: false, }; buffer } @@ -400,93 +400,59 @@ impl Device { modifications: RwLock::new(HashMap::new()), } } - pub fn create_callable<'a, S: CallableSignature<'a>>(&self, f: S::Fn) -> S::Callable { - let mut builder = KernelBuilder::new(Some(self.clone()), false); - let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); - S::wrap_raw_callable(raw_callable) - } - pub fn create_dyn_callable<'a, S: CallableSignature<'a>>(&self, f: S::DynFn) -> S::DynCallable { - S::create_dyn_callable(self.clone(), false, f) + + /// Compile a [`KernelDef`] into a [`Kernel`]. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel(&self, k: &KernelDef) -> Kernel { + self.compile_kernel_with_options(k, KernelBuildOptions::default()) } - pub fn create_dyn_callable_once<'a, S: CallableSignature<'a>>( - &self, - f: S::DynFn, - ) -> S::DynCallable { - S::create_dyn_callable(self.clone(), true, f) - } - pub fn create_kernel<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = - KernelBuildFn::build_kernel(&f, &mut builder, KernelBuildOptions::default()); - S::wrap_raw_kernel(raw_kernel) - } - pub fn create_kernel_async<'a, S: KernelSignature<'a>>(&self, f: S::Fn) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel( - &f, - &mut builder, + + /// Compile a [`KernelDef`] into a [`Kernel`] asynchronously. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel_async(&self, k: &KernelDef) -> Kernel { + self.compile_kernel_with_options( + k, KernelBuildOptions { async_compile: true, ..Default::default() }, - ); - S::wrap_raw_kernel(raw_kernel) + ) } - pub fn create_kernel_with_options<'a, S: KernelSignature<'a>>( + + /// Compile a [`KernelDef`] into a [`Kernel`] with options. See [`Kernel`] for more details on kernel creation + pub fn compile_kernel_with_options( &self, - f: S::Fn, + k: &KernelDef, options: KernelBuildOptions, - ) -> S::Kernel { - let mut builder = KernelBuilder::new(Some(self.clone()), true); - let raw_kernel = KernelBuildFn::build_kernel(&f, &mut builder, options); - S::wrap_raw_kernel(raw_kernel) - } -} - -pub fn create_static_callable<'a, S: CallableSignature<'a>>(f: S::StaticFn) -> S::Callable { - let r_backup = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - std::mem::replace(&mut *r, Recorder::new()) - }); - let mut builder = KernelBuilder::new(None, false); - let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); - let callable = S::wrap_raw_callable(raw_callable); - RECORDER.with(|r| { - *r.borrow_mut() = r_backup; - }); - callable -} -#[macro_export] -macro_rules! fn_n_args { - (0)=>{ dyn Fn()}; - (1)=>{ dyn Fn(_)}; - (2)=>{ dyn Fn(_,_)}; - (3)=>{ dyn Fn(_,_,_)}; - (4)=>{ dyn Fn(_,_,_,_)}; - (5)=>{ dyn Fn(_,_,_,_,_)}; - (6)=>{ dyn Fn(_,_,_,_,_,_)}; - (7)=>{ dyn Fn(_,_,_,_,_,_,_)}; - (8)=>{ dyn Fn(_,_,_,_,_,_,_,_)}; - (9)=>{ dyn Fn(_,_,_,_,_,_,_,_,_)}; - (10)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_)}; - (11)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_)}; - (12)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_)}; - (13)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_)}; - (14)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_,_)}; - (15)=>{dyn Fn(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)}; -} -#[macro_export] -macro_rules! wrap_fn { - ($arg_count:tt, $f:expr) => { - &$f as &fn_n_args!($arg_count) - }; -} -#[macro_export] -macro_rules! create_kernel { - ($device:expr, $arg_count:tt, $f:expr) => {{ - let kernel: fn_n_args!($arg_count) = Box::new($f); - $device.create_kernel(kernel) - }}; + ) -> Kernel { + let name = options.name.unwrap_or("".to_string()); + let name = Arc::new(CString::new(name).unwrap()); + let shader_options = api::ShaderOption { + enable_cache: options.enable_cache, + enable_fast_math: options.enable_fast_math, + enable_debug_info: options.enable_debug_info, + compile_only: false, + name: name.as_ptr(), + }; + let module = k.inner.module.clone(); + let artifact = if options.async_compile { + ShaderArtifact::Async(AsyncShaderArtifact::new( + self.clone(), + module.clone(), + shader_options, + name, + )) + } else { + ShaderArtifact::Sync(self.inner.create_shader(&module, &shader_options)) + }; + Kernel { + inner: RawKernel { + device: self.clone(), + artifact, + module, + resource_tracker: k.inner.resource_tracker.clone(), + }, + _marker: PhantomData {}, + } + } } pub(crate) enum StreamHandle { Default { @@ -793,31 +759,6 @@ impl<'a> Scope<'a> { self } } -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: Command<'a>) -> Self::Output { - self.submit(std::iter::once(rhs)); - self - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: EventSignal<'a>) -> Self::Output { - self.signal(rhs.event, rhs.ticket) - } -} -impl<'a> std::ops::Shl> for &'a Scope<'a> { - type Output = Self; - #[inline] - #[allow(unused_must_use)] - fn shl(self, rhs: EventWait<'a>) -> Self::Output { - self.wait(rhs.event, rhs.ticket) - } -} impl<'a> Drop for Scope<'a> { fn drop(&mut self) { if !self.synchronized.get() { @@ -939,7 +880,7 @@ pub struct RawKernel { pub(crate) artifact: ShaderArtifact, #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, - pub(crate) module: CArc, + pub(crate) module: CArc, } pub struct CallableArgEncoder { @@ -1111,17 +1052,17 @@ macro_rules! impl_kernel_arg_for_tuple { fn encode(&self, _: &mut KernelArgEncoder) { } } }; - ($first:ident $($rest:ident) *) => { - impl<$first:KernelArg, $($rest: KernelArg),*> KernelArg for ($first, $($rest,)*) { - type Parameter = ($first::Parameter, $($rest::Parameter),*); + ($first:ident $($Ts:ident) *) => { + impl<$first:KernelArg, $($Ts: KernelArg),*> KernelArg for ($first, $($Ts,)*) { + type Parameter = ($first::Parameter, $($Ts::Parameter),*); #[allow(non_snake_case)] fn encode(&self, encoder: &mut KernelArgEncoder) { - let ($first, $($rest,)*) = self; + let ($first, $($Ts,)*) = self; $first.encode(encoder); - $($rest.encode(encoder);)* + $($Ts.encode(encoder);)* } } - impl_kernel_arg_for_tuple!($($rest)*); + impl_kernel_arg_for_tuple!($($Ts)*); }; } @@ -1142,6 +1083,7 @@ impl RawKernel { } } } + pub fn dispatch_async(&self, args: KernelArgEncoder, dispatch_size: [u32; 3]) -> Command { let mut rt = ResourceTracker::new(); rt.add(Arc::new(args.uniform_data)); @@ -1165,23 +1107,23 @@ impl RawKernel { } } -pub struct Callable> { +pub struct Callable { #[allow(dead_code)] pub(crate) inner: RawCallable, pub(crate) _marker: PhantomData, } -pub(crate) struct DynCallableInner> { +pub(crate) struct DynCallableInner { builder: Box, &mut KernelBuilder) -> Callable>, callables: Vec>, } -pub struct DynCallable> { +pub struct DynCallable { #[allow(dead_code)] pub(crate) inner: RefCell>, pub(crate) device: Device, pub(crate) init_once: bool, } -impl> DynCallable { - pub(crate) fn new( +impl DynCallable { + pub(crate) fn _new( device: Device, init_once: bool, builder: Box, &mut KernelBuilder) -> Callable>, @@ -1253,14 +1195,45 @@ pub struct RawCallable { #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, } +pub struct RawKernelDef { + #[allow(dead_code)] + pub(crate) device: Option, + pub(crate) module: CArc, + #[allow(dead_code)] + pub(crate) resource_tracker: ResourceTracker, +} -pub struct Kernel> { +/// A kernel definition +/// See [`Kernel`] for more information +pub struct KernelDef { + pub(crate) inner: RawKernelDef, + pub(crate) _marker: PhantomData, +} + +/// An executable kernel +/// Kernel creation can be done in multiple ways: +/// - Seperate recording and compilation: +/// ```no_run +/// // Recording: +/// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); +/// // Compilation: +/// let kernel = device.compile_kernel(&kernel); +/// ``` +/// - Recording and compilation in one step: +/// ```no_run +/// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); +/// ``` +/// - Asynchronous compilation use [`Kernel::::new_async`] +/// - Custom build options using [`Kernel::::new_with_options`] +/// ``` +/// +pub struct Kernel { pub(crate) inner: RawKernel, pub(crate) _marker: PhantomData, } -unsafe impl> Send for Kernel {} -unsafe impl> Sync for Kernel {} -impl> Kernel { +unsafe impl Send for Kernel {} +unsafe impl Sync for Kernel {} +impl Kernel { pub fn cache_dir(&self) -> Option { let handle = self.inner.unwrap(); let device = &self.inner.device; @@ -1302,80 +1275,101 @@ impl AsKernelArg> for Tex3d {} impl AsKernelArg for BindlessArray {} impl AsKernelArg for Accel {} + macro_rules! impl_call_for_callable { - ($first:ident $($rest:ident)*) => { - impl CallableR> { + ( $($Ts:ident)*) => { + impl CallableR> { #[allow(non_snake_case)] - pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { + #[allow(unused_mut)] + pub fn call(&self, $($Ts:$Ts),*) -> R { let mut encoder = CallableArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* CallableRet::_from_return( crate::lang::__invoke_callable(&self.inner.module, &encoder.args)) } } - impl DynCallableR> { + impl DynCallableR> { #[allow(non_snake_case)] - pub fn call(&self, $first:$first, $($rest:$rest),*) -> R { + #[allow(unused_mut)] + pub fn call(&self, $($Ts:$Ts),*) -> R { let mut encoder = CallableArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* - self.call_impl(std::rc::Rc::new(($first, $($rest,)*)), &encoder.args) + $($Ts.encode(&mut encoder);)* + self.call_impl(std::rc::Rc::new(($($Ts,)*)), &encoder.args) } } - impl_call_for_callable!($($rest)*); }; - ()=>{ - impl CallableR> { - pub fn call(&self)->R { - CallableRet::_from_return( - crate::lang::__invoke_callable(&self.inner.module, &[])) - } - } - impl DynCallableR> { - pub fn call(&self)-> R{ - self.call_impl(std::rc::Rc::new(()), &[]) - } - } - } } +impl_call_for_callable!(); +impl_call_for_callable!(T0); +impl_call_for_callable!(T0 T1 ); +impl_call_for_callable!(T0 T1 T2 ); +impl_call_for_callable!(T0 T1 T2 T3 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_call_for_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + macro_rules! impl_dispatch_for_kernel { - ($first:ident $($rest:ident)*) => { - impl <$first:KernelArg+'static, $($rest: KernelArg+'static),*> Kernel { + ($($Ts:ident)*) => { + impl <$($Ts: KernelArg+'static),*> Kernel { #[allow(non_snake_case)] - pub fn dispatch(&self, dispatch_size: [u32; 3], $first:&impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),*) { + #[allow(unused_mut)] + pub fn dispatch(&self, dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),*) { let mut encoder = KernelArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* self.inner.dispatch(encoder, dispatch_size) } #[allow(non_snake_case)] + #[allow(unused_mut)] pub fn dispatch_async<'a>( &'a self, - dispatch_size: [u32; 3], $first: &impl AsKernelArg<$first>, $($rest:&impl AsKernelArg<$rest>),* + dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),* ) -> Command<'a> { let mut encoder = KernelArgEncoder::new(); - $first.encode(&mut encoder); - $($rest.encode(&mut encoder);)* + $($Ts.encode(&mut encoder);)* self.inner.dispatch_async(encoder, dispatch_size) } + /// Blocks until the kernel is compiled + pub fn ensure_ready(&self) { + self.inner.unwrap(); + } } - impl_dispatch_for_kernel!($($rest)*); }; - ()=>{ - impl Kernel { - pub fn dispatch(&self, dispatch_size: [u32; 3]) { - self.inner.dispatch(KernelArgEncoder::new(), dispatch_size) - } - pub fn dispatch_async<'a>( - &'a self, - dispatch_size: [u32; 3], - ) -> Command<'a> { - self.inner.dispatch_async(KernelArgEncoder::new(), dispatch_size) - } - } -} } + +impl_dispatch_for_kernel!(); +impl_dispatch_for_kernel!(T0); +impl_dispatch_for_kernel!(T0 T1 ); +impl_dispatch_for_kernel!(T0 T1 T2 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_dispatch_for_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +#[macro_export] +macro_rules! remove_last { + ($x:ident) => { + + }; + ($first:ident $($xs:ident)*) => { + $first remove_last!($($xs)*) + }; +} diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index ea118b9c..bd738927 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -348,14 +348,13 @@ impl KernelBuilder { } }) } - fn build_kernel( + fn build_kernel( &mut self, - options: KernelBuildOptions, body: impl FnOnce(&mut Self), - ) -> crate::runtime::RawKernel { + ) -> crate::runtime::KernelDef { body(self); let (rt, cpu_custom_ops, captures) = self.collect_module_info(); - RECORDER.with(|r| -> crate::runtime::RawKernel { + RECORDER.with(|r| -> crate::runtime::KernelDef { let mut r = r.borrow_mut(); assert!(r.lock); r.lock = false; @@ -380,45 +379,28 @@ impl KernelBuilder { block_size: r.block_size.unwrap_or([64, 1, 1]), pools: r.pools.clone().unwrap(), }; - - let module = CArc::new(module); - let name = options.name.unwrap_or("".to_string()); - let name = Arc::new(CString::new(name).unwrap()); - let shader_options = api::ShaderOption { - enable_cache: options.enable_cache, - enable_fast_math: options.enable_fast_math, - enable_debug_info: options.enable_debug_info, - compile_only: false, - name: name.as_ptr(), - }; - let artifact = if options.async_compile { - ShaderArtifact::Async(AsyncShaderArtifact::new( - self.device.clone().unwrap(), - module.clone(), - shader_options, - name, - )) - } else { - ShaderArtifact::Sync( - self.device - .as_ref() - .unwrap() - .inner - .create_shader(&module, &shader_options), - ) - }; - // r.reset(); - RawKernel { - artifact, - device: self.device.clone().unwrap(), - resource_tracker: rt, - module, + + KernelDef { + inner: RawKernelDef { + device: self.device.clone(), + resource_tracker: rt, + module: CArc::new(module), + }, + _marker: PhantomData, } }) } } +/// Build options for kernel compilation +/// * `enable_debug_info`: enable debug info, default true on debug build +/// * `enable_optimization`: enable optimization, default true +/// * `async_compile`: compile the kernel asynchronously +/// * `enable_cache`: enable cache for the compiled kernel +/// * `enable_fast_math`: enable fast math in the compiled kernel +/// * `name`: name of the compiled kernel. On CUDA backend, this is the name of the generated PTX kernel +/// #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct KernelBuildOptions { pub enable_debug_info: bool, @@ -445,21 +427,12 @@ impl Default for KernelBuildOptions { } } } - -pub trait KernelBuildFn { - fn build_kernel( - &self, - builder: &mut KernelBuilder, - options: KernelBuildOptions, - ) -> crate::runtime::RawKernel; -} - -pub trait CallableBuildFn { +pub trait CallableBuildFn { fn build_callable(&self, args: Option>, builder: &mut KernelBuilder) -> RawCallable; } -pub trait StaticCallableBuildFn: CallableBuildFn {} +pub trait StaticCallableBuildFn: CallableBuildFn {} // @FIXME: this looks redundant pub unsafe trait CallableRet { @@ -486,204 +459,206 @@ unsafe impl CallableRet for Expr { } } -pub trait CallableSignature<'a> { - type Callable; - type DynCallable; - type Fn: CallableBuildFn; - type StaticFn: StaticCallableBuildFn; - type DynFn: CallableBuildFn + 'static; +pub trait CallableSignature { type Ret: CallableRet; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable; - fn create_dyn_callable(device: Device, init_once: bool, f: Self::DynFn) -> Self::DynCallable; } -pub trait KernelSignature<'a> { - type Fn: KernelBuildFn; - type Kernel; - - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel; -} -macro_rules! impl_callable_signature { - ()=>{ - impl<'a, R: CallableRet +'static> CallableSignature<'a> for fn()->R { - type Fn = &'a dyn Fn() ->R; - type DynFn = BoxR>; - type StaticFn = fn() -> R; - type Callable = CallableR>; - type DynCallable = DynCallableR>; +pub trait KernelSignature {} +macro_rules! impl_callable { + ($($Ts:ident)*) => { + impl CallableSignature for fn($($Ts,)*)->R { type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:PhantomData, + } + impl CallableR> { + pub fn newR>(device: &Device, f:F)->Self where F:CallableBuildFnR> { + let mut builder = KernelBuilder::new(Some(device.clone()), false); + let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); + Self{ + inner: raw_callable, + _marker: PhantomData, } } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { - let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) - })) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, R:CallableRet +'static, $first:CallableParameter +'static, $($rest: CallableParameter +'static),*> CallableSignature<'a> for fn($first, $($rest,)*)->R { - type Fn = &'a dyn Fn($first, $($rest),*)->R; - type DynFn = BoxR>; - type Callable = CallableR>; - type StaticFn = fn($first, $($rest,)*)->R; - type DynCallable = DynCallableR>; - type Ret = R; - fn wrap_raw_callable(callable: RawCallable) -> Self::Callable{ - Callable { - inner: callable, - _marker:PhantomData, + pub fn new_static(f:fn($($Ts,)*)->R)->Self where fn($($Ts,)*)->R :CallableBuildFnR> { + let r_backup = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + std::mem::replace(&mut *r, Recorder::new()) + }); + let mut builder = KernelBuilder::new(None, false); + let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); + RECORDER.with(|r| { + *r.borrow_mut() = r_backup; + }); + Self{ + inner: raw_callable, + _marker: PhantomData, } } - fn create_dyn_callable(device:Device, init_once:bool, f: Self::DynFn) -> Self::DynCallable { - DynCallable::new(device, init_once, Box::new(move |arg, builder| { + } + impl DynCallableR> { + pub fn new(device: &Device, f:BoxR>)->Self where BoxR> : CallableBuildFnR> { + DynCallable::_new(device.clone(), false, Box::new(move |arg, builder| { let raw_callable = CallableBuildFn::build_callable(&f, Some(arg), builder); - Self::wrap_raw_callable(raw_callable) + Callable { + inner: raw_callable, + _marker: PhantomData, + } })) } } - impl_callable_signature!($($rest)*); - }; -} -impl_callable_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_signature { - ()=>{ - impl<'a> KernelSignature<'a> for fn() { - type Fn = &'a dyn Fn(); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:PhantomData, - } - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<'a, $first:KernelArg +'static, $($rest: KernelArg +'static),*> KernelSignature<'a> for fn($first, $($rest,)*) { - type Fn = &'a dyn Fn($first::Parameter, $($rest::Parameter),*); - type Kernel = Kernel; - fn wrap_raw_kernel(kernel: crate::runtime::RawKernel) -> Self::Kernel { - Self::Kernel{ - inner:kernel, - _marker:PhantomData, - } - } - } - impl_kernel_signature!($($rest)*); }; } -impl_kernel_signature!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_callable_build_for_fn { - ()=>{ - impl CallableBuildFn for &dyn Fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() +impl_callable!(); +impl_callable!(T0); +impl_callable!(T0 T1 ); +impl_callable!(T0 T1 T2 ); +impl_callable!(T0 T1 T2 T3 ); +impl_callable!(T0 T1 T2 T3 T4 ); +impl_callable!(T0 T1 T2 T3 T4 T5 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); +impl_callable!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +macro_rules! impl_kernel { + ($($Ts:ident)*) => { + impl<$($Ts: KernelArg +'static),*> KernelSignature for fn($($Ts,)*) {} + impl<$($Ts: KernelArg +'static),*> KernelDef { + #[allow(non_snake_case)] + #[allow(unused_variables)] + pub fn new_maybe_device(device: Option<&Device>, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let mut builder = KernelBuilder::new(device.cloned(), true); + builder.build_kernel(move |builder| { + $(let $Ts = <$Ts::Parameter as KernelParameter>::def_param(builder);)* + (f)($($Ts,)*) }) } - } - impl CallableBuildFn for fn()->R { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) + pub fn new(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + Self::new_maybe_device(Some(device), f) } - } - impl CallableBuildFn for BoxR> { - fn build_callable(&self, _args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |_| { - self() - }) + pub fn new_static(f:fn($($Ts::Parameter,)*))->Self { + Self::new_maybe_device(None, f) } } - impl StaticCallableBuildFn for fn()->R {} - }; - ($first:ident $($rest:ident)*) => { - impl CallableBuildFn for &dyn Fn($first, $($rest,)*)->R { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) + impl<$($Ts: KernelArg +'static),*> Kernel { + /// Compile a kernel with given recording function `f`. + pub fn new(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel(&def) } - } - impl CallableBuildFn for BoxR> { - #[allow(non_snake_case)] - fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { - builder.build_callable( |builder| { - if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) - } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) - } - }) + /// Compile a kernel asynchronously with given recording function `f`. + /// This function returns immediately after `f` returns + + pub fn new_async(device: &Device, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel_async(&def) + } + + // Compile a kernel with given recording function `f` and build options [`KernelBuildOptions`] + pub fn new_with_options(device: &Device, options: KernelBuildOptions, f:impl FnOnce($($Ts::Parameter,)*))->Self { + let def = KernelDef::::new(device, f); + device.compile_kernel_with_options(&def, options) } } - impl CallableBuildFn for fn($first, $($rest,)*)->R { + }; +} + +impl_kernel!(); +impl_kernel!(T0); +impl_kernel!(T0 T1 ); +impl_kernel!(T0 T1 T2 ); +impl_kernel!(T0 T1 T2 T3 ); +impl_kernel!(T0 T1 T2 T3 T4 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); +impl_kernel!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); + +macro_rules! impl_callable_build_for_fn { + ($($Ts:ident)*) => { + impl CallableBuildFnR> for T + where T: Fn($($Ts,)*)->R + 'static { #[allow(non_snake_case)] + #[allow(unused_variables)] fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { builder.build_callable( |builder| { if let Some(args) = args { - let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); - let $first = $first::def_param(Some(Rc::new($first)), builder); - $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* - self($first, $($rest,)*) + let ($($Ts,)*) = args.downcast_ref::<($($Ts,)*)>().cloned().unwrap(); + $(let $Ts = $Ts::def_param(Some(Rc::new($Ts)), builder);)* + self($($Ts,)*) } else { - let $first = $first::def_param(None, builder); - $(let $rest = $rest::def_param(None, builder);)* - self($first, $($rest,)*) + $(let $Ts = $Ts::def_param(None, builder);)* + self($($Ts,)*) } }) } } - impl StaticCallableBuildFn for fn($first, $($rest,)*)->R {} - impl_callable_build_for_fn!($($rest)*); + // impl CallableBuildFn for BoxR> { + // #[allow(non_snake_case)] + // fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + // builder.build_callable( |builder| { + // if let Some(args) = args { + // let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + // let $first = $first::def_param(Some(Rc::new($first)), builder); + // $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + // self($first, $($rest,)*) + // } else { + // let $first = $first::def_param(None, builder); + // $(let $rest = $rest::def_param(None, builder);)* + // self($first, $($rest,)*) + // } + // }) + // } + // } + // impl CallableBuildFn for fn($first, $($rest,)*)->R { + // #[allow(non_snake_case)] + // fn build_callable(&self, args: Option>, builder: &mut KernelBuilder)->RawCallable { + // builder.build_callable( |builder| { + // if let Some(args) = args { + // let ($first, $($rest,)*) = args.downcast_ref::<($first, $($rest,)*)>().cloned().unwrap(); + // let $first = $first::def_param(Some(Rc::new($first)), builder); + // $(let $rest = $rest::def_param(Some(Rc::new($rest)), builder);)* + // self($first, $($rest,)*) + // } else { + // let $first = $first::def_param(None, builder); + // $(let $rest = $rest::def_param(None, builder);)* + // self($first, $($rest,)*) + // } + // }) + // } + // } + impl StaticCallableBuildFnR> for fn($($Ts,)*)->R + where fn($($Ts,)*)->R : CallableBuildFnR> {} }; } + +impl_callable_build_for_fn!(); +impl_callable_build_for_fn!(T0); +impl_callable_build_for_fn!(T0 T1 ); +impl_callable_build_for_fn!(T0 T1 T2 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 ); +impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 ); impl_callable_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); -macro_rules! impl_kernel_build_for_fn { - ()=>{ - impl KernelBuildFn for &dyn Fn() { - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |_| { - self() - }) - } - } - }; - ($first:ident $($rest:ident)*) => { - impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelBuildFn for &dyn Fn($first, $($rest,)*) { - #[allow(non_snake_case)] - fn build_kernel(&self, builder: &mut KernelBuilder, options:KernelBuildOptions) -> crate::runtime::RawKernel { - builder.build_kernel(options, |builder| { - let $first = $first::def_param(builder); - $(let $rest = $rest::def_param(builder);)* - self($first, $($rest,)*) - }) - } - } - impl_kernel_build_for_fn!($($rest)*); - }; -} -impl_kernel_build_for_fn!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 1888326e..98b43734 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -74,7 +74,7 @@ fn autodiff_helper]) -> Expr>( // inputs[i].view(..).copy_from(&tmp); // } println!("init time: {:?}", tic.elapsed()); - let kernel = device.create_kernel_async::(&|| { + let kernel = Kernel::::new_async(&device, || { let input_vars = inputs.iter().map(|input| input.var()).collect::>(); let grad_fd_vars = grad_fd.iter().map(|grad| grad.var()).collect::>(); let grad_ad_vars = grad_ad.iter().map(|grad| grad.var()).collect::>(); @@ -707,23 +707,26 @@ fn autodiff_select() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = select(x > y, x * 4.0, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = select(x > y, x * 4.0, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -751,24 +754,27 @@ fn autodiff_detach() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let k = detach(x * y); - let z = (x + y) * k; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let k = detach(x * y); + let z = (x + y) * k; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -801,25 +807,28 @@ fn autodiff_select_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x > y; - let a = (x - y).sqrt(); - let z = select(cond, a, y * 0.5); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x > y; + let a = (x - y).sqrt(); + let z = select(cond, a, y * 0.5); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -842,30 +851,33 @@ fn autodiff_if_nan() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen::() + 10.0); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let cond = x > y; - let z = if cond { - let a = (x - y).sqrt(); - a - } else { - y * 0.5 - }; - // cpu_dbg!(f32, z); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let cond = x > y; + let z = if cond { + let a = (x - y).sqrt(); + a + } else { + y * 0.5 + }; + // cpu_dbg!(f32, z); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -893,25 +905,28 @@ fn autodiff_if_phi() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - if true.expr() { - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if x > y { x * 4.0 } else { y * 0.5 }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - } - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + if true.expr() { + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if x > y { x * 4.0 } else { y * 0.5 }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + } + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -939,31 +954,34 @@ fn autodiff_if_phi2() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if x > y { - if x > 3.0 { - x * 4.0 + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = if x > y { + if x > 3.0 { + x * 4.0 + } else { + x * 2.0 + } } else { - x * 2.0 - } - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -995,37 +1013,40 @@ fn autodiff_if_phi3() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let const_two = 2.0_f32.var(); - let const_three = 3.0_f32.var(); - let const_four = f32::var_zeroed(); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let const_two = 2.0_f32.var(); + let const_three = 3.0_f32.var(); + let const_four = f32::var_zeroed(); - autodiff(|| { - requires_grad(x); - requires_grad(y); - const_four.store(4.0); - let c = (x > const_three).as_::(); - let z = if x > y { - switch::>(c) - .case(0, || x * const_two) - .default(|| x * const_four) - .finish() - * const_two - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + autodiff(|| { + requires_grad(x); + requires_grad(y); + const_four.store(4.0); + let c = (x > const_three).as_::(); + let z = if x > y { + switch::>(c) + .case(0, || x * const_two) + .default(|| x * const_four) + .finish() + * const_two + } else { + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1057,38 +1078,41 @@ fn autodiff_if_phi4() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); - let consts = Float3::var_zeroed(); - autodiff(|| { - requires_grad(x); - requires_grad(y); - *consts = Float3::expr(2.0, 3.0, 4.0); - let const_two = consts.x; - let const_three = consts.y; - let const_four = consts.z; - let c = (x > const_three).as_::(); - let z = if x > y { - switch::>(c) - .case(0, || x * const_two) - .default(|| x * const_four) - .finish() - * const_two - } else { - y * 0.5 - }; - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let consts = Float3::var_zeroed(); + autodiff(|| { + requires_grad(x); + requires_grad(y); + *consts = Float3::expr(2.0, 3.0, 4.0); + let const_two = consts.x; + let const_three = consts.y; + let const_four = consts.z; + let c = (x > const_three).as_::(); + let z = if x > y { + switch::>(c) + .case(0, || x * const_two) + .default(|| x * const_four) + .finish() + * const_two + } else { + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1122,29 +1146,32 @@ fn autodiff_switch() { t.view(..).fill_fn(|_| rng.gen_range(0..3)); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let buf_t = t.var(); - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let t = buf_t.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = switch::>(t) - .case(0, || x * 4.0) - .case(1, || x * 2.0) - .case(2, || y * 0.5) - .finish(); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = switch::>(t) + .case(0, || x * 4.0) + .case(1, || x * 2.0) + .case(2, || y * 0.5) + .finish(); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); @@ -1182,7 +1209,7 @@ fn autodiff_callable() { x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); let callable = - device.create_callable::, Var, Expr)>(track!(&|vx, vy, t| { + Callable::, Var, Expr)>::new(&device, track!(|vx, vy, t| { let x = **vx; let y = **vy; autodiff(|| { @@ -1198,22 +1225,25 @@ fn autodiff_callable() { *vy = gradient(y); }); })); - let kernel = device.create_kernel::(&track!(|| { - let buf_t = t.var(); - let buf_x = x.var(); - let buf_y = y.var(); - let buf_dx = dx.var(); - let buf_dy = dy.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let t = buf_t.read(tid); - let dx = x.var(); - let dy = y.var(); - callable.call(dx, dy, t); - buf_dx.write(tid, dx); - buf_dy.write(tid, dy); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + let dx = x.var(); + let dy = y.var(); + callable.call(dx, dy, t); + buf_dx.write(tid, dx); + buf_dy.write(tid, dy); + }), + ); kernel.dispatch([1024, 1, 1]); let dx = dx.view(..).copy_to_vec(); let dy = dy.view(..).copy_to_vec(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index f10b83a4..439fc5e4 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1,5 +1,5 @@ use luisa::lang::types::array::VLArrayVar; -use luisa::lang::types::core::*; +use luisa::lang::types::dynamic::*; use luisa::lang::types::vector::alias::*; use luisa::prelude::*; use luisa_compute as luisa; @@ -15,46 +15,32 @@ fn event() { let a: Buffer = device.create_buffer_from_slice(&[0]); let b: Buffer = device.create_buffer_from_slice(&[0]); // compute (1 + 3) * (4 + 5) - let add = device.create_kernel::, i32)>(&|buf: BufferVar, v: Expr| { + let add = Kernel::, i32)>::new(&device, |buf: BufferVar, v: Expr| { track!(buf.write(0, buf.read(0) + v)); }); - let mul = device.create_kernel::, Buffer)>( - &|a: BufferVar, b: BufferVar| { - track!(a.write(0, a.read(0) * b.read(0))); - }, - ); + let mul = Kernel::, Buffer)>::new(&device, |a, b| { + track!(a.write(0, a.read(0) * b.read(0))); + }); let stream_a = device.create_stream(StreamTag::Compute); let stream_b = device.create_stream(StreamTag::Compute); { let scope_a = stream_a.scope(); let scope_b = stream_b.scope(); let event = device.create_event(); - let _ = &scope_a - << add.dispatch_async([1, 1, 1], &a, &1) - << add.dispatch_async([1, 1, 1], &b, &4) - << event.signal(1); - let _ = &scope_b - << event.wait(1) - << add.dispatch_async([1, 1, 1], &a, &3) - << add.dispatch_async([1, 1, 1], &b, &5) - << event.signal(2); - let _ = - &scope_a << event.wait(2) << mul.dispatch_async([1, 1, 1], &a, &b) << event.signal(3); + scope_a + .submit([add.dispatch_async([1, 1, 1], &a, &1)]) + .submit([add.dispatch_async([1, 1, 1], &b, &4)]) + .signal(&event, 1); + scope_b + .wait(&event, 1) + .submit([add.dispatch_async([1, 1, 1], &a, &3)]) + .submit([add.dispatch_async([1, 1, 1], &b, &5)]) + .signal(&event, 2); + scope_a + .wait(&event, 2) + .submit([mul.dispatch_async([1, 1, 1], &a, &b)]) + .signal(&event, 3); event.synchronize(3); - // scope_a - // .submit([add.dispatch_async([1, 1, 1], &a, &1)]) - // .submit([add.dispatch_async([1, 1, 1], &b, &4)]) - // .signal(&event, 1); - // scope_b - // .wait(&event, 1) - // .submit([add.dispatch_async([1, 1, 1], &a, &3)]) - // .submit([add.dispatch_async([1, 1, 1], &b, &5)]) - // .signal(&event, 2); - // scope_a - // .wait(&event, 2) - // .submit([mul.dispatch_async([1, 1, 1], &a, &b)]) - // .signal(&event, 3); - // event.synchronize(3); } let v = a.copy_to_vec(); assert_eq!(v[0], (1 + 3) * (4 + 5)); @@ -63,57 +49,68 @@ fn event() { #[should_panic] fn callable_return_mismatch() { let device = get_device(); - let _abs = device.create_callable::) -> Expr>(&track!(|x| { - if x > 0.0 { - return true.expr(); - } - -x - })); + let _abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return true.expr(); + } + -x + }), + ); } #[test] #[should_panic] fn callable_return_mismatch2() { let device = get_device(); - let _abs = device.create_callable::) -> Expr>(&track!(|x| { - if x > 0.0 { - return; - } - -x - })); + let _abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return; + } + -x + }), + ); } #[test] #[should_panic] fn callable_return_void_mismatch() { let device = get_device(); - let _abs = device.create_callable::)>(&track!(|x| { - if x > 0.0 { - return true.expr(); - } - *x = -x; - })); + let _abs = Callable::)>::new( + &device, + track!(|x| { + if x > 0.0 { + return true.expr(); + } + *x = -x; + }), + ); } #[test] fn callable_early_return() { let device = get_device(); - let abs = device.create_callable::) -> Expr>(track!(&|x| { - if x > 0.0 { - return x; - } - -x - })); + let abs = Callable::) -> Expr>::new( + &device, + track!(|x| { + if x > 0.0 { + return x; + } + -x + }), + ); let x = device.create_buffer::(1024); let mut rng = StdRng::seed_from_u64(0); x.fill_fn(|_| rng.gen()); let y = device.create_buffer::(1024); - device - .create_kernel::(&|| { - let i = dispatch_id().x; - let x = x.var().read(i); - let y = y.var(); - y.write(i, abs.call(x)); - }) - .dispatch([x.len() as u32, 1, 1]); + Kernel::::new(&device, || { + let i = dispatch_id().x; + let x = x.var().read(i); + let y = y.var(); + y.write(i, abs.call(x)); + }) + .dispatch([x.len() as u32, 1, 1]); let x = x.copy_to_vec(); let y = y.copy_to_vec(); for i in 0..x.len() { @@ -123,31 +120,34 @@ fn callable_early_return() { #[test] fn callable() { let device = get_device(); - let write = device.create_callable::, Expr, Var)>( - &|buf: BufferVar, i: Expr, v: Var| { + let write = Callable::, Expr, Var)>::new( + &device, + |buf: BufferVar, i: Expr, v: Var| { buf.write(i, v.load()); track!(*v+=1;) }, ); - let add = - device.create_callable::, Expr) -> Expr>(&|a, b| track!(a + b)); + let add = Callable::, Expr) -> Expr>::new(&device, |a, b| track!(a + b)); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); let z = device.create_buffer::(1024); let w = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as u32); y.view(..).fill_fn(|i| 1000 * i as u32); - let kernel = device.create_kernel::)>(&track!(|buf_z| { - let buf_x = x.var(); - let buf_y = y.var(); - let buf_w = w.var(); - let tid = dispatch_id().x; - let x = buf_x.read(tid); - let y = buf_y.read(tid); - let z = add.call(x, y).var(); - write.call(buf_z, tid, z); - buf_w.write(tid, z); - })); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_w = w.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let z = add.call(x, y).var(); + write.call(buf_z, tid, z); + buf_w.write(tid, z); + }), + ); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); let w_data = w.view(..).copy_to_vec(); @@ -163,7 +163,12 @@ fn vec_cast() { let i: Buffer = device.create_buffer(1024); f.view(..) .fill_fn(|i| Float2::new(i as f32 + 0.5, i as f32 + 1.5)); - let kernel = device.create_kernel_with_options::( + let kernel = Kernel::::new_with_options( + &device, + KernelBuildOptions { + name: Some("vec_cast".to_string()), + ..KernelBuildOptions::default() + }, &|| { let f = f.var(); let i = i.var(); @@ -171,10 +176,6 @@ fn vec_cast() { let v = f.read(tid); i.write(tid, v.as_int2()); }, - KernelBuildOptions { - name: Some("vec_cast".to_string()), - ..KernelBuildOptions::default() - }, ); kernel.dispatch([1024, 1, 1]); let mut i_data = vec![Int2::new(0, 0); 1024]; @@ -196,19 +197,22 @@ fn bool_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| rng.gen()); y.view(..).fill_fn(|_| rng.gen()); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -237,19 +241,22 @@ fn bvec_op() { let mut rng = rand::thread_rng(); x.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Bool2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -275,16 +282,19 @@ fn test_var_replace() { let device = get_device(); let xs: Buffer = device.create_buffer(1024); let ys: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = xs.var().read(tid).var(); - *x = Int4::expr(1, 2, 3, 4); - let y = **x; - *x.y = 10; - *x.z = 20; - xs.write(tid, x); - ys.write(tid, y); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = xs.var().read(tid).var(); + *x = Int4::expr(1, 2, 3, 4); + let y = **x; + *x.y = 10; + *x.z = 20; + xs.write(tid, x); + ys.write(tid, y); + }), + ); kernel.dispatch([1024, 1, 1]); let xs = xs.view(..).copy_to_vec(); let ys = ys.view(..).copy_to_vec(); @@ -316,26 +326,29 @@ fn vec_bit_minmax() { x.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); y.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); z.view(..).fill_fn(|_| Int2::new(rng.gen(), rng.gen())); - let kernel = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let x = x.var().read(tid); - let y = y.var().read(tid); - let z = z.var().read(tid); - let and = and.var(); - let or = or.var(); - let xor = xor.var(); - let not = not.var(); - let min = min.var(); - let max = max.var(); - let clamp = clamp.var(); - and.write(tid, x & y); - or.write(tid, x | y); - xor.write(tid, x ^ y); - not.write(tid, !x); - min.write(tid, luisa::min(x, y)); - max.write(tid, luisa::max(x, y)); - clamp.write(tid, z.clamp(luisa::min(x, y), luisa::max(x, y))); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let x = x.var().read(tid); + let y = y.var().read(tid); + let z = z.var().read(tid); + let and = and.var(); + let or = or.var(); + let xor = xor.var(); + let not = not.var(); + let min = min.var(); + let max = max.var(); + let clamp = clamp.var(); + and.write(tid, x & y); + or.write(tid, x | y); + xor.write(tid, x ^ y); + not.write(tid, !x); + min.write(tid, luisa::min(x, y)); + max.write(tid, luisa::max(x, y)); + clamp.write(tid, z.clamp(luisa::min(x, y), luisa::max(x, y))); + }), + ); kernel.dispatch([1024, 1, 1]); let x = x.view(..).copy_to_vec(); let y = y.view(..).copy_to_vec(); @@ -375,7 +388,7 @@ fn vec_permute() { let v3: Buffer = device.create_buffer(1024); v2.view(..) .fill_fn(|i| Int2::new(i as i32 + 0, i as i32 + 1)); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let v2 = v2.var(); let v3 = v3.var(); let tid = dispatch_id().x; @@ -398,18 +411,21 @@ fn if_phi() { let x: Buffer = device.create_buffer(1024); let even: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::(&track!(|| { - let x = x.var(); - let even = even.var(); - let tid = dispatch_id().x; - let v = x.read(tid); - let result = if v % 2 == 0 { - true.expr() - } else { - false.expr() - }; - even.write(tid, result); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let x = x.var(); + let even = even.var(); + let tid = dispatch_id().x; + let v = x.read(tid); + let result = if v % 2 == 0 { + true.expr() + } else { + false.expr() + }; + even.write(tid, result); + }), + ); kernel.dispatch([1024, 1, 1]); let mut i_data = vec![false; 1024]; even.view(..).copy_to(&mut i_data); @@ -425,7 +441,7 @@ fn switch_phi() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -472,7 +488,7 @@ fn switch_unreachable() { let y: Buffer = device.create_buffer(1024); let z: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32 % 3); - let kernel = device.create_kernel::(&|| { + let kernel = Kernel::::new(&device, || { let buf_x = x.var(); let buf_y = y.var(); let buf_z = z.var(); @@ -514,17 +530,20 @@ fn switch_unreachable() { fn array_read_write() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -538,15 +557,18 @@ fn array_read_write() { fn array_read_write3() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - for_range(0..4u32, |i| { - arr.write(i, tid.as_i32() + i.as_i32()); - }); - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + for_range(0..4u32, |i| { + arr.write(i, tid.as_i32() + i.as_i32()); + }); + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -560,17 +582,20 @@ fn array_read_write3() { fn array_read_write4() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - for_range(0..6u32, |_| { - for_range(0..4u32, |i| { - arr.write(i, arr.read(i) + tid.as_i32() + i.as_i32()); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + for_range(0..6u32, |_| { + for_range(0..4u32, |i| { + arr.write(i, arr.read(i) + tid.as_i32() + i.as_i32()); + }); }); - }); - buf_x.write(tid, arr); - })); + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -590,19 +615,22 @@ fn array_read_write2() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - buf_y.write(tid, arr.read(0)); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + buf_y.write(tid, arr.read(0)); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); let y_data = y.view(..).copy_to_vec(); @@ -619,25 +647,28 @@ fn array_read_write_vla() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let buf_y = y.var(); - let tid = dispatch_id().x; - let vl = VLArrayVar::::zero(4); - let i = i32::var_zeroed(); - while i < 4 { - vl.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), vl.read(i.as_u32())); - *i += 1; - } - buf_x.write(tid, arr); - buf_y.write(tid, arr.read(0)); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let vl = VLArrayVar::::zero(4); + let i = i32::var_zeroed(); + while i < 4 { + vl.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), vl.read(i.as_u32())); + *i += 1; + } + buf_x.write(tid, arr); + buf_y.write(tid, arr.read(0)); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); let y_data = y.view(..).copy_to_vec(); @@ -653,17 +684,20 @@ fn array_read_write_vla() { fn array_read_write_async_compile() { let device = get_device(); let x: Buffer<[i32; 4]> = device.create_buffer(1024); - let kernel = device.create_kernel::(&track!(|| { - let buf_x = x.var(); - let tid = dispatch_id().x; - let arr = Var::<[i32; 4]>::zeroed(); - let i = i32::var_zeroed(); - while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); - *i += 1; - } - buf_x.write(tid, arr); - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let buf_x = x.var(); + let tid = dispatch_id().x; + let arr = Var::<[i32; 4]>::zeroed(); + let i = i32::var_zeroed(); + while i < 4 { + arr.write(i.as_u32(), tid.as_i32() + i); + *i += 1; + } + buf_x.write(tid, arr); + }), + ); kernel.dispatch([1024, 1, 1]); let x_data = x.view(..).copy_to_vec(); for i in 0..1024 { @@ -680,19 +714,22 @@ fn capture_same_buffer_multiple_view() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|| { - let tid = dispatch_id().x; - let buf_x_lo = x.view(0..64).var(); - let buf_x_hi = x.view(64..).var(); - let x = if tid < 64 { - buf_x_lo.read(tid) - } else { - buf_x_hi.read(tid - 64) - }; - let buf_sum = sum.var(); + let shader = Kernel::::new( + &device, + track!(|| { + let tid = dispatch_id().x; + let buf_x_lo = x.view(0..64).var(); + let buf_x_hi = x.view(64..).var(); + let x = if tid < 64 { + buf_x_lo.read(tid) + } else { + buf_x_hi.read(tid - 64) + }; + let buf_sum = sum.var(); - buf_sum.atomic_fetch_add(0, x); - })); + buf_sum.atomic_fetch_add(0, x); + }), + ); shader.dispatch([x.len() as u32, 1, 1]); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); @@ -708,19 +745,22 @@ fn uniform() { let sum = device.create_buffer::(1); x.view(..).fill_fn(|i| i as f32); sum.view(..).fill(0.0); - let shader = device.create_kernel::(&track!(|v: Expr| { - let tid = dispatch_id().x; - let buf_x_lo = x.view(0..64).var(); - let buf_x_hi = x.view(64..).var(); - let x = if tid < 64 { - buf_x_lo.read(tid) - } else { - buf_x_hi.read(tid - 64) - }; - let buf_sum = sum.var(); - let x = x * v.reduce_prod(); - buf_sum.atomic_fetch_add(0, x); - })); + let shader = Kernel::::new( + &device, + track!(|v: Expr| { + let tid = dispatch_id().x; + let buf_x_lo = x.view(0..64).var(); + let buf_x_hi = x.view(64..).var(); + let x = if tid < 64 { + buf_x_lo.read(tid) + } else { + buf_x_hi.read(tid - 64) + }; + let buf_sum = sum.var(); + let x = x * v.reduce_prod(); + buf_sum.atomic_fetch_add(0, x); + }), + ); shader.dispatch([x.len() as u32, 1, 1], &Float3::new(1.0, 2.0, 3.0)); let mut sum_data = vec![0.0]; sum.view(..).copy_to(&mut sum_data); @@ -757,8 +797,9 @@ fn byte_buffer() { let i1 = push!(Big, big); let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); - device - .create_kernel::(&track!(|| unsafe { + Kernel::::new( + &device, + track!(|| unsafe { let buf = buf.var(); let i0 = i0 as u64; let i1 = i1 as u64; @@ -778,8 +819,9 @@ fn byte_buffer() { buf.write_as::(i1, v1.load()); buf.write_as::(i2, v2.load()); buf.write_as::(i3, v3.load()); - })) - .dispatch([1, 1, 1]); + }), + ) + .dispatch([1, 1, 1]); let data = buf.copy_to_vec(); macro_rules! pop { ($t:ty, $offset:expr) => {{ @@ -832,8 +874,9 @@ fn bindless_byte_buffer() { let i1 = push!(Big, big); let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); - device - .create_kernel::(&track!(|out: ByteBufferVar| unsafe { + Kernel::::new( + &device, + track!(|out: ByteBufferVar| unsafe { let heap = heap.var(); let buf = heap.byte_address_buffer(0u32); let i0 = i0 as u64; @@ -854,8 +897,9 @@ fn bindless_byte_buffer() { out.write_as::(i1, v1.load()); out.write_as::(i2, v2.load()); out.write_as::(i3, v3.load()); - })) - .dispatch([1, 1, 1], &out); + }), + ) + .dispatch([1, 1, 1], &out); let data = out.copy_to_vec(); macro_rules! pop { ($t:ty, $offset:expr) => {{ @@ -911,25 +955,28 @@ fn atomic() { }; let foo_max = device.create_buffer_from_slice(&[foo_max_init]); let foo_min = device.create_buffer_from_slice(&[foo_min_init]); - let kernel = device.create_kernel::(&track!(|| { - let i = dispatch_id().x; - let foos = foos.var(); - let foo = foos.read(i); - let foo_max = foo_max.var().atomic_ref(0); - let foo_min = foo_min.var().atomic_ref(0); - foo_max.i.fetch_max(foo.i); - foo_max.v.x.fetch_max(foo.v.x); - foo_max.v.y.fetch_max(foo.v.y); - for i in 0..4u32 { - foo_max.a[i].fetch_max(foo.a[i]); - } - foo_min.i.fetch_min(foo.i); - foo_min.v.x.fetch_min(foo.v.x); - foo_min.v.y.fetch_min(foo.v.y); - for i in 0..4u32 { - foo_min.a[i].fetch_min(foo.a[i]); - } - })); + let kernel = Kernel::::new( + &device, + track!(|| { + let i = dispatch_id().x; + let foos = foos.var(); + let foo = foos.read(i); + let foo_max = foo_max.var().atomic_ref(0); + let foo_min = foo_min.var().atomic_ref(0); + foo_max.i.fetch_max(foo.i); + foo_max.v.x.fetch_max(foo.v.x); + foo_max.v.y.fetch_max(foo.v.y); + for i in 0..4u32 { + foo_max.a[i].fetch_max(foo.a[i]); + } + foo_min.i.fetch_min(foo.i); + foo_min.v.x.fetch_min(foo.v.x); + foo_min.v.y.fetch_min(foo.v.y); + for i in 0..4u32 { + foo_min.a[i].fetch_min(foo.a[i]); + } + }), + ); kernel.dispatch([foos.len() as u32, 1, 1]); let foos = foos.view(..).copy_to_vec(); let foo_min = foo_min.view(..).copy_to_vec()[0]; @@ -953,3 +1000,52 @@ fn atomic() { assert_eq!(foo_max, expected_foo_max); assert_eq!(foo_min, expected_foo_min); } + +#[test] +fn dyn_callable() { + let device = get_device(); + let add = DynCallable:: DynExpr>::new( + &device, + Box::new(|a: DynExpr, b: DynExpr| -> DynExpr { + if let Some(a) = a.downcast::() { + let b = b.downcast::().unwrap(); + return DynExpr::new(track!(a + b)); + } else if let Some(a) = a.downcast::() { + let b = b.downcast::().unwrap(); + return DynExpr::new(track!(a + b)); + } else { + unreachable!() + } + }), + ); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + let w = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + + buf_z.write(tid, add.call(x.into(), y.into()).get::()); + w.var().write( + tid, + add.call(x.as_::().into(), y.as_::().into()) + .get::(), + ); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + let w_data = w.view(..).copy_to_vec(); + for i in 0..1024 { + assert_eq!(z_data[i], i as f32 + 1000.0 * i as f32); + assert_eq!(w_data[i], i as i32 + 1000 * i as i32); + } +} From a9d826c78f48ae2e82f831d860c804e51ac3154a Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sat, 23 Sep 2023 03:48:06 -0400 Subject: [PATCH 2/3] fixed doc --- luisa_compute/src/runtime.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 871aa50f..7475694e 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1213,14 +1213,14 @@ pub struct KernelDef { /// An executable kernel /// Kernel creation can be done in multiple ways: /// - Seperate recording and compilation: -/// ```no_run +/// ```rust /// // Recording: /// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); /// // Compilation: /// let kernel = device.compile_kernel(&kernel); /// ``` /// - Recording and compilation in one step: -/// ```no_run +/// ```rust /// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); /// ``` /// - Asynchronous compilation use [`Kernel::::new_async`] From d148b139268150f984438db13b70e185b2066126 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sat, 23 Sep 2023 03:52:37 -0400 Subject: [PATCH 3/3] fixed doc --- luisa_compute/src/runtime.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 7475694e..6c1fb73f 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1213,19 +1213,24 @@ pub struct KernelDef { /// An executable kernel /// Kernel creation can be done in multiple ways: /// - Seperate recording and compilation: -/// ```rust -/// // Recording: -/// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); +/// ```no_run +//// // Recording: +/// use luisa_compute::prelude::*; +/// let ctx = Context::new(std::env::current_exe().unwrap()); +/// let device = ctx.create_device("cpu"); +/// let kernel = KernelDef::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); /// // Compilation: /// let kernel = device.compile_kernel(&kernel); /// ``` /// - Recording and compilation in one step: -/// ```rust -/// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ ... })); +/// ```no_run +/// use luisa_compute::prelude::*; +/// let ctx = Context::new(std::env::current_exe().unwrap()); +/// let device = ctx.create_device("cpu"); +/// let kernel = Kernel::, Buffer, Buffer)>::new(&device, track!(|a,b,c|{ })); /// ``` /// - Asynchronous compilation use [`Kernel::::new_async`] /// - Custom build options using [`Kernel::::new_with_options`] -/// ``` /// pub struct Kernel { pub(crate) inner: RawKernel,