Skip to content

Commit

Permalink
get both kernel creation syntax working
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
1 parent 1ef0541 commit a039e1a
Show file tree
Hide file tree
Showing 22 changed files with 160 additions and 113 deletions.
2 changes: 1 addition & 1 deletion luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn main() {
sum.view(..).fill(0.0);
let shader = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = dispatch_id().x;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn main() {
let dy_gt = device.create_buffer::<f32>(1024);
x.fill_fn(|i| i as f32);
y.fill_fn(|i| 1.0 + i as f32);
let shader = Kernel::<fn()>::new(&device, track!(|| {
let shader = Kernel::<fn()>::new(&device, track!(&|| {
let tid = dispatch_id().x;
let buf_x = x.var();
let buf_y = y.var();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ fn main() {
y,
exclude: 42.0,
};
let shader = Kernel::<fn(MyArgStruct<f32>)>::new(&device, |_args| {});
let shader = Kernel::<fn(MyArgStruct<f32>)>::new(&device, &|_args| {});
shader.dispatch([1024, 1, 1], &my_args);
}
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn main() {
bindless.emplace_buffer_async(1, &y);
bindless.emplace_tex2d_async(0, &img, Sampler::default());
bindless.update();
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
let bindless = bindless.var();
let tid = dispatch_id().x;
let buf_x = bindless.buffer::<f32>(0_u32.expr());
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/callable_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn main() {
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
track!(|buf_z| {
&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn main() {
});
let shader = Kernel::<fn(Buffer<f32>)>::new(
&device,
track!(|buf_z: BufferVar<f32>| {
&track!(|buf_z: BufferVar<f32>| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
Expand Down
18 changes: 9 additions & 9 deletions luisa_compute/examples/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn main() {

let advect = Kernel::<fn(Buffer<Float2>, Buffer<Float2>, Buffer<f32>, Buffer<f32>)>::new_async(
&device,
track!(|u0, u1, rho0, rho1| {
&track!(|u0, u1, rho0, rho1| {
let coord = dispatch_id().xy();
let u = u0.read(index(coord));

Expand All @@ -136,7 +136,7 @@ fn main() {

let divergence = Kernel::<fn(Buffer<Float2>, Buffer<f32>)>::new_async(
&device,
track!(|u, div| {
&track!(|u, div| {
let coord = dispatch_id().xy();
if coord.x < (N_GRID as u32 - 1) && coord.y < (N_GRID as u32 - 1) {
let dx = (u.read(index(Uint2::expr(coord.x + 1, coord.y))).x
Expand All @@ -152,7 +152,7 @@ fn main() {

let pressure_solve = Kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>::new_async(
&device,
track!(|p0, p1, div| {
&track!(|p0, p1, div| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -171,7 +171,7 @@ fn main() {

let pressure_apply = Kernel::<fn(Buffer<f32>, Buffer<Float2>)>::new_async(
&device,
track!(|p, u| {
&track!(|p, u| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -193,7 +193,7 @@ fn main() {

let integrate = Kernel::<fn(Buffer<Float2>, Buffer<f32>)>::new_async(
&device,
track!(|u, rho| {
&track!(|u, rho| {
let coord = dispatch_id().xy();
let ij = index(coord);

Expand All @@ -210,7 +210,7 @@ fn main() {

let init = Kernel::<fn(Buffer<f32>, Buffer<Float2>, Float2)>::new_async(
&device,
track!(|rho, u, dir| {
&track!(|rho, u, dir| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -225,7 +225,7 @@ fn main() {
}),
);

let init_grid = Kernel::<fn()>::new_async(&device, || {
let init_grid = Kernel::<fn()>::new_async(&device, &|| {
let idx = index(dispatch_id().xy());
u0.var().write(idx, Float2::expr(0.0f32, 0.0f32));
u1.var().write(idx, Float2::expr(0.0f32, 0.0f32));
Expand All @@ -238,15 +238,15 @@ fn main() {
div.var().write(idx, 0.0f32);
});

let clear_pressure = Kernel::<fn()>::new_async(&device, || {
let clear_pressure = Kernel::<fn()>::new_async(&device, &|| {
let idx = index(dispatch_id().xy());
p0.var().write(idx, 0.0f32);
p1.var().write(idx, 0.0f32);
});

let draw_rho = Kernel::<fn()>::new_async(
&device,
track!(|| {
&track!(|| {
let coord = dispatch_id().xy();
let ij = index(coord);
let value = rho0.var().read(ij);
Expand Down
12 changes: 6 additions & 6 deletions luisa_compute/examples/mpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn main() {

let clear_grid = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let idx = index(dispatch_id().xy());
grid_v.var().write(idx * 2, 0.0f32);
grid_v.var().write(idx * 2 + 1, 0.0f32);
Expand All @@ -102,7 +102,7 @@ fn main() {

let point_to_grid = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let p = dispatch_id().x;
let xp = x.var().read(p) / DX;
let base = (xp - 0.5f32).cast_i32();
Expand Down Expand Up @@ -135,7 +135,7 @@ fn main() {

let simulate_grid = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let coord = dispatch_id().xy();
let i = index(coord);
let v = Var::<Float2>::zeroed();
Expand Down Expand Up @@ -167,7 +167,7 @@ fn main() {

let grid_to_point = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let p = dispatch_id().x;
let xp = x.var().read(p) / DX;
let base = (xp - 0.5f32).cast_i32();
Expand Down Expand Up @@ -210,15 +210,15 @@ fn main() {
}),
);

let clear_display = Kernel::<fn()>::new(&device, || {
let clear_display = Kernel::<fn()>::new(&device, &|| {
display.var().write(
dispatch_id().xy(),
Float4::expr(0.1f32, 0.2f32, 0.3f32, 1.0f32),
);
});
let draw_particles = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let p = dispatch_id().x;
for i in -1..=1 {
for j in -1..=1 {
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ fn main() {
// use create_kernel_async to compile multiple kernels in parallel
let path_tracer = Kernel::<fn(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>::new_async(
&device,
track!(|image: Tex2dVar<Float4>,
&track!(|image: Tex2dVar<Float4>,
seed_image: Tex2dVar<u32>,
accel: AccelVar,
resolution: Expr<Uint2>| {
Expand Down Expand Up @@ -460,7 +460,7 @@ fn main() {
);
let display = Kernel::<fn(Tex2d<Float4>, Tex2d<Float4>)>::new_async(
&device,
track!(|acc, display| {
&track!(|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ fn main() {
// use create_kernel_async to compile multiple kernels in parallel
let path_tracer = Kernel::<fn(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>::new_async(
&device,
track!(|image: Tex2dVar<Float4>,
&track!(|image: Tex2dVar<Float4>,
seed_image: Tex2dVar<u32>,
accel: AccelVar,
resolution: Expr<Uint2>| {
Expand Down Expand Up @@ -508,7 +508,7 @@ fn main() {
);
let display = Kernel::<fn(Tex2d<Float4>, Tex2d<Float4>)>::new_async(
&device,
track!(|acc, display| {
&track!(|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/polymorphism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() {
let areas = device.create_buffer::<f32>(4);
let shader = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let tid = dispatch_id().x;
let tag = tid / 2;
let index = tid % 2;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/polymorphism_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ fn main() {
let result = device.create_buffer::<f32>(100);
let kernel = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let i = dispatch_id().x;
let x = i.as_f32() / 100.0 * PI;
let ctx = ShaderEvalContext {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn main() {
"cpu"
});
let printer = Printer::new(&device, 65536);
let kernel = Kernel::<fn()>::new(&device, track!(|| {
let kernel = Kernel::<fn()>::new(&device, &track!(|| {
let id = dispatch_id().xy();
if id.x == id.y {
lc_info!(printer, "id = {:?}", id);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/ray_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ fn main() {
let debug_hit_t = device.create_buffer::<f32>(4);
let rt_kernel = Kernel::<fn()>::new(
&device,
track!(|| {
&track!(|| {
let accel = accel.var();
let px = dispatch_id().xy();
let xy = px.as_float2() / Float2::expr(img_w as f32, img_h as f32);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/raytracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn main() {
let img_w = 800;
let img_h = 800;
let img = device.create_tex2d::<Float4>(PixelStorage::Byte4, img_w, img_h, 1);
let rt_kernel = Kernel::<fn()>::new(&device,track!(|| {
let rt_kernel = Kernel::<fn()>::new(&device,&track!(|| {
let accel = accel.var();
let px = dispatch_id().xy();
let xy = px.as_::<Float2>() / Float2::expr(img_w as f32, img_h as f32);
Expand Down
27 changes: 12 additions & 15 deletions luisa_compute/examples/vecadd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,18 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);

let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let vx = 2.0_f32.var(); // create a local mutable variable
*vx += x; // store to vx
buf_z.write(tid, vx + y);
}),
);

let kernel = device.create_kernel::<fn(Buffer<f32>)>(&track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let vx = 2.0_f32.var(); // create a local mutable variable
*vx += x; // store to vx
buf_z.write(tid, vx + y);
}));
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
Expand Down
Loading

0 comments on commit a039e1a

Please sign in to comment.