From 66471dc7100883420008ea7174b8d2aa6d018fe4 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 12 Sep 2023 05:26:04 -0400 Subject: [PATCH] minor --- luisa_compute/src/lang/mod.rs | 2 ++ luisa_compute/tests/autodiff.rs | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 525a3a8..dd85ab2 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -2019,6 +2019,7 @@ impl KernelBuilder { flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM | ModuleFlags::REQUIRES_FWD_AD_TRANSFORM, }; + let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); let module = CallableModule { module: ir_module, ret_type, @@ -2937,6 +2938,7 @@ pub fn autodiff(body: impl Fn()) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); assert!(!c.started, "autodiff section is already started"); + *c = AdContext::new_rev(); c.started = true; }); RECORDER.with(|r| { diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 98ea5c9..9ed0a2b 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -66,6 +66,9 @@ fn autodiff_helper Float>( let grad_ad = (0..n_inputs) .map(|_| device.create_buffer::(repeats)) .collect::>(); + // let grad_fwd_ad = (0..n_inputs) + // .map(|_| device.create_buffer::(repeats)) + // .collect::>(); let tic = std::time::Instant::now(); let tmp: Vec> = (0..n_inputs) .into_par_iter() @@ -110,6 +113,9 @@ fn autodiff_helper Float>( grad_ad_vars[i].write(tid, gradient(inputs[i])); } }); + // forward_autodiff(n_inputs, ||{ + + // }); let fd = finite_difference(&inputs, &f); for i in 0..n_inputs { grad_fd_vars[i].write(tid, fd[i]);