From 1c35b289734f073abc21624a97d5a7adefb6ce53 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 3 Oct 2023 16:09:23 -0400 Subject: [PATCH 1/4] expose RawKernel --- luisa_compute/src/runtime.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index c8f1491..68d9810 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1373,6 +1373,10 @@ impl Kernel { pub fn dump(&self) -> String { ir::debug::dump_ir_human_readable(&self.inner.module.module) } + #[doc(hidden)] + pub fn raw(&self) -> &RawKernel { + &self.inner + } } pub trait AsKernelArg: KernelArg {} From 3550ecd7034ec702f744f2f721b9a90e88bd1db1 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 3 Oct 2023 16:36:26 -0400 Subject: [PATCH 2/4] expose build_kernel --- luisa_compute/src/runtime/kernel.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index a39d9f3..283180b 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -337,7 +337,15 @@ impl KernelBuilder { (resource_tracker, cpu_custom_ops, captured) }) } - fn build_callable(&mut self, body: impl FnOnce(&mut Self) -> R) -> RawCallable { + + + /// Don't use this directly + /// See [`Callable`] for how to create a callable + #[doc(hidden)] + pub fn build_callable( + &mut self, + body: impl FnOnce(&mut Self) -> R, + ) -> RawCallable { let ret = body(self); let ret_type = ret._return(); let (rt, cpu_custom_ops, captures) = self.collect_module_info(); @@ -381,7 +389,11 @@ impl KernelBuilder { } }) } - fn build_kernel( + + /// Don't use this directly + /// See [`Kernel`] for how to create a kernel + #[doc(hidden)] + pub fn build_kernel( &mut self, body: impl FnOnce(&mut Self), ) -> crate::runtime::KernelDef { From 7fa97cd6cd09d6304392485f016b1ce5f6ae161e Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Wed, 4 Oct 2023 22:46:11 -0400 Subject: [PATCH 3/4] update submod, ad fix --- luisa_compute/tests/autodiff.rs | 95 ++++++++++++++++++++++++++++++--- luisa_compute_sys/LuisaCompute | 2 +- 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 5f272fa..45d877c 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -249,19 +249,27 @@ macro_rules! autodiff_3 { #[repr(C)] #[value_new] struct Foo { - x: f32, - y: f32, + v: Float3, + f: f32, } autodiff_2!(autodiff_const, 1.0..10.0, |x: Expr, y: Expr| { let k = 2.0 / 3.0_f32.expr(); x * k + y * k }); -autodiff_2!(autodiff_struct, 1.0..10.0, |x: Expr, y: Expr| { - let foo = Foo::new_expr(x, y).var(); - *foo.x += 1.0; - foo.x + foo.y -}); +autodiff_3!( + autodiff_struct, + -1.0..1.0, + track!(|x: Expr, y: Expr, z: Expr| { + let foo = Foo::new_expr(Float3::expr(x, y, z), x + y + z).var(); + if foo.v.x > 3.0 { + *foo.v.x -= 1.0; + } else { + *foo.v.x -= foo.f; + }; + foo.v.x * foo.v.y + foo.v.z * foo.f + }) +); autodiff_1!(autodiff_sin, -10.0..10.0, |x: Expr| x.sin()); autodiff_1!(autodiff_cos, -10.0..10.0, |x: Expr| x.cos()); autodiff_1!(autodiff_sincos, -10.0..10.0, |x: Expr| x.cos() @@ -1141,6 +1149,79 @@ fn autodiff_if_phi4() { } } #[test] +fn autodiff_if_phi5() { + let device = get_device(); + let x: Buffer = device.create_buffer(1024); + let y: Buffer = device.create_buffer(1024); + let dx: Buffer = device.create_buffer(1024); + let dy: Buffer = device.create_buffer(1024); + let mut rng = rand::thread_rng(); + x.view(..).fill_fn(|_| rng.gen()); + y.view(..).fill_fn(|_| rng.gen()); + let kernel = Kernel::::new( + &device, + &track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let tmp_x = 0.0f32.var(); + let tmp_y = 0.0f32.var(); + let consts = Float3::var_zeroed(); + autodiff(|| { + requires_grad(x); + requires_grad(y); + *consts = Float3::expr(2.0, 3.0, 4.0); + let const_two = consts.x; + let const_three = consts.y; + let const_four = consts.z; + let c = (x > const_three).as_::(); + let z = if x > y { + switch::>(c) + .case(0, || { + *tmp_x = x * const_two; + **tmp_x + }) + .default(|| { + *tmp_y = x * const_four; + **tmp_y + }) + .finish() + * const_two + } else { + y * 0.5 + }; + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }), + ); + kernel.dispatch([1024, 1, 1]); + let dx = dx.view(..).copy_to_vec(); + let dy = dy.view(..).copy_to_vec(); + let x = x.view(..).copy_to_vec(); + let y = y.view(..).copy_to_vec(); + let cache_dir = kernel.cache_dir(); + for i in 0..1024 { + if x[i] > y[i] { + if x[i] > 3.0 { + assert_eq!(dx[i], 8.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } else { + assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + } else { + assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir); + } + } +} +#[test] fn autodiff_switch() { let device = get_device(); let t: Buffer = device.create_buffer(1024); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 360987a..6887785 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 360987a5124b09090c49b9da92347cf8d1452068 +Subproject commit 688778557bf4fe6cb87f0d3cd3caf0dd3259d880 From 0c66b59aa4328b6118c84cb130b75899b41cd05c Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Thu, 5 Oct 2023 08:24:37 -0400 Subject: [PATCH 4/4] rename --- luisa_compute/examples/autodiff.rs | 2 +- luisa_compute/src/lang.rs | 2 +- luisa_compute/src/lang/{diff.rs => autodiff.rs} | 0 luisa_compute/tests/autodiff.rs | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename luisa_compute/src/lang/{diff.rs => autodiff.rs} (100%) diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 6c11224..79de9f0 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -1,7 +1,7 @@ use std::env::current_exe; use std::f32::consts::PI; -use luisa::lang::diff::*; +use luisa::lang::autodiff::*; use luisa::prelude::*; use luisa_compute as luisa; fn main() { diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 8e22085..c8b0bc1 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -28,7 +28,7 @@ use self::index::IntoIndex; pub mod control_flow; pub mod debug; -pub mod diff; +pub mod autodiff; pub mod functions; pub mod index; pub mod ops; diff --git a/luisa_compute/src/lang/diff.rs b/luisa_compute/src/lang/autodiff.rs similarity index 100% rename from luisa_compute/src/lang/diff.rs rename to luisa_compute/src/lang/autodiff.rs diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 45d877c..8d9f1ce 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -1,7 +1,7 @@ use std::ops::Range; use alias::*; -use luisa::lang::diff::*; +use luisa::lang::autodiff::*; use luisa::lang::types::core::*; use luisa::lang::types::vector::*; use luisa::prelude::*;