Skip to content

Commit

Permalink
fixed capture
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 10, 2023
1 parent 699f98e commit 7c98a1b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
10 changes: 5 additions & 5 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ pub(crate) struct FnRecorder {
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
pub(crate) captured_vars: IndexMap<SafeNodeRef, SafeNodeRef>,
pub(crate) captured_vars: IndexMap<NodeRef, (NodeRef, SafeNodeRef)>,
pub(crate) shared: Vec<NodeRef>,
pub(crate) device: Option<WeakDevice>,
pub(crate) block_size: Option<[u32; 3]>,
Expand Down Expand Up @@ -372,8 +372,8 @@ impl FnRecorder {
if node0.recorder == self as *mut _ {
return node0;
}
if self.captured_vars.contains_key(&node0) {
return self.captured_vars[&node0];
if self.captured_vars.contains_key(&node0.node) {
return self.captured_vars[&node0.node].1;
}
let ptr = self as *mut _;
let node = {
Expand Down Expand Up @@ -431,7 +431,7 @@ impl FnRecorder {
panic!("cannot capture node {:?}", node.node.get().instruction)
}
};
self.captured_vars.insert(node0, arg);
self.captured_vars.insert(node0.node, (node.node, arg));
arg
}
}
Expand Down Expand Up @@ -478,7 +478,7 @@ pub(crate) fn push_recorder(kernel_id: usize) {
let mut r = r.borrow_mut();
let old = (*r).clone();
let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id, old)));
std::mem::replace(&mut *r, Some(new.clone()));
*r = Some(new.clone());
})
}
pub(crate) fn pop_recorder() -> FnRecorderPtr {
Expand Down
11 changes: 4 additions & 7 deletions luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1307,13 +1307,10 @@ impl<S: CallableSignature> DynCallable<S> {
};
let kernel_id = r_ptr.as_ref().unwrap().borrow().kernel_id;
let r_backup = (*r_ptr).clone();
std::mem::replace(
&mut *r_ptr,
Some(Rc::new(RefCell::new(FnRecorder::new(
kernel_id,
r_backup.clone(),
)))),
);
*r_ptr = Some(Rc::new(RefCell::new(FnRecorder::new(
kernel_id,
r_backup.clone(),
))));
(r_backup, device.upgrade().unwrap())
});
let mut builder = KernelBuilder::new(Some(device), false);
Expand Down
8 changes: 2 additions & 6 deletions luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ impl KernelBuilder {
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);

let mut args = self.args.clone();
args.extend(r.captured_vars.values().map(|x| unsafe { x.get_raw() }));
args.extend(r.captured_vars.values().map(|x| unsafe { x.1.get_raw() }));
for a in &args {
r.inaccessible.borrow_mut().insert(*a);
}
Expand All @@ -416,11 +416,7 @@ impl KernelBuilder {
device: self.device.clone(),
module,
resource_tracker: rt,
captured_args: r
.captured_vars
.keys()
.map(|x| unsafe { x.get_raw() })
.collect(),
captured_args: r.captured_vars.values().map(|x| x.0).collect(),
}
});
pop_recorder();
Expand Down
14 changes: 14 additions & 0 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,20 @@ fn callable_return_void_mismatch() {
}
#[test]
#[should_panic]
#[tracked]
fn illegal_scope_sharing() {
let device = get_device();
let tid = RefCell::new(None);
Kernel::<fn()>::new(&device, &|| {
let i = dispatch_id().x;
if i % 2 == 0 {
*tid.borrow_mut() = Some(i + 1);
}
let _v = tid.borrow().unwrap() + 1;
});
}
#[test]
#[should_panic]
fn callable_illegal_sharing() {
let device = get_device();
let tid = RefCell::new(None);
Expand Down

0 comments on commit 7c98a1b

Please sign in to comment.