From 3638d4a2c6d9a666cf1b8ea5d268abf21a9de4c7 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Fri, 29 Dec 2023 05:51:24 -0500 Subject: [PATCH] update submod --- luisa_compute/src/resource.rs | 6 ++++ luisa_compute/tests/autodiff.rs | 54 ++++++++++++++++++++++++++++++++- luisa_compute_sys/LuisaCompute | 2 +- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 5e71d1f..dda9129 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1358,6 +1358,9 @@ impl Tex2d { pub fn format(&self) -> PixelFormat { self.handle.format } + pub fn storage(&self) -> PixelStorage { + self.handle.storage + } // pub fn read(&self, uv: impl AsExpr) -> Expr { // self.var().read(uv) // } @@ -1384,6 +1387,9 @@ impl Tex3d { pub fn format(&self) -> PixelFormat { self.handle.format } + pub fn storage(&self) -> PixelStorage { + self.handle.storage + } // pub fn read(&self, uv: impl AsExpr) -> Expr { // self.var().read(uv) // } diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index d419ea9..d799f85 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -956,7 +956,59 @@ fn autodiff_if_phi() { } } } - +#[test] +fn autodiff_if_phi_outer_no_else() { + let device = get_device(); + let x: Buffer = device.create_buffer(1024); + let y: Buffer = device.create_buffer(1024); + let dx: Buffer = device.create_buffer(1024); + let dy: Buffer = device.create_buffer(1024); + let mut rng = rand::thread_rng(); + x.view(..).fill_fn(|_| rng.gen()); + y.view(..).fill_fn(|_| rng.gen()); + let kernel = Kernel::::new( + &device, + &track!(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let z = 0.0f32.var(); + if true.expr() { + autodiff(|| { + requires_grad(x); + requires_grad(y); + if x > y { + let tmp = 0.0f32.var(); + *tmp = x * 4.0; + *z = tmp; + }; + backward(**z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + } + }), + ); + kernel.dispatch([1024, 1, 1]); + let dx = dx.view(..).copy_to_vec(); + let dy = dy.view(..).copy_to_vec(); + let x = x.view(..).copy_to_vec(); + let y = y.view(..).copy_to_vec(); + let cache_dir = kernel.cache_dir(); + for i in 0..1024 { + if x[i] > y[i] { + assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } else { + assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + } +} #[test] fn autodiff_if_phi2() { let device = get_device(); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 4cfe950..11aec87 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 4cfe950e7f254eb6fc8a5fcf35d7747dc1a84eac +Subproject commit 11aec87024ff757ec2ac544b2ff6c5baa790a2f8