Skip to content

Commit

Permalink
fix return_v()
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 17, 2023
1 parent c04e475 commit d2be706
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 14 deletions.
14 changes: 13 additions & 1 deletion luisa_compute/examples/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@ fn main() {
use luisa::*;
init_logger();
let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device("cpu");
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});
let add = device.create_callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>(&|a, b| a + b);
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
Expand Down
14 changes: 13 additions & 1 deletion luisa_compute/examples/callable_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@ fn main() {
use luisa::*;
init_logger();
let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device("cpu");
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});
let add = device.create_dyn_callable::<fn(DynExpr, DynExpr) -> DynExpr>(Box::new(
|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
Expand Down
47 changes: 42 additions & 5 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ pub(crate) struct Recorder {
pub(crate) building_kernel: bool,
pub(crate) pools: Option<CArc<ModulePools>>,
pub(crate) arena: Bump,
pub(crate) callable_ret_type: Option<CArc<Type>>,
}

impl Recorder {
Expand All @@ -619,6 +620,7 @@ impl Recorder {
self.arena.reset();
self.shared.clear();
self.kernel_id = None;
self.callable_ret_type = None;
}
pub(crate) fn new() -> Self {
Recorder {
Expand All @@ -634,6 +636,7 @@ impl Recorder {
arena: Bump::new(),
building_kernel: false,
kernel_id: None,
callable_ret_type: None,
}
}
}
Expand Down Expand Up @@ -2009,6 +2012,14 @@ impl KernelBuilder {
let mut r = r.borrow_mut();
assert!(r.lock);
r.lock = false;
if let Some(t) = &r.callable_ret_type {
assert!(
luisa_compute_ir::context::is_type_equal(t, &ret_type),
"Return type mismatch"
);
} else {
r.callable_ret_type = Some(ret_type.clone());
}
assert_eq!(r.scopes.len(), 1);
let scope = r.scopes.pop().unwrap();
let entry = scope.finish();
Expand Down Expand Up @@ -2732,12 +2743,38 @@ pub fn continue_() {
});
}

// pub fn return_v<T: FromNode>(v: T) {
// __current_scope(|b| {
// b.return_(Some(v.node()));
// });
// }
pub fn return_v<T: FromNode>(v: T) {
RECORDER.with(|r| {
let mut r = r.borrow_mut();
if r.callable_ret_type.is_none() {
r.callable_ret_type = Some(v.node().type_().clone());
} else {
assert!(
luisa_compute_ir::context::is_type_equal(
r.callable_ret_type.as_ref().unwrap(),
v.node().type_()
),
"return type mismatch"
);
}
});
__current_scope(|b| {
b.return_(v.node());
});
}

pub fn return_() {
RECORDER.with(|r| {
let mut r = r.borrow_mut();
if r.callable_ret_type.is_none() {
r.callable_ret_type = Some(Type::void());
} else {
assert!(luisa_compute_ir::context::is_type_equal(
r.callable_ret_type.as_ref().unwrap(),
&Type::void()
));
}
});
__current_scope(|b| {
b.return_(INVALID_REF);
});
Expand Down
69 changes: 63 additions & 6 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,55 @@ fn event() {
assert_eq!(v[0], (1 + 3) * (4 + 5));
}
#[test]
#[should_panic]
fn callable_return_mismatch() {
let device = get_device();
let _abs = device.create_callable::<fn(Expr<f32>) -> Expr<f32>>(&|x| {
if_!(x.cmpgt(0.0), {
return_v(const_(true));
});
-x
});
}
#[test]
#[should_panic]
fn callable_return_void_mismatch() {
let device = get_device();
let _abs = device.create_callable::<fn(Var<f32>)>(&|x| {
if_!(x.cmpgt(0.0), {
return_v(const_(true));
});
x.store(-*x);
});
}
#[test]
fn callable_early_return() {
let device = get_device();
let abs = device.create_callable::<fn(Expr<f32>) -> Expr<f32>>(&|x| {
if_!(x.cmpgt(0.0), {
return_v(x);
});
-x
});
let x = device.create_buffer::<f32>(1024);
let mut rng = StdRng::seed_from_u64(0);
x.fill_fn(|_| rng.gen());
let y = device.create_buffer::<f32>(1024);
device
.create_kernel::<fn()>(&|| {
let i = dispatch_id().x();
let x = x.var().read(i);
let y = y.var();
y.write(i, abs.call(x));
})
.dispatch([x.len() as u32, 1, 1]);
let x = x.copy_to_vec();
let y = y.copy_to_vec();
for i in 0..x.len() {
assert_eq!(y[i], x[i].abs());
}
}
#[test]
fn callable() {
let device = get_device();
let write = device.create_callable::<fn(BufferVar<u32>, Expr<u32>, Var<u32>)>(
Expand All @@ -91,7 +140,7 @@ fn callable() {
v.store(v.load() + 1);
},
);
let add = device.create_callable::<fn(Expr<u32>, Expr<u32>)->Expr<u32>>(&|a, b| a + b);
let add = device.create_callable::<fn(Expr<u32>, Expr<u32>) -> Expr<u32>>(&|a, b| a + b);
let x = device.create_buffer::<u32>(1024);
let y = device.create_buffer::<u32>(1024);
let z = device.create_buffer::<u32>(1024);
Expand Down Expand Up @@ -715,15 +764,19 @@ fn byte_buffer() {
($t:ty, $offset:expr) => {{
let s = std::mem::size_of::<$t>();
let bytes = &data[$offset..$offset + s];
let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) };
let v = unsafe {
std::mem::transmute_copy::<[u8; { std::mem::size_of::<$t>() }], $t>(
bytes.try_into().unwrap(),
)
};
v
}};
}
let v0 = pop!(Float3, i0);
let v1 = pop!(Big, i1);
let v2 = pop!(i32, i2);
let v3 = pop!(f32, i3);
assert_eq!(v0, Float3::new(1.0,2.0,3.0));
assert_eq!(v0, Float3::new(1.0, 2.0, 3.0));
assert_eq!(v2, 1);
assert_eq!(v3, 2.0);
for i in 0..32 {
Expand Down Expand Up @@ -759,7 +812,7 @@ fn bindless_byte_buffer() {
let i2 = push!(i32, 0i32);
let i3 = push!(f32, 1f32);
device
.create_kernel::<fn(ByteBuffer)>(&|out:ByteBufferVar| {
.create_kernel::<fn(ByteBuffer)>(&|out: ByteBufferVar| {
let heap = heap.var();
let buf = heap.byte_address_buffer(0);
let i0 = i0 as u64;
Expand Down Expand Up @@ -787,15 +840,19 @@ fn bindless_byte_buffer() {
($t:ty, $offset:expr) => {{
let s = std::mem::size_of::<$t>();
let bytes = &data[$offset..$offset + s];
let v = unsafe { std::mem::transmute_copy::<[u8; {std::mem::size_of::<$t>()}], $t>(bytes.try_into().unwrap()) };
let v = unsafe {
std::mem::transmute_copy::<[u8; { std::mem::size_of::<$t>() }], $t>(
bytes.try_into().unwrap(),
)
};
v
}};
}
let v0 = pop!(Float3, i0);
let v1 = pop!(Big, i1);
let v2 = pop!(i32, i2);
let v3 = pop!(f32, i3);
assert_eq!(v0, Float3::new(1.0,2.0,3.0));
assert_eq!(v0, Float3::new(1.0, 2.0, 3.0));
assert_eq!(v2, 1);
assert_eq!(v3, 2.0);
for i in 0..32 {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_sys/LuisaCompute

0 comments on commit d2be706

Please sign in to comment.