Skip to content

Commit

Permalink
Merge pull request #18 from LuisaGroup/refactor-kernel-creation
Browse files Browse the repository at this point in the history
Refactor kernel creation
  • Loading branch information
shiinamiyuki authored Sep 23, 2023
2 parents 6a1910a + d148b13 commit b52438f
Show file tree
Hide file tree
Showing 25 changed files with 1,387 additions and 1,247 deletions.
15 changes: 9 additions & 6 deletions luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ fn main() {
let sum = device.create_buffer::<f32>(1);
x.view(..).fill_fn(|i| i as f32);
sum.view(..).fill(0.0);
let shader = device.create_kernel::<fn()>(&track!(|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = dispatch_id().x;
buf_sum.atomic_fetch_add(0, buf_x.read(tid));
}));
let shader = Kernel::<fn()>::new(
&device,
track!(|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = dispatch_id().x;
buf_sum.atomic_fetch_add(0, buf_x.read(tid));
}),
);
shader.dispatch([x.len() as u32, 1, 1]);
let mut sum_data = vec![0.0];
sum.view(..).copy_to(&mut sum_data);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn main() {
let dy_gt = device.create_buffer::<f32>(1024);
x.fill_fn(|i| i as f32);
y.fill_fn(|i| 1.0 + i as f32);
let shader = device.create_kernel::<fn()>(track!(&|| {
let shader = Kernel::<fn()>::new(&device, track!(|| {
let tid = dispatch_id().x;
let buf_x = x.var();
let buf_y = y.var();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<fn(Buffer<f32>)>(track!(&|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ fn main() {
y,
exclude: 42.0,
};
let shader = device.create_kernel::<fn(MyArgStruct<f32>)>(&|_args| {});
let shader = Kernel::<fn(MyArgStruct<f32>)>::new(&device, |_args| {});
shader.dispatch([1024, 1, 1], &my_args);
}
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn main() {
bindless.emplace_buffer_async(1, &y);
bindless.emplace_tex2d_async(0, &img, Sampler::default());
bindless.update();
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&track!(|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
let bindless = bindless.var();
let tid = dispatch_id().x;
let buf_x = bindless.buffer::<f32>(0_u32.expr());
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ fn main() {
"cpu"
});
let add =
device.create_callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>(&|a, b| track!(a + b));
Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(&device, |a, b| track!(a + b));
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&track!(|buf_z| {
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
Expand Down
38 changes: 21 additions & 17 deletions luisa_compute/examples/callable_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ fn main() {
} else {
"cpu"
});
let add = device.create_dyn_callable::<fn(DynExpr, DynExpr) -> DynExpr>(Box::new(
|a: DynExpr, b: DynExpr| -> DynExpr {
let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
&device,
Box::new(|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
let b = b.downcast::<f32>().unwrap();
return DynExpr::new(track!(a + b));
Expand All @@ -29,28 +30,31 @@ fn main() {
} else {
unreachable!()
}
},
));
}),
);
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
let w = device.create_buffer::<i32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);

buf_z.write(tid, add.call(x.into(), y.into()).get::<f32>());
w.var().write(
tid,
add.call(x.as_::<i32>().into(), y.as_::<i32>().into())
.get::<i32>(),
);
}));
buf_z.write(tid, add.call(x.into(), y.into()).get::<f32>());
w.var().write(
tid,
add.call(x.as_::<i32>().into(), y.as_::<i32>().into())
.get::<i32>(),
);
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
Expand Down
33 changes: 18 additions & 15 deletions luisa_compute/examples/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,24 @@ fn main() {
println!("Hello from thread 0!");
}
});
let shader = device.create_kernel::<fn(Buffer<f32>)>(&track!(|buf_z: BufferVar<f32>| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let args = MyAddArgs::new_expr(x, y, 0.0f32.expr());
let result = my_add.call(args);
let _ = my_print.call(tid);
if tid == 0 {
cpu_dbg!(args);
}
buf_z.write(tid, result.result);
}));
let shader = Kernel::<fn(Buffer<f32>)>::new(
&device,
track!(|buf_z: BufferVar<f32>| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let args = MyAddArgs::new_expr(x, y, 0.0f32.expr());
let result = my_add.call(args);
let _ = my_print.call(tid);
if tid == 0 {
cpu_dbg!(args);
}
buf_z.write(tid, result.result);
}),
);
shader.dispatch([1024, 1, 1], &z);
let mut z_data = vec![0.0; 1024];
z.view(..).copy_to(&mut z_data);
Expand Down
93 changes: 52 additions & 41 deletions luisa_compute/examples/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,25 @@ fn main() {
}
);

let advect = device
.create_kernel_async::<fn(Buffer<Float2>, Buffer<Float2>, Buffer<f32>, Buffer<f32>)>(
track!(&|u0, u1, rho0, rho1| {
let coord = dispatch_id().xy();
let u = u0.read(index(coord));

// trace backward
let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32());
p = p - u * dt;

// advect
u1.write(index(coord), sample_vel(u0, p.x, p.y));
rho1.write(index(coord), sample_float(rho0, p.x, p.y));
}),
);
let advect = Kernel::<fn(Buffer<Float2>, Buffer<Float2>, Buffer<f32>, Buffer<f32>)>::new_async(
&device,
track!(|u0, u1, rho0, rho1| {
let coord = dispatch_id().xy();
let u = u0.read(index(coord));

// trace backward
let mut p = Float2::expr(coord.x.as_f32(), coord.y.as_f32());
p = p - u * dt;

// advect
u1.write(index(coord), sample_vel(u0, p.x, p.y));
rho1.write(index(coord), sample_float(rho0, p.x, p.y));
}),
);

let divergence =
device.create_kernel_async::<fn(Buffer<Float2>, Buffer<f32>)>(track!(&|u, div| {
let divergence = Kernel::<fn(Buffer<Float2>, Buffer<f32>)>::new_async(
&device,
track!(|u, div| {
let coord = dispatch_id().xy();
if coord.x < (N_GRID as u32 - 1) && coord.y < (N_GRID as u32 - 1) {
let dx = (u.read(index(Uint2::expr(coord.x + 1, coord.y))).x
Expand All @@ -146,10 +147,12 @@ fn main() {
* 0.5;
div.write(index(coord), dx + dy);
}
}));
}),
);

let pressure_solve = device.create_kernel_async::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
track!(&|p0, p1, div| {
let pressure_solve = Kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>::new_async(
&device,
track!(|p0, p1, div| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -166,8 +169,9 @@ fn main() {
}),
);

let pressure_apply =
device.create_kernel_async::<fn(Buffer<f32>, Buffer<Float2>)>(track!(&|p, u| {
let pressure_apply = Kernel::<fn(Buffer<f32>, Buffer<Float2>)>::new_async(
&device,
track!(|p, u| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -184,10 +188,12 @@ fn main() {

u.write(ij, u.read(ij) - f_p);
}
}));
}),
);

let integrate =
device.create_kernel_async::<fn(Buffer<Float2>, Buffer<f32>)>(track!(&|u, rho| {
let integrate = Kernel::<fn(Buffer<Float2>, Buffer<f32>)>::new_async(
&device,
track!(|u, rho| {
let coord = dispatch_id().xy();
let ij = index(coord);

Expand All @@ -199,10 +205,12 @@ fn main() {

// fade
rho.write(ij, rho.read(ij) * (1.0f32 - 0.1f32 * dt));
}));
}),
);

let init = device.create_kernel_async::<fn(Buffer<f32>, Buffer<Float2>, Float2)>(track!(
&|rho, u, dir| {
let init = Kernel::<fn(Buffer<f32>, Buffer<Float2>, Float2)>::new_async(
&device,
track!(|rho, u, dir| {
let coord = dispatch_id().xy();
let i = coord.x.as_i32();
let j = coord.y.as_i32();
Expand All @@ -214,10 +222,10 @@ fn main() {
rho.write(ij, 1.0f32);
u.write(ij, dir);
}
}
));
}),
);

let init_grid = device.create_kernel_async::<fn()>(&|| {
let init_grid = Kernel::<fn()>::new_async(&device, || {
let idx = index(dispatch_id().xy());
u0.var().write(idx, Float2::expr(0.0f32, 0.0f32));
u1.var().write(idx, Float2::expr(0.0f32, 0.0f32));
Expand All @@ -230,21 +238,24 @@ fn main() {
div.var().write(idx, 0.0f32);
});

let clear_pressure = device.create_kernel_async::<fn()>(&|| {
let clear_pressure = Kernel::<fn()>::new_async(&device, || {
let idx = index(dispatch_id().xy());
p0.var().write(idx, 0.0f32);
p1.var().write(idx, 0.0f32);
});

let draw_rho = device.create_kernel_async::<fn()>(&track!(|| {
let coord = dispatch_id().xy();
let ij = index(coord);
let value = rho0.var().read(ij);
display.var().write(
Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y),
Float4::expr(value, 0.0f32, 0.0f32, 1.0f32),
);
}));
let draw_rho = Kernel::<fn()>::new_async(
&device,
track!(|| {
let coord = dispatch_id().xy();
let ij = index(coord);
let value = rho0.var().read(ij);
display.var().write(
Uint2::expr(coord.x, (N_GRID - 1) as u32 - coord.y),
Float4::expr(value, 0.0f32, 0.0f32, 1.0f32),
);
}),
);

event_loop.run(move |event, _, control_flow| {
control_flow.set_poll();
Expand Down
Loading

0 comments on commit b52438f

Please sign in to comment.