Skip to content

Commit

Permalink
update submod; autodiff with callables
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 19, 2023
1 parent dd415e7 commit 52cc707
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
70 changes: 70 additions & 0 deletions luisa_compute/tests/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,3 +1355,73 @@ fn autodiff_callable() {
}
}
}

#[test]
fn autodiff_callable2() {
let device = get_device();
let t: Buffer<i32> = device.create_buffer(1024);
let x: Buffer<f32> = device.create_buffer(1024);
let y: Buffer<f32> = device.create_buffer(1024);
let dx: Buffer<f32> = device.create_buffer(1024);
let dy: Buffer<f32> = device.create_buffer(1024);
let mut rng = rand::thread_rng();
t.view(..).fill_fn(|_| rng.gen_range(0..3));
x.view(..).fill_fn(|_| rng.gen());
y.view(..).fill_fn(|_| rng.gen());
let callable = Callable::<fn(Expr<f32>, Expr<f32>, Expr<i32>) -> Expr<f32>>::new(
&device,
track!(|x, y, t| {
switch::<Expr<f32>>(t)
.case(0, || x * 4.0)
.case(1, || x * 2.0)
.case(2, || y * 0.5)
.finish()
}),
);
let kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let buf_t = t.var();
let buf_x = x.var();
let buf_y = y.var();
let buf_dx = dx.var();
let buf_dy = dy.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let t = buf_t.read(tid);
autodiff(|| {
requires_grad(x);
requires_grad(y);
let z = callable.call(x, y, t);
backward(z);
let dx = gradient(x);
let dy = gradient(y);
buf_dx.write(tid, dx);
buf_dy.write(tid, dy);
});
}),
);
kernel.dispatch([1024, 1, 1]);
let dx = dx.view(..).copy_to_vec();
let dy = dy.view(..).copy_to_vec();
let t = t.view(..).copy_to_vec();
let cache_dir = kernel.cache_dir();
for i in 0..1024 {
match t[i] {
0 => {
assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
}
1 => {
assert_eq!(dx[i], 2.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
}
2 => {
assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir);
}
_ => unreachable!(),
}
}
}

0 comments on commit 52cc707

Please sign in to comment.