Skip to content

Commit

Permalink
Merge remote-tracking branch 'root/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
entropylost committed Oct 6, 2023
2 parents 1d9baf0 + 0c66b59 commit 0a84858
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 12 deletions.
2 changes: 1 addition & 1 deletion luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::env::current_exe;
use std::f32::consts::PI;

use luisa::lang::diff::*;
use luisa::lang::autodiff::*;
use luisa::prelude::*;
use luisa_compute as luisa;
fn main() {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use self::index::IntoIndex;

pub mod control_flow;
pub mod debug;
pub mod diff;
pub mod autodiff;
pub mod functions;
pub mod index;
pub mod ops;
Expand Down
File renamed without changes.
4 changes: 4 additions & 0 deletions luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,10 @@ impl<T: KernelSignature> Kernel<T> {
pub fn dump(&self) -> String {
ir::debug::dump_ir_human_readable(&self.inner.module.module)
}
#[doc(hidden)]
pub fn raw(&self) -> &RawKernel {
&self.inner
}
}

// A trait signifying that this argument can be used in place of an argument of type `Self::T`.
Expand Down
14 changes: 13 additions & 1 deletion luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,15 @@ impl KernelBuilder {
(resource_tracker, cpu_custom_ops, captured)
})
}
fn build_callable<R: CallableRet>(&mut self, body: impl FnOnce(&mut Self) -> R) -> RawCallable {


/// Don't use this directly
/// See [`Callable`] for how to create a callable
#[doc(hidden)]
pub fn build_callable<R: CallableRet>(
&mut self,
body: impl FnOnce(&mut Self) -> R,
) -> RawCallable {
let ret = body(self);
let ret_type = ret._return();
let (rt, cpu_custom_ops, captures) = self.collect_module_info();
Expand Down Expand Up @@ -381,6 +389,10 @@ impl KernelBuilder {
}
})
}

/// Don't use this directly
/// See [`Kernel`] for how to create a kernel
#[doc(hidden)]
pub fn build_kernel<S: KernelSignature>(
&mut self,
body: impl FnOnce(&mut Self),
Expand Down
97 changes: 89 additions & 8 deletions luisa_compute/tests/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::ops::Range;

use alias::*;
use luisa::lang::diff::*;
use luisa::lang::autodiff::*;
use luisa::lang::types::core::*;
use luisa::lang::types::vector::*;
use luisa::prelude::*;
Expand Down Expand Up @@ -249,19 +249,27 @@ macro_rules! autodiff_3 {
#[repr(C)]
#[value_new]
struct Foo {
x: f32,
y: f32,
v: Float3,
f: f32,
}

autodiff_2!(autodiff_const, 1.0..10.0, |x: Expr<f32>, y: Expr<f32>| {
let k = 2.0 / 3.0_f32.expr();
x * k + y * k
});
autodiff_2!(autodiff_struct, 1.0..10.0, |x: Expr<f32>, y: Expr<f32>| {
let foo = Foo::new_expr(x, y).var();
*foo.x += 1.0;
foo.x + foo.y
});
autodiff_3!(
autodiff_struct,
-1.0..1.0,
track!(|x: Expr<f32>, y: Expr<f32>, z: Expr<f32>| {
let foo = Foo::new_expr(Float3::expr(x, y, z), x + y + z).var();
if foo.v.x > 3.0 {
*foo.v.x -= 1.0;
} else {
*foo.v.x -= foo.f;
};
foo.v.x * foo.v.y + foo.v.z * foo.f
})
);
autodiff_1!(autodiff_sin, -10.0..10.0, |x: Expr<f32>| x.sin());
autodiff_1!(autodiff_cos, -10.0..10.0, |x: Expr<f32>| x.cos());
autodiff_1!(autodiff_sincos, -10.0..10.0, |x: Expr<f32>| x.cos()
Expand Down Expand Up @@ -1141,6 +1149,79 @@ fn autodiff_if_phi4() {
}
}
#[test]
fn autodiff_if_phi5() {
let device = get_device();
let x: Buffer<f32> = device.create_buffer(1024);
let y: Buffer<f32> = device.create_buffer(1024);
let dx: Buffer<f32> = device.create_buffer(1024);
let dy: Buffer<f32> = device.create_buffer(1024);
let mut rng = rand::thread_rng();
x.view(..).fill_fn(|_| rng.gen());
y.view(..).fill_fn(|_| rng.gen());
let kernel = Kernel::<fn()>::new(
&device,
&track!(|| {
let buf_x = x.var();
let buf_y = y.var();
let buf_dx = dx.var();
let buf_dy = dy.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let tmp_x = 0.0f32.var();
let tmp_y = 0.0f32.var();
let consts = Float3::var_zeroed();
autodiff(|| {
requires_grad(x);
requires_grad(y);
*consts = Float3::expr(2.0, 3.0, 4.0);
let const_two = consts.x;
let const_three = consts.y;
let const_four = consts.z;
let c = (x > const_three).as_::<i32>();
let z = if x > y {
switch::<Expr<f32>>(c)
.case(0, || {
*tmp_x = x * const_two;
**tmp_x
})
.default(|| {
*tmp_y = x * const_four;
**tmp_y
})
.finish()
* const_two
} else {
y * 0.5
};
backward(z);
buf_dx.write(tid, gradient(x));
buf_dy.write(tid, gradient(y));
});
}),
);
kernel.dispatch([1024, 1, 1]);
let dx = dx.view(..).copy_to_vec();
let dy = dy.view(..).copy_to_vec();
let x = x.view(..).copy_to_vec();
let y = y.view(..).copy_to_vec();
let cache_dir = kernel.cache_dir();
for i in 0..1024 {
if x[i] > y[i] {
if x[i] > 3.0 {
assert_eq!(dx[i], 8.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
} else {
assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir);
}
} else {
assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir);
assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir);
}
}
}
#[test]
fn autodiff_switch() {
let device = get_device();
let t: Buffer<i32> = device.create_buffer(1024);
Expand Down

0 comments on commit 0a84858

Please sign in to comment.