diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 23ca818..dfbc70f 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -1,5 +1,6 @@ use std::any::Any; use std::cell::{Cell, RefCell}; +use std::collections::HashSet; use std::fmt::Debug; use std::rc::Rc; use std::sync::atomic::AtomicUsize; @@ -53,14 +54,14 @@ pub(crate) trait CallFuncTrait { } impl CallFuncTrait for Func { fn call(self, x: Expr) -> Expr { - let x = process_potential_capture(x.node()).node; + let x = x.node().get(); Expr::::from_node(make_safe_node(__current_scope(|b| { b.call(self, &[x], ::type_()) }))) } fn call2(self, x: Expr, y: Expr) -> Expr { - let x = process_potential_capture(x.node()).node; - let y = process_potential_capture(y.node()).node; + let x = x.node().get(); + let y = y.node().get(); Expr::::from_node(make_safe_node(__current_scope(|b| { b.call(self, &[x, y], ::type_()) }))) @@ -71,31 +72,31 @@ impl CallFuncTrait for Func { y: Expr, z: Expr, ) -> Expr { - let x = process_potential_capture(x.node()).node; - let y = process_potential_capture(y.node()).node; - let z = process_potential_capture(z.node()).node; + let x = x.node().get(); + let y = y.node().get(); + let z = z.node().get(); Expr::::from_node(make_safe_node(__current_scope(|b| { b.call(self, &[x, y, z], ::type_()) }))) } fn call_void(self, x: Expr) { - let x = process_potential_capture(x.node()).node; + let x = x.node().get(); __current_scope(|b| { b.call(self, &[x], Type::void()); }); } fn call2_void(self, x: Expr, y: Expr) { - let x = process_potential_capture(x.node()).node; - let y = process_potential_capture(y.node()).node; + let x = x.node().get(); + let y = y.node().get(); __current_scope(|b| { b.call(self, &[x, y], Type::void()); }); } fn call3_void(self, x: Expr, y: Expr, z: Expr) { - let x = process_potential_capture(x.node()).node; - let y = process_potential_capture(y.node()).node; - let z = process_potential_capture(z.node()).node; + let x = x.node().get(); + let y = y.node().get(); + let z = z.node().get(); __current_scope(|b| { b.call(self, &[x, y, z], Type::void()); }); @@ -291,6 +292,9 @@ impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15) pub(crate) struct FnRecorder { pub(crate) parent: Option, pub(crate) scopes: Vec, + /// Nodes that are should not be acess + /// Once a basicblock is finished, all nodes in it are added to this set + pub(crate) inaccessible: Rc>>, pub(crate) kernel_id: usize, pub(crate) captured_resources: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, @@ -306,6 +310,12 @@ pub(crate) struct FnRecorder { } pub(crate) type FnRecorderPtr = Rc>; impl FnRecorder { + pub(crate) fn add_block_to_inaccessible(&self, block: &BasicBlock) { + let mut inaccessible = self.inaccessible.borrow_mut(); + for n in block.iter() { + inaccessible.insert(n); + } + } pub(crate) fn check_on_same_device(&mut self, other: &Device) -> Option<(String, String)> { if let Some(device) = &self.device { let device = device.upgrade().unwrap(); @@ -336,8 +346,12 @@ impl FnRecorder { node } } - pub(crate) fn new(kernel_id: usize) -> Self { + pub(crate) fn new(kernel_id: usize, parent: Option) -> Self { FnRecorder { + inaccessible: parent + .as_ref() + .map(|p| p.borrow().inaccessible.clone()) + .unwrap_or_else(|| Rc::new(RefCell::new(HashSet::new()))), scopes: vec![], captured_resources: IndexMap::new(), cpu_custom_ops: IndexMap::new(), @@ -351,45 +365,73 @@ impl FnRecorder { building_kernel: false, callable_ret_type: None, kernel_id, - parent: None, + parent, } } - pub(crate) fn map_captured_vars(&mut self, node: SafeNodeRef) -> SafeNodeRef { - if node.recorder == self as *mut _ { - return node; + pub(crate) fn map_captured_vars(&mut self, node0: SafeNodeRef) -> SafeNodeRef { + if node0.recorder == self as *mut _ { + return node0; } - if self.captured_vars.contains_key(&node) { - return self.captured_vars[&node]; + if self.captured_vars.contains_key(&node0) { + return self.captured_vars[&node0]; } let ptr = self as *mut _; - let parent = self - .parent - .as_mut() - .unwrap_or_else(|| panic!("Captured var outside kernel")); - match node.node.get().instruction.as_ref() { - Instruction::Local { .. } => {} - Instruction::Call { .. } => {} - Instruction::Argument { .. } => {} + let node = { + let parent = self.parent.as_mut().unwrap_or_else(|| { + panic!( + "Captured var outside kernel {:?}", + node0.node.get().instruction + ) + }); + let mut parent = parent.borrow_mut(); + let node = parent.map_captured_vars(node0); + assert_eq!(node.recorder, &mut *parent as *mut _); + assert_ne!(node.recorder, ptr); + node + }; + + let arg = match node.node.get().instruction.as_ref() { + Instruction::Local { .. } + | Instruction::Call { .. } + | Instruction::Argument { .. } + | Instruction::Phi(_) + | Instruction::Const(_) + | Instruction::Uniform => SafeNodeRef { + recorder: ptr, + node: new_node( + &self.pools, + Node::new( + CArc::new(Instruction::Argument { + by_value: !node.node.is_lvalue(), + }), + node.node.type_().clone(), + ), + ), + kernel_id: node.kernel_id, + }, + Instruction::Buffer + | Instruction::Accel + | Instruction::Bindless + | Instruction::Texture2D + | Instruction::Texture3D => { + // captured resource + SafeNodeRef { + recorder: ptr, + node: new_node( + &self.pools, + Node::new( + node.node.get().instruction.clone(), + node.node.type_().clone(), + ), + ), + kernel_id: node.kernel_id, + } + } _ => { panic!("cannot capture node {:?}", node.node.get().instruction) } - } - let arg = SafeNodeRef { - recorder: ptr, - node: new_node( - &self.pools, - Node::new( - CArc::new(Instruction::Argument { - by_value: !node.node.is_lvalue(), - }), - node.node.type_().clone(), - ), - ), - kernel_id: node.kernel_id, }; - self.captured_vars.insert(node, arg); - let mut parent = parent.borrow_mut(); - parent.map_captured_vars(node); + self.captured_vars.insert(node0, arg); arg } } @@ -416,6 +458,13 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef { cur_kernel_id, node.kernel_id, "Referencing node from another kernel!" ); + if r.inaccessible.borrow().contains(&node.node) { + panic!( + r#"Detected using node outside of its scope. It is possible that you use `RefCell` or `Cell` to store an `Expr` or `Var` +that is defined inside an if branch/loop body/switch case and use it outside its scope. +Please define a `Var` in the parent scope and assign to it instead!"# + ); + } let ptr = r as *mut _; // defined in same callable, no need to capture if ptr == node.recorder { @@ -427,9 +476,9 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef { pub(crate) fn push_recorder(kernel_id: usize) { RECORDER.with(|r| { let mut r = r.borrow_mut(); - let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id))); - let old = std::mem::replace(&mut *r, Some(new.clone())); - new.borrow_mut().parent = old; + let old = (*r).clone(); + let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id, old))); + std::mem::replace(&mut *r, Some(new.clone())); }) } pub(crate) fn pop_recorder() -> FnRecorderPtr { @@ -510,7 +559,9 @@ pub(crate) fn __check_callable(callable: &CallableModuleRef, args: &[NodeRef]) - pub fn __pop_scope() -> Pooled { with_recorder(|r| { let s = &mut r.scopes; - s.pop().unwrap().finish() + let bb = s.pop().unwrap().finish(); + r.add_block_to_inaccessible(&bb); + bb }) } @@ -529,9 +580,16 @@ pub fn __module_pools() -> &'static CArc { * scope might not be up to date Thus, for IrBuilder of each scope, it * updates the insertion point to the end of the current basic block */ -pub fn __extract(node: NodeRef, index: usize) -> NodeRef { +pub fn __extract(safe_node: SafeNodeRef, index: usize) -> SafeNodeRef { + let node = unsafe { safe_node.get_raw() }; let inst = &node.get().instruction; - with_recorder(|r| { + let r = unsafe { + safe_node + .recorder + .as_mut() + .unwrap_or_else(|| panic!("Node {:?} not in any kernel", node.get().instruction)) + }; + { let pools = { let cur_builder = r.scopes.last_mut().unwrap(); cur_builder.pools() @@ -560,6 +618,15 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { } }; } + macro_rules! wrap_up { + ($n:expr) => { + SafeNodeRef { + recorder: safe_node.recorder, + node: $n, + kernel_id: safe_node.kernel_id, + } + }; + } let op = match inst.as_ref() { Instruction::Local { .. } => Func::GetElementPtr, Instruction::Argument { by_value } => { @@ -575,14 +642,14 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { indices.push(i); let n = b.call_no_append(Func::AtomicRef, &indices, ::type_()); update_builders!(); - return n; + return wrap_up!(n); } Func::GetElementPtr => { let mut indices = args.to_vec(); indices.push(i); let n = b.call(Func::GetElementPtr, &indices, ::type_()); update_builders!(); - return n; + return wrap_up!(n); } _ => Func::ExtractElement, }, @@ -591,8 +658,8 @@ pub fn __extract(node: NodeRef, index: usize) -> NodeRef { let node = b.call(op, &[node, i], ::type_()); update_builders!(); - node - }) + wrap_up!(node) + } } pub fn __insert(node: NodeRef, index: usize, value: NodeRef) -> NodeRef { diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs index d17b928..2c43a61 100644 --- a/luisa_compute/src/lang/control_flow.rs +++ b/luisa_compute/src/lang/control_flow.rs @@ -113,6 +113,7 @@ pub fn if_then_else( let s = &mut r.scopes; let then_block = s.pop().unwrap().finish(); s.push(IrBuilder::new(pools)); + r.add_block_to_inaccessible(&then_block); then_block }); let else_ = else_(); @@ -123,7 +124,9 @@ pub fn if_then_else( .collect::>(); let else_block = with_recorder(|r| { let s = &mut r.scopes; - s.pop().unwrap().finish() + let else_block = s.pop().unwrap().finish(); + r.add_block_to_inaccessible(&else_block); + else_block }); __current_scope(|b| { b.if_(cond, then_block, else_block); @@ -210,6 +213,7 @@ pub fn generic_loop( let s = &mut r.scopes; let prepare = s.pop().unwrap().finish(); s.push(IrBuilder::new(pools)); + r.add_block_to_inaccessible(&prepare); prepare }); body(); @@ -218,12 +222,15 @@ pub fn generic_loop( let s = &mut r.scopes; let body = s.pop().unwrap().finish(); s.push(IrBuilder::new(pools)); + r.add_block_to_inaccessible(&body); body }); update(); let update = with_recorder(|r| { let s = &mut r.scopes; - s.pop().unwrap().finish() + let update_block = s.pop().unwrap().finish(); + r.add_block_to_inaccessible(&update_block); + update_block }); __current_scope(|b| { b.generic_loop(prepare, cond_v, body, update); diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index 3fc4d77..a9c8c36 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -131,7 +131,7 @@ macro_rules! impl_array_vec_conversion{ where T: vector::VectorAlign<$N>, { fn from(vec:Expr>)->Self{ - let elems = (0..$N).map(|i| __extract::(vec.node().get(), i)).collect::>(); + let elems = (0..$N).map(|i| __extract::(vec.node(), i).get()).collect::>(); let node = __current_scope(|b| b.call(Func::Array, &elems, <[T;$N]>::type_())); Self::from_node(node.into()) } diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index a87bbe4..18e4755 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -48,7 +48,7 @@ pub struct VectorExprData, const N: usize>([Expr; N]); impl, const N: usize> FromNode for VectorExprData { fn from_node(node: SafeNodeRef) -> Self { Self(std::array::from_fn(|i| { - FromNode::from_node(__extract::(node.get(), i).into()) + FromNode::from_node(__extract::(node, i)) })) } } @@ -58,7 +58,7 @@ pub struct VectorVarData, const N: usize>([Var; N]); impl, const N: usize> FromNode for VectorVarData { fn from_node(node: SafeNodeRef) -> Self { Self(std::array::from_fn(|i| { - FromNode::from_node(__extract::(node.get(), i).into()) + FromNode::from_node(__extract::(node, i)) })) } } @@ -69,7 +69,7 @@ pub struct VectorAtomicRefData, const N: usize>([AtomicRef; impl, const N: usize> FromNode for VectorAtomicRefData { fn from_node(node: SafeNodeRef) -> Self { Self(std::array::from_fn(|i| { - FromNode::from_node(__extract::(node.get(), i).into()) + FromNode::from_node(__extract::(node, i)) })) } } @@ -161,7 +161,7 @@ macro_rules! vector_proxies { let mut comp = 0; $( { - let el = Expr::::from_node(__extract::(v.node().get(), comp).into()); + let el = Expr::::from_node(__extract::(v.node(), comp)); self.$real_c.write(i, el); comp += 1; } diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 92fb78c..6a0e8ee 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1306,13 +1306,15 @@ impl DynCallable { r.borrow().device.clone().unwrap() }; let kernel_id = r_ptr.as_ref().unwrap().borrow().kernel_id; - ( - std::mem::replace( - &mut *r_ptr, - Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id)))), - ), - device.upgrade().unwrap(), - ) + let r_backup = (*r_ptr).clone(); + std::mem::replace( + &mut *r_ptr, + Some(Rc::new(RefCell::new(FnRecorder::new( + kernel_id, + r_backup.clone(), + )))), + ); + (r_backup, device.upgrade().unwrap()) }); let mut builder = KernelBuilder::new(Some(device), false); let new_callable = (inner.builder)(args, &mut builder); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 41ca6b7..4bb99f8 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -387,6 +387,7 @@ impl KernelBuilder { assert_eq!(r.scopes.len(), 1); let scope = r.scopes.pop().unwrap(); let entry = scope.finish(); + r.add_block_to_inaccessible(&entry); let ir_module = Module { entry, kind: ModuleKind::Kernel, @@ -398,7 +399,9 @@ impl KernelBuilder { let mut args = self.args.clone(); args.extend(r.captured_vars.values().map(|x| unsafe { x.get_raw() })); - + for a in &args { + r.inaccessible.borrow_mut().insert(*a); + } let module = CallableModule { module: ir_module, ret_type, @@ -562,7 +565,8 @@ macro_rules! impl_callable { let r_backup = RECORDER.with(|r| { let mut r = r.borrow_mut(); let kernel_id = r.as_ref().unwrap().borrow().kernel_id; - std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id))))) + let old = r.clone(); + std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id, old))))) }); let mut builder = KernelBuilder::new(None, false); let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 459fecb..13db1ce 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -118,6 +118,145 @@ fn nested_callable_capture_by_ref() { } } #[test] +fn nested_callable_outline_twice() { + let device = get_device(); + + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + let z = 0.0f32.var(); + outline(|| { + outline(|| { + *z += y; + }); + outline(|| { + *z += x; + }) + }); + buf_z.write(tid, z); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + for i in 0..x.len() { + assert_eq!(z_data[i], (i as f32 + 1000.0 * i as f32)); + } +} + +#[derive(Clone, Copy, Debug, Value, Soa, PartialEq)] +#[repr(C)] +#[value_new(pub)] +pub struct A { + v: Float3, +} +#[test] +fn nested_callable_capture_gep() { + let device = get_device(); + + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let a = Var::::zeroed(); + outline(|| { + let x = buf_x.read(tid); + let y = buf_y.read(tid); + *a.v = Float3::expr(x, y, 0.0); + outline(|| { + let v = a.v; + *v.z += v.x; + }); + outline(|| { + let v = a.v; + *a.v.z += v.y; + }); + buf_z.write(tid, a.v.z); + }); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + for i in 0..x.len() { + assert_eq!(z_data[i], (i as f32 + 1000.0 * i as f32)); + } +} +#[test] +fn nested_callable_capture_buffer() { + let device = get_device(); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let tid = dispatch_id().x; + let z = 0.0f32.var(); + outline(|| { + let buf_x = x.var(); + let buf_y = y.var(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + *z = x + y; + }); + buf_z.write(tid, z); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + for i in 0..x.len() { + assert_eq!(z_data[i], (i as f32 + 1000.0 * i as f32)); + } +} +#[test] +fn nested_callable_capture_buffer_var() { + let device = get_device(); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + + let z = 0.0f32.var(); + outline(|| { + let x = buf_x.read(tid); + let y = buf_y.read(tid); + *z = x + y; + }); + buf_z.write(tid, z); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + for i in 0..x.len() { + assert_eq!(z_data[i], (i as f32 + 1000.0 * i as f32)); + } +} +#[test] #[should_panic] fn callable_different_device() { let device1 = get_device(); diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index f169bd5..b78c1ea 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -253,7 +253,7 @@ impl Compiler { quote_spanned!(span=> let #ident = < #lang_path::types::Expr::<#ty> as #lang_path::FromNode>::from_node(#lang_path::__extract::<#ty>( __node, #i, - ).into()); + )); ) }) .collect(); @@ -266,7 +266,7 @@ impl Compiler { quote_spanned!(span=> let #ident = < #lang_path::types::Var::<#ty> as #lang_path::FromNode>::from_node(#lang_path::__extract::<#ty>( __node, #i, - ).into()); + )); ) }) .collect(); @@ -279,7 +279,7 @@ impl Compiler { quote_spanned!(span=> let #ident = < #lang_path::types::AtomicRef::<#ty> as #lang_path::FromNode>::from_node(#lang_path::__extract::<#ty>( __node, #i, - ).into()); + )); ) }) .collect(); @@ -362,7 +362,7 @@ impl Compiler { type Value = #name #ty_generics; fn from_expr(expr: #lang_path::types::Expr<#name #ty_generics>) -> Self { use #lang_path::ToNode; - let __node = expr.node().get(); + let __node = expr.node(); #(#extract_expr_fields)* Self{ self_:expr, @@ -380,7 +380,7 @@ impl Compiler { type Value = #name #ty_generics; fn from_var(var: #lang_path::types::Var<#name #ty_generics>) -> Self { use #lang_path::ToNode; - let __node = var.node().get(); + let __node = var.node(); #(#extract_var_fields)* Self{ self_:var, @@ -397,7 +397,7 @@ impl Compiler { type Value = #name #ty_generics; fn from_atomic_ref(var: #lang_path::types::AtomicRef<#name #ty_generics>) -> Self { use #lang_path::ToNode; - let __node = var.node().get(); + let __node = var.node(); #(#extract_atomic_ref_fields)* Self{ self_:var,