diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index ca8188a..2cb8f0f 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -1,9 +1,9 @@ -use std::env::current_exe; +use std::{env::current_exe, f32::consts::PI}; use luisa::*; use luisa_compute as luisa; fn main() { - luisa::init_logger(); + luisa::init_logger_verbose(); let ctx = Context::new(current_exe().unwrap()); let args: Vec = std::env::args().collect(); @@ -19,36 +19,53 @@ fn main() { }); let x = device.create_buffer::(1024); let y = device.create_buffer::(1024); - let dx = device.create_buffer::(1024); - let dy = device.create_buffer::(1024); + let dx_rev = device.create_buffer::(1024); + let dy_rev = device.create_buffer::(1024); + let dx_fwd = device.create_buffer::(1024); + let dy_fwd = device.create_buffer::(1024); x.fill_fn(|i| i as f32); y.fill_fn(|i| 1.0 + i as f32); - let shader = device.create_kernel::, Buffer, Buffer, Buffer)>( - &|buf_x: BufferVar, - buf_y: BufferVar, - buf_dx: BufferVar, - buf_dy: BufferVar| { - let tid = dispatch_id().x(); - let x = buf_x.read(tid); - let y = buf_y.read(tid); - autodiff(|| { - requires_grad(x); - requires_grad(y); - let z = if_!(x.cmpgt(y), { - x * 4.0 - }, else { - y * 0.5 - }); - backward(z); - buf_dx.write(tid, gradient(x)); - buf_dy.write(tid, gradient(y)); - }); - }, - ); + let shader = device.create_kernel::(&|| { + let tid = dispatch_id().x(); + let buf_x = x.var(); + let buf_y = y.var(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let f = |x: Expr, y: Expr| { + if_!(x.cmpgt(y), { x * y }, else, { + y * x + (x / 32.0 * PI).sin() + }) + }; + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = f(x, y); + backward(z); + dx_rev.write(tid, gradient(x)); + dy_rev.write(tid, gradient(y)); + }); + forward_autodiff(2, || { + propagate_gradient(x, &[const_(1.0f32), const_(0.0f32)]); + propagate_gradient(y, &[const_(0.0f32), const_(1.0f32)]); + let z = f(x, y); + let dx = output_gradients(z)[0]; + let dy = output_gradients(z)[1]; + dx_fwd.write(tid, dx); + dy_fwd.write(tid, dy); + }); + }); - shader.dispatch([1024, 1, 1], &x.view(..), &y, &dx, &dy); - let dx = dx.copy_to_vec(); - println!("{:?}", &dx[0..16]); - let dy = dy.copy_to_vec(); - println!("{:?}", &dy[0..16]); + shader.dispatch([1024, 1, 1]); + { + let dx = dx_rev.copy_to_vec(); + println!("{:?}", &dx[0..16]); + let dy = dy_rev.copy_to_vec(); + println!("{:?}", &dy[0..16]); + } + { + let dx = dx_fwd.copy_to_vec(); + println!("{:?}", &dx[0..16]); + let dy = dy_fwd.copy_to_vec(); + println!("{:?}", &dy[0..16]); + } } diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index a22999b..525a3a8 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -2016,7 +2016,8 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM, + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM + | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, }; let module = CallableModule { module: ir_module, @@ -2053,7 +2054,8 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), - flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM, + flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM + | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, }; let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); let module = KernelModule { @@ -2877,13 +2879,15 @@ pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { s.push(IrBuilder::new(pools)); }); body(); - AD_CONTEXT.with(|c| { + let n_grads = AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); + let n_grads = c.n_forward_grads; c.reset(); + n_grads }); let body = __pop_scope(); __current_scope(|b| { - b.ad_scope(body, true); + b.fwd_ad_scope(body, n_grads); }); } @@ -2922,7 +2926,7 @@ pub fn output_gradients(v: T) -> Vec { grads.push(T::from_node(b.call( Func::OutputGrad, &[v.node(), idx], - Type::void(), + v.node().type_().clone(), ))); } grads @@ -2950,7 +2954,7 @@ pub fn autodiff(body: impl Fn()) { }); let body = __pop_scope(); __current_scope(|b| { - b.ad_scope(body, false); + b.ad_scope(body); }); } diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 259b58d..f6b8ebf 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 259b58d69e67f311bf9de31d19e9c38ea3d4c6ff +Subproject commit f6b8ebf8eaead724a3dd75cd80d0bb5d2f0ca1ef