Skip to content

Commit

Permalink
added Forward AD
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 12, 2023
1 parent f0ec52d commit da2e826
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 38 deletions.
79 changes: 48 additions & 31 deletions luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::env::current_exe;
use std::{env::current_exe, f32::consts::PI};

use luisa::*;
use luisa_compute as luisa;
fn main() {
luisa::init_logger();
luisa::init_logger_verbose();

let ctx = Context::new(current_exe().unwrap());
let args: Vec<String> = std::env::args().collect();
Expand All @@ -19,36 +19,53 @@ fn main() {
});
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let dx = device.create_buffer::<f32>(1024);
let dy = device.create_buffer::<f32>(1024);
let dx_rev = device.create_buffer::<f32>(1024);
let dy_rev = device.create_buffer::<f32>(1024);
let dx_fwd = device.create_buffer::<f32>(1024);
let dy_fwd = device.create_buffer::<f32>(1024);
x.fill_fn(|i| i as f32);
y.fill_fn(|i| 1.0 + i as f32);
let shader = device.create_kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
&|buf_x: BufferVar<f32>,
buf_y: BufferVar<f32>,
buf_dx: BufferVar<f32>,
buf_dy: BufferVar<f32>| {
let tid = dispatch_id().x();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
autodiff(|| {
requires_grad(x);
requires_grad(y);
let z = if_!(x.cmpgt(y), {
x * 4.0
}, else {
y * 0.5
});
backward(z);
buf_dx.write(tid, gradient(x));
buf_dy.write(tid, gradient(y));
});
},
);
let shader = device.create_kernel::<fn()>(&|| {
let tid = dispatch_id().x();
let buf_x = x.var();
let buf_y = y.var();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
if_!(x.cmpgt(y), { x * y }, else, {
y * x + (x / 32.0 * PI).sin()
})
};
autodiff(|| {
requires_grad(x);
requires_grad(y);
let z = f(x, y);
backward(z);
dx_rev.write(tid, gradient(x));
dy_rev.write(tid, gradient(y));
});
forward_autodiff(2, || {
propagate_gradient(x, &[const_(1.0f32), const_(0.0f32)]);
propagate_gradient(y, &[const_(0.0f32), const_(1.0f32)]);
let z = f(x, y);
let dx = output_gradients(z)[0];
let dy = output_gradients(z)[1];
dx_fwd.write(tid, dx);
dy_fwd.write(tid, dy);
});
});

shader.dispatch([1024, 1, 1], &x.view(..), &y, &dx, &dy);
let dx = dx.copy_to_vec();
println!("{:?}", &dx[0..16]);
let dy = dy.copy_to_vec();
println!("{:?}", &dy[0..16]);
shader.dispatch([1024, 1, 1]);
{
let dx = dx_rev.copy_to_vec();
println!("{:?}", &dx[0..16]);
let dy = dy_rev.copy_to_vec();
println!("{:?}", &dy[0..16]);
}
{
let dx = dx_fwd.copy_to_vec();
println!("{:?}", &dx[0..16]);
let dy = dy_fwd.copy_to_vec();
println!("{:?}", &dy[0..16]);
}
}
16 changes: 10 additions & 6 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,8 @@ impl KernelBuilder {
entry,
kind: ModuleKind::Kernel,
pools: r.pools.clone().unwrap(),
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM,
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM
| ModuleFlags::REQUIRES_FWD_AD_TRANSFORM,
};
let module = CallableModule {
module: ir_module,
Expand Down Expand Up @@ -2053,7 +2054,8 @@ impl KernelBuilder {
entry,
kind: ModuleKind::Kernel,
pools: r.pools.clone().unwrap(),
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM,
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM
| ModuleFlags::REQUIRES_FWD_AD_TRANSFORM,
};
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);
let module = KernelModule {
Expand Down Expand Up @@ -2877,13 +2879,15 @@ pub fn forward_autodiff(n_grads: usize, body: impl Fn()) {
s.push(IrBuilder::new(pools));
});
body();
AD_CONTEXT.with(|c| {
let n_grads = AD_CONTEXT.with(|c| {
let mut c = c.borrow_mut();
let n_grads = c.n_forward_grads;
c.reset();
n_grads
});
let body = __pop_scope();
__current_scope(|b| {
b.ad_scope(body, true);
b.fwd_ad_scope(body, n_grads);
});
}

Expand Down Expand Up @@ -2922,7 +2926,7 @@ pub fn output_gradients<T: ExprProxy>(v: T) -> Vec<T> {
grads.push(T::from_node(b.call(
Func::OutputGrad,
&[v.node(), idx],
Type::void(),
v.node().type_().clone(),
)));
}
grads
Expand Down Expand Up @@ -2950,7 +2954,7 @@ pub fn autodiff(body: impl Fn()) {
});
let body = __pop_scope();
__current_scope(|b| {
b.ad_scope(body, false);
b.ad_scope(body);
});
}

Expand Down

0 comments on commit da2e826

Please sign in to comment.