From 7c98a1bea15613f02df042ba6500e04cbc5c1034 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 10 Oct 2023 16:11:41 -0400 Subject: [PATCH] fixed capture --- luisa_compute/src/lang.rs | 10 +++++----- luisa_compute/src/runtime.rs | 11 ++++------- luisa_compute/src/runtime/kernel.rs | 8 ++------ luisa_compute/tests/misc.rs | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index dfbc70f..61503ca 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -299,7 +299,7 @@ pub(crate) struct FnRecorder { pub(crate) captured_resources: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, - pub(crate) captured_vars: IndexMap, + pub(crate) captured_vars: IndexMap, pub(crate) shared: Vec, pub(crate) device: Option, pub(crate) block_size: Option<[u32; 3]>, @@ -372,8 +372,8 @@ impl FnRecorder { if node0.recorder == self as *mut _ { return node0; } - if self.captured_vars.contains_key(&node0) { - return self.captured_vars[&node0]; + if self.captured_vars.contains_key(&node0.node) { + return self.captured_vars[&node0.node].1; } let ptr = self as *mut _; let node = { @@ -431,7 +431,7 @@ impl FnRecorder { panic!("cannot capture node {:?}", node.node.get().instruction) } }; - self.captured_vars.insert(node0, arg); + self.captured_vars.insert(node0.node, (node.node, arg)); arg } } @@ -478,7 +478,7 @@ pub(crate) fn push_recorder(kernel_id: usize) { let mut r = r.borrow_mut(); let old = (*r).clone(); let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id, old))); - std::mem::replace(&mut *r, Some(new.clone())); + *r = Some(new.clone()); }) } pub(crate) fn pop_recorder() -> FnRecorderPtr { diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 6a0e8ee..74f5204 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1307,13 +1307,10 @@ impl DynCallable { }; let kernel_id = r_ptr.as_ref().unwrap().borrow().kernel_id; let r_backup = (*r_ptr).clone(); - std::mem::replace( - &mut *r_ptr, - Some(Rc::new(RefCell::new(FnRecorder::new( - kernel_id, - r_backup.clone(), - )))), - ); + *r_ptr = Some(Rc::new(RefCell::new(FnRecorder::new( + kernel_id, + r_backup.clone(), + )))); (r_backup, device.upgrade().unwrap()) }); let mut builder = KernelBuilder::new(Some(device), false); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 4bb99f8..b8ec62b 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -398,7 +398,7 @@ impl KernelBuilder { let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module); let mut args = self.args.clone(); - args.extend(r.captured_vars.values().map(|x| unsafe { x.get_raw() })); + args.extend(r.captured_vars.values().map(|x| unsafe { x.1.get_raw() })); for a in &args { r.inaccessible.borrow_mut().insert(*a); } @@ -416,11 +416,7 @@ impl KernelBuilder { device: self.device.clone(), module, resource_tracker: rt, - captured_args: r - .captured_vars - .keys() - .map(|x| unsafe { x.get_raw() }) - .collect(), + captured_args: r.captured_vars.values().map(|x| x.0).collect(), } }); pop_recorder(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 13db1ce..411c4c2 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -317,6 +317,20 @@ fn callable_return_void_mismatch() { } #[test] #[should_panic] +#[tracked] +fn illegal_scope_sharing() { + let device = get_device(); + let tid = RefCell::new(None); + Kernel::::new(&device, &|| { + let i = dispatch_id().x; + if i % 2 == 0 { + *tid.borrow_mut() = Some(i + 1); + } + let _v = tid.borrow().unwrap() + 1; + }); +} +#[test] +#[should_panic] fn callable_illegal_sharing() { let device = get_device(); let tid = RefCell::new(None);