Skip to content

Commit

Permalink
fixed callable capture
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
2 parents a039e1a + 541f149 commit 79a872b
Show file tree
Hide file tree
Showing 26 changed files with 215 additions and 140 deletions.
87 changes: 45 additions & 42 deletions luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,52 @@ fn main() {
let dy_gt = device.create_buffer::<f32>(1024);
x.fill_fn(|i| i as f32);
y.fill_fn(|i| 1.0 + i as f32);
let shader = Kernel::<fn()>::new(&device, track!(&|| {
let tid = dispatch_id().x;
let buf_x = x.var();
let buf_y = y.var();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
if x > y {
x * y
} else {
y * x + (x / 4.0 * PI).sin()
let shader = Kernel::<fn()>::new(
&device,
&track!(|| {
let tid = dispatch_id().x;
let buf_x = x.var();
let buf_y = y.var();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
if x > y {
x * y
} else {
y * x + (x / 4.0 * PI).sin()
}
};
let df = |x: Expr<f32>, y: Expr<f32>| {
if x > y {
(y, x)
} else {
(y + (x / 4.0 * PI).cos() / 4.0 * PI, x)
}
};
autodiff(|| {
requires_grad(x);
requires_grad(y);
let z = f(x, y);
backward(z);
dx_rev.write(tid, gradient(x));
dy_rev.write(tid, gradient(y));
});
forward_autodiff(2, || {
propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]);
propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]);
let z = f(x, y);
let dx = output_gradients(z)[0];
let dy = output_gradients(z)[1];
dx_fwd.write(tid, dx);
dy_fwd.write(tid, dy);
});
{
let (dx, dy) = df(x, y);
dx_gt.write(tid, dx);
dy_gt.write(tid, dy);
}
};
let df = |x: Expr<f32>, y: Expr<f32>| {
if x > y {
(y, x)
} else {
(y + (x / 4.0 * PI).cos() / 4.0 * PI, x)
}
};
autodiff(|| {
requires_grad(x);
requires_grad(y);
let z = f(x, y);
backward(z);
dx_rev.write(tid, gradient(x));
dy_rev.write(tid, gradient(y));
});
forward_autodiff(2, || {
propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]);
propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]);
let z = f(x, y);
let dx = output_gradients(z)[0];
let dy = output_gradients(z)[1];
dx_fwd.write(tid, dx);
dy_fwd.write(tid, dy);
});
{
let (dx, dy) = df(x, y);
dx_gt.write(tid, dx);
dy_gt.write(tid, dy);
}
}));
}),
);

shader.dispatch([1024, 1, 1]);
{
Expand Down
27 changes: 15 additions & 12 deletions luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ fn main() {
} else {
"cpu"
});

let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid + 123);
let y = buf_y.read(tid);
let vx = Var::<f32>::zeroed(); // create a local mutable variable
*vx = x;
buf_z.write(tid, vx + y);
}));
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid + 123);
let y = buf_y.read(tid);
let vx = Var::<f32>::zeroed(); // create a local mutable variable
*vx = x;
buf_z.write(tid, vx + y);
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
Expand Down
21 changes: 12 additions & 9 deletions luisa_compute/examples/bindless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ fn main() {
bindless.emplace_buffer_async(1, &y);
bindless.emplace_tex2d_async(0, &img, Sampler::default());
bindless.update();
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
let bindless = bindless.var();
let tid = dispatch_id().x;
let buf_x = bindless.buffer::<f32>(0_u32.expr());
let buf_y = bindless.buffer::<f32>(1_u32.expr());
let x = buf_x.read(tid).as_::<u32>().as_::<f32>();
let y = buf_y.read(tid);
buf_z.write(tid, x + y);
}));
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
let bindless = bindless.var();
let tid = dispatch_id().x;
let buf_x = bindless.buffer::<f32>(0_u32.expr());
let buf_y = bindless.buffer::<f32>(1_u32.expr());
let x = buf_x.read(tid).as_::<u32>().as_::<f32>();
let y = buf_y.read(tid);
buf_z.write(tid, x + y);
}),
);
kernel.dispatch([1024, 1, 1], &z);
let mut z_data = vec![0.0; 1024];
z.view(..).copy_to(&mut z_data);
Expand Down
22 changes: 12 additions & 10 deletions luisa_compute/examples/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@ fn main() {
} else {
"cpu"
});
let add =
Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(&device, |a, b| track!(a + b));
let add = Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(&device, |a, b| track!(a + b));
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(&device, &track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);

buf_z.write(tid, add.call(x, y));
}));
buf_z.write(tid, add.call(x, y));
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/examples/custom_aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use luisa::{prelude::*, lang::types::vector::alias::Float3};
use luisa::lang::types::vector::alias::Float3;
use luisa::prelude::*;
use luisa_compute as luisa;
#[derive(Aggregate)]
pub struct Spectrum {
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/examples/mpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use std::env::current_exe;
use std::time::Instant;

use luisa::lang::types::vector::{alias::*, Mat2};
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::Mat2;
use luisa::prelude::*;
use luisa_compute as luisa;
use rand::Rng;
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::time::Instant;
use winit::event::{Event as WinitEvent, WindowEvent};
use winit::event_loop::EventLoop;

use luisa::lang::types::vector::{alias::*, *};
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::*;
use luisa::prelude::*;
use luisa::rtx::{
offset_ray_origin, Accel, AccelBuildRequest, AccelOption, AccelVar, Index, Ray, RayComps,
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use image::Rgb;
use luisa::lang::types::vector::{alias::*, Mat4};
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::Mat4;
use luisa_compute_api_types::StreamTag;
use rand::Rng;
use std::env::current_exe;
Expand Down
19 changes: 11 additions & 8 deletions luisa_compute/examples/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ fn main() {
"cpu"
});
let printer = Printer::new(&device, 65536);
let kernel = Kernel::<fn()>::new(&device, &track!(|| {
let id = dispatch_id().xy();
if id.x == id.y {
lc_info!(printer, "id = {:?}", id);
} else {
lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y);
}
}));
let kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let id = dispatch_id().xy();
if id.x == id.y {
lc_info!(printer, "id = {:?}", id);
} else {
lc_info!(printer, "not equal!, id = [{} {}]", id.x, id.y);
}
}),
);
device.default_stream().with_scope(|s| {
s.reset_printer(&printer);
s.submit([kernel.dispatch_async([4, 4, 1])]);
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/examples/ray_query.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::env::current_exe;

use image::Rgb;
use luisa::lang::types::vector::{alias::*, *};
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::*;
use luisa::prelude::*;
use luisa::rtx::{
Aabb, AccelBuildRequest, AccelOption, ProceduralCandidate, Ray, RayQuery, TriangleCandidate,
Expand Down
54 changes: 30 additions & 24 deletions luisa_compute/examples/raytracing.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::env::current_exe;

use image::Rgb;
use luisa::lang::{types::vector::alias::*, types::vector::*, types::*, *};
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::*;
use luisa::lang::types::*;
use luisa::lang::*;
use luisa::prelude::*;
use luisa::rtx::{AccelBuildRequest, AccelOption, Ray};
use luisa_compute as luisa;
Expand Down Expand Up @@ -36,29 +39,32 @@ fn main() {
let img_w = 800;
let img_h = 800;
let img = device.create_tex2d::<Float4>(PixelStorage::Byte4, img_w, img_h, 1);
let rt_kernel = Kernel::<fn()>::new(&device,&track!(|| {
let accel = accel.var();
let px = dispatch_id().xy();
let xy = px.as_::<Float2>() / Float2::expr(img_w as f32, img_h as f32);
let xy = 2.0 * xy - 1.0;
let o = Float3::expr(0.0, 0.0, -1.0);
let d = Float3::expr(xy.x, xy.y, 0.0) - o;
let d = d.normalize();
let ray = Ray::new_expr(
Expr::<[f32; 3]>::from(o),
1e-3,
Expr::<[f32; 3]>::from(d),
1e9,
);
let hit = accel.trace_closest(ray);
let img = img.view(0).var();
let color = select(
hit.valid(),
Float3::expr(hit.u, hit.v, 1.0),
Float3::expr(0.0, 0.0, 0.0),
);
img.write(px, Float4::expr(color.x, color.y, color.z, 1.0));
}));
let rt_kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let accel = accel.var();
let px = dispatch_id().xy();
let xy = px.as_::<Float2>() / Float2::expr(img_w as f32, img_h as f32);
let xy = 2.0 * xy - 1.0;
let o = Float3::expr(0.0, 0.0, -1.0);
let d = Float3::expr(xy.x, xy.y, 0.0) - o;
let d = d.normalize();
let ray = Ray::new_expr(
Expr::<[f32; 3]>::from(o),
1e-3,
Expr::<[f32; 3]>::from(d),
1e9,
);
let hit = accel.trace_closest(ray);
let img = img.view(0).var();
let color = select(
hit.valid(),
Float3::expr(hit.u, hit.v, 1.0),
Float3::expr(0.0, 0.0, 0.0),
);
img.write(px, Float4::expr(color.x, color.y, color.z, 1.0));
}),
);
let event_loop = EventLoop::new();
let window = winit::window::WindowBuilder::new()
.with_title("Luisa Compute Rust - Ray Tracing")
Expand Down
14 changes: 8 additions & 6 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,11 @@ pub fn __module_pools() -> &'static CArc<ModulePools> {

/// Don't call this function directly unless you know what you are doing
/** This function is soley for constructing proxies
* Given a node, __extract selects the correct Func based on the node's type
* It then inserts the extract(node, i) call *at where the node is defined*
* *Note*, after insertion, the IrBuilder in the correct/parent scope might not be up to date
* Thus, for IrBuilder of each scope, it updates the insertion point to the end of the current basic block
* Given a node, __extract selects the correct Func based on the node's
* type It then inserts the extract(node, i) call *at where the node is
* defined* *Note*, after insertion, the IrBuilder in the correct/parent
* scope might not be up to date Thus, for IrBuilder of each scope, it
* updates the insertion point to the end of the current basic block
*/
pub fn __extract<T: Value>(node: NodeRef, index: usize) -> NodeRef {
let inst = &node.get().instruction;
Expand All @@ -408,8 +409,9 @@ pub fn __extract<T: Value>(node: NodeRef, index: usize) -> NodeRef {
}

let i = b.const_(Const::Int32(index as i32));
// Since we have inserted something, the insertion point in cur_builder might not be up to date
// So we need to set it to the end of the current basic block
// Since we have inserted something, the insertion point in cur_builder might
// not be up to date So we need to set it to the end of the current
// basic block
macro_rules! update_builders {
() => {
for scope in &mut r.scopes {
Expand Down
1 change: 0 additions & 1 deletion luisa_compute/src/lang/ops/cast_impls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#[rustfmt::skip]mod impl_{
use crate::prelude::*;
use super::super::*;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ where
self.clone().mul(self.clone()).mul(self.clone())
}
fn recip(&self) -> Self {
<Self as FromNode>::from_node(__current_scope(|b|{
<Self as FromNode>::from_node(__current_scope(|b| {
let one = b.const_(Const::One(<X as TypeOf>::type_()));
b.call(Func::Div, &[one, self.node()], <X as TypeOf>::type_())
}))
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lang/ops/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub trait IntExpr {

pub trait FloatExpr: Sized {
type Bool;

fn ceil(&self) -> Self;
fn floor(&self) -> Self;
fn round(&self) -> Self;
Expand Down
3 changes: 2 additions & 1 deletion luisa_compute/src/lang/soa.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use luisa_compute_ir::{ir::Type, CArc};
use luisa_compute_ir::ir::Type;
use luisa_compute_ir::CArc;

use crate::prelude::*;
/** A buffer with SOA layout.
Expand Down
Loading

0 comments on commit 79a872b

Please sign in to comment.