diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 1acfa4f..dbd0902 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -1355,3 +1355,73 @@ fn autodiff_callable() { } } } + +#[test] +fn autodiff_callable2() { + let device = get_device(); + let t: Buffer = device.create_buffer(1024); + let x: Buffer = device.create_buffer(1024); + let y: Buffer = device.create_buffer(1024); + let dx: Buffer = device.create_buffer(1024); + let dy: Buffer = device.create_buffer(1024); + let mut rng = rand::thread_rng(); + t.view(..).fill_fn(|_| rng.gen_range(0..3)); + x.view(..).fill_fn(|_| rng.gen()); + y.view(..).fill_fn(|_| rng.gen()); + let callable = Callable::, Expr, Expr) -> Expr>::new( + &device, + track!(|x, y, t| { + switch::>(t) + .case(0, || x * 4.0) + .case(1, || x * 2.0) + .case(2, || y * 0.5) + .finish() + }), + ); + let kernel = Kernel::::new( + &device, + &track!(|| { + let buf_t = t.var(); + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let t = buf_t.read(tid); + autodiff(|| { + requires_grad(x); + requires_grad(y); + let z = callable.call(x, y, t); + backward(z); + let dx = gradient(x); + let dy = gradient(y); + buf_dx.write(tid, dx); + buf_dy.write(tid, dy); + }); + }), + ); + kernel.dispatch([1024, 1, 1]); + let dx = dx.view(..).copy_to_vec(); + let dy = dy.view(..).copy_to_vec(); + let t = t.view(..).copy_to_vec(); + let cache_dir = kernel.cache_dir(); + for i in 0..1024 { + match t[i] { + 0 => { + assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + 1 => { + assert_eq!(dx[i], 2.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + 2 => { + assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir); + } + _ => unreachable!(), + } + } +} diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 34fbbdb..8eab5cb 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 34fbbdb082a41c0162f9320a060e2c5edba2386c +Subproject commit 8eab5cb80eafcba6c969258fb8c63b1105580703