diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 1b74ea9..ff43f5e 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -13,8 +13,8 @@ use bumpalo::Bump; use indexmap::IndexMap; pub use ir::ir::NodeRef; use ir::ir::{ - ArrayType, CallableModule, CallableModuleRef, ModulePools, SwitchCase, UserNodeData, - INVALID_REF, + ArrayType, CallableModule, CallableModuleRef, ModuleFlags, ModulePools, SwitchCase, + UserNodeData, INVALID_REF, }; pub use ir::CArc; use ir::Pooled; @@ -2016,11 +2016,7 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), - }; - let ir_module = { - // perform IR passes - let ad_transform = transform::autodiff::Autodiff; - ad_transform.transform(ir_module) + flags: ModuleFlags::REQUIRES_AD_TRANSFORM, }; let module = CallableModule { module: ir_module, @@ -2057,12 +2053,9 @@ impl KernelBuilder { entry, kind: ModuleKind::Kernel, pools: r.pools.clone().unwrap(), + flags: ModuleFlags::REQUIRES_AD_TRANSFORM, }; - let ir_module = { - // perform IR passes - let ad_transform = transform::autodiff::Autodiff; - ad_transform.transform(ir_module) - }; + let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); let module = KernelModule { module: ir_module, cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops), diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index c78bdfa..6304364 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -383,7 +383,6 @@ impl Device { } pub fn create_dyn_callable_once<'a, S: CallableSignature<'a, R>, R: CallableRet>( &self, - init_once: bool, f: S::DynFn, ) -> S::DynCallable { S::create_dyn_callable(self.clone(), true, f) diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 74d04fb..3a52972 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -984,7 +984,7 @@ fn autodiff_if_phi4() { autodiff(|| { requires_grad(x); requires_grad(y); - consts.store(make_float3(2.0,3.0,4.0)); + consts.store(make_float3(2.0, 3.0, 4.0)); let const_two = consts.x(); let const_three = consts.y(); let const_four = consts.z(); @@ -1081,3 +1081,71 @@ fn autodiff_switch() { } } } + +#[test] +fn autodiff_callable() { + let device = get_device(); + let t: Buffer = device.create_buffer(1024); + let x: Buffer = device.create_buffer(1024); + let y: Buffer = device.create_buffer(1024); + let dx: Buffer = device.create_buffer(1024); + let dy: Buffer = device.create_buffer(1024); + let mut rng = rand::thread_rng(); + t.view(..).fill_fn(|_| rng.gen_range(0..3)); + x.view(..).fill_fn(|_| rng.gen()); + y.view(..).fill_fn(|_| rng.gen()); + let callable = device.create_callable::<(Var, Var, Expr), ()>(&|vx, vy, t| { + let x = *vx; + let y = *vy; + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = switch::>(t) + .case(0, || x * 4.0) + .case(1, || x * 2.0) + .case(2, || y * 0.5) + .finish(); + backward(z); + *vx.get_mut() = gradient(x); + *vy.get_mut() = gradient(y); + }); + }); + let kernel = device.create_kernel::<()>(&|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + let dx = def(x); + let dy = def(y); + callable.call(dx, dy, t); + buf_dx.write(tid, *dx); + buf_dy.write(tid, *dy); + }); + kernel.dispatch([1024, 1, 1]); + let dx = dx.view(..).copy_to_vec(); + let dy = dy.view(..).copy_to_vec(); + let t = t.view(..).copy_to_vec(); + let cache_dir = kernel.cache_dir(); + for i in 0..1024 { + match t[i] { + 0 => { + assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + 1 => { + assert_eq!(dx[i], 2.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + 2 => { + assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir); + } + _ => unreachable!(), + } + } +} diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index a84ab03..a065d18 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit a84ab03602e1ba105189d0378e0c986565e9a7f3 +Subproject commit a065d181e5509f51f5b297aaa2522eee735d72f6