Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 12, 2023
1 parent da2e826 commit 66471dc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,7 @@ impl KernelBuilder {
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM
| ModuleFlags::REQUIRES_FWD_AD_TRANSFORM,
};
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);
let module = CallableModule {
module: ir_module,
ret_type,
Expand Down Expand Up @@ -2937,6 +2938,7 @@ pub fn autodiff(body: impl Fn()) {
AD_CONTEXT.with(|c| {
let mut c = c.borrow_mut();
assert!(!c.started, "autodiff section is already started");
*c = AdContext::new_rev();
c.started = true;
});
RECORDER.with(|r| {
Expand Down
6 changes: 6 additions & 0 deletions luisa_compute/tests/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ fn autodiff_helper<F: Fn(&[Float]) -> Float>(
let grad_ad = (0..n_inputs)
.map(|_| device.create_buffer::<f32>(repeats))
.collect::<Vec<_>>();
// let grad_fwd_ad = (0..n_inputs)
// .map(|_| device.create_buffer::<f32>(repeats))
// .collect::<Vec<_>>();
let tic = std::time::Instant::now();
let tmp: Vec<Vec<f32>> = (0..n_inputs)
.into_par_iter()
Expand Down Expand Up @@ -110,6 +113,9 @@ fn autodiff_helper<F: Fn(&[Float]) -> Float>(
grad_ad_vars[i].write(tid, gradient(inputs[i]));
}
});
// forward_autodiff(n_inputs, ||{

// });
let fd = finite_difference(&inputs, &f);
for i in 0..n_inputs {
grad_fd_vars[i].write(tid, fd[i]);
Expand Down

0 comments on commit 66471dc

Please sign in to comment.