Skip to content

Commit

Permalink
add checks for cross device sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
1 parent b9230ba commit 77d4c4e
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 60 deletions.
37 changes: 34 additions & 3 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ pub(crate) struct Recorder {
pub(crate) scopes: Vec<IrBuilder>,
pub(crate) kernel_id: Option<usize>,
pub(crate) lock: bool,
pub(crate) captured_buffer: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
pub(crate) shared: Vec<NodeRef>,
Expand All @@ -270,7 +270,7 @@ pub(crate) struct Recorder {
impl Recorder {
pub(crate) fn reset(&mut self) {
self.scopes.clear();
self.captured_buffer.clear();
self.captured_resources.clear();
self.cpu_custom_ops.clear();
self.callables.clear();
self.lock = false;
Expand All @@ -281,11 +281,42 @@ impl Recorder {
self.kernel_id = None;
self.callable_ret_type = None;
}

pub(crate) fn check_on_same_device(&self, other: &Device) -> Option<(String, String)> {
if let Some(device) = &self.device {
let device = device.upgrade().unwrap();
if !Arc::ptr_eq(&device.inner, &other.inner) {
return Some((
format!("{} at {:?}", device.name(), Arc::as_ptr(&device.inner)),
format!("{} at {:?}", other.name(), Arc::as_ptr(&other.inner)),
));
}
} else {
// @FIXME: What should we do?
}
None
}
pub(crate) fn capture_or_get<T: Any>(
&mut self,
binding: ir::Binding,
handle: &Arc<T>,
create_node: impl FnOnce() -> Node,
) -> NodeRef {
if let Some((_, node, _, _)) = self.captured_resources.get(&binding) {
*node
} else {
let node = new_node(self.pools.as_ref().unwrap(), create_node());
let i = self.captured_resources.len();
self.captured_resources
.insert(binding, (i, node, binding, handle.clone()));
node
}
}
pub(crate) fn new() -> Self {
Recorder {
scopes: vec![],
lock: false,
captured_buffer: IndexMap::new(),
captured_resources: IndexMap::new(),
cpu_custom_ops: IndexMap::new(),
callables: IndexMap::new(),
shared: vec![],
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/src/lang/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ pub fn __unreachable(file: &str, line: u32, col: u32) {
} else {
pretty_filename = file.to_string();
}
let msg = if is_cpu_backend() && __env_need_backtrace() {
let msg = if __env_need_backtrace() {
let backtrace = get_backtrace();
format!(
"unreachable code at {}:{}:{} \nbacktrace: {}",
Expand Down Expand Up @@ -190,7 +190,7 @@ pub fn __assert(cond: impl Into<Expr<bool>>, msg: &str, file: &str, line: u32, c
} else {
pretty_filename = file.to_string();
}
let msg = if is_cpu_backend() && __env_need_backtrace() {
let msg = if __env_need_backtrace() {
let backtrace = get_backtrace();
format!(
"assertion failed: {} at {}:{}:{} \nbacktrace:\n{}",
Expand Down
69 changes: 28 additions & 41 deletions luisa_compute/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1451,19 +1451,15 @@ impl BindlessArrayVar {
);
let handle: u64 = array.handle().0;
let binding = Binding::BindlessArray(BindlessArrayBinding { handle });

if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) {
*node
} else {
let node = new_node(
r.pools.as_ref().unwrap(),
Node::new(CArc::new(Instruction::Bindless), Type::void()),
if let Some((a, b)) = r.check_on_same_device(&array.device) {
panic!(
"BindlessArray created for a device: `{:?}` but used in `{:?}`",
b, a
);
let i = r.captured_buffer.len();
r.captured_buffer
.insert(binding, (i, node, binding, array.handle.clone()));
node
}
r.capture_or_get(binding, &array.handle, || {
Node::new(CArc::new(Instruction::Bindless), Type::void())
})
});
Self {
node,
Expand Down Expand Up @@ -1525,18 +1521,15 @@ impl<T: Value> BufferVar<T> {
size: buffer.len * std::mem::size_of::<T>(),
offset: (buffer.offset * std::mem::size_of::<T>()) as u64,
});
if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) {
*node
} else {
let node = new_node(
r.pools.as_ref().unwrap(),
Node::new(CArc::new(Instruction::Buffer), T::type_()),
if let Some((a, b)) = r.check_on_same_device(&buffer.buffer.device) {
panic!(
"Buffer created for a device: `{:?}` but used in `{:?}`",
b, a
);
let i = r.captured_buffer.len();
r.captured_buffer
.insert(binding, (i, node, binding, buffer.buffer.handle.clone()));
node
}
r.capture_or_get(binding, &buffer.buffer.handle, || {
Node::new(CArc::new(Instruction::Buffer), T::type_())
})
});
Self {
node,
Expand Down Expand Up @@ -1766,18 +1759,15 @@ impl<T: IoTexel> Tex2dVar<T> {
handle,
level: view.level,
});
if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) {
*node
} else {
let node = new_node(
r.pools.as_ref().unwrap(),
Node::new(CArc::new(Instruction::Texture2D), T::RwType::type_()),
if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) {
panic!(
"Tex2d created for a device: `{:?}` but used in `{:?}`",
b, a
);
let i = r.captured_buffer.len();
r.captured_buffer
.insert(binding, (i, node, binding, view.tex.handle.clone()));
node
}
r.capture_or_get(binding, &view.tex.handle, || {
Node::new(CArc::new(Instruction::Texture2D), T::RwType::type_())
})
});
Self {
node,
Expand Down Expand Up @@ -1825,18 +1815,15 @@ impl<T: IoTexel> Tex3dVar<T> {
handle,
level: view.level,
});
if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) {
*node
} else {
let node = new_node(
r.pools.as_ref().unwrap(),
Node::new(CArc::new(Instruction::Texture3D), T::RwType::type_()),
if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) {
panic!(
"Tex3d created for a device: `{:?}` but used in `{:?}`",
b, a
);
let i = r.captured_buffer.len();
r.captured_buffer
.insert(binding, (i, node, binding, view.tex.handle.clone()));
node
}
r.capture_or_get(binding, &view.tex.handle, || {
Node::new(CArc::new(Instruction::Texture3D), T::RwType::type_())
})
});
Self {
node,
Expand Down
19 changes: 8 additions & 11 deletions luisa_compute/src/rtx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,21 +662,18 @@ impl AccelVar {
pub fn new(accel: &rtx::Accel) -> Self {
let node = RECORDER.with(|r| {
let mut r = r.borrow_mut();
assert!(r.lock, "BufferVar must be created from within a kernel");
assert!(r.lock, "AccelVar must be created from within a kernel");
let handle: u64 = accel.handle().0;
let binding = Binding::Accel(AccelBinding { handle });
if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) {
*node
} else {
let node = new_node(
r.pools.as_ref().unwrap(),
Node::new(CArc::new(Instruction::Accel), Type::void()),
if let Some((a, b)) = r.check_on_same_device(&accel.handle.device) {
panic!(
"Accel created for a device: `{:?}` but used in `{:?}`",
b, a
);
let i = r.captured_buffer.len();
r.captured_buffer
.insert(binding, (i, node, binding, accel.handle.clone()));
node
}
r.capture_or_get(binding, &accel.handle, || {
Node::new(CArc::new(Instruction::Accel), Type::void())
})
});
Self {
node,
Expand Down
21 changes: 21 additions & 0 deletions luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ impl Drop for DeviceHandle {
}

impl Device {
pub fn query(&self, name: &str) -> Option<String> {
self.inner.query(name)
}
pub fn name(&self) -> String {
self.query("device_name").unwrap_or("unknown".to_string())
}
pub fn create_swapchain(
&self,
window: &Window,
Expand Down Expand Up @@ -1237,10 +1243,24 @@ impl<S: CallableSignature> DynCallable<S> {
unsafe impl Send for RawCallable {}
unsafe impl Sync for RawCallable {}
pub struct RawCallable {
#[allow(dead_code)]
pub(crate) device: Option<Device>,
pub(crate) module: ir::CallableModuleRef,
#[allow(dead_code)]
pub(crate) resource_tracker: ResourceTracker,
}
impl RawCallable {
pub(crate) fn check_on_same_device(&self) {
RECORDER.with(|r| {
let r =r.borrow();
if let Some(device) = &self.device {
if let Some((a,b)) = r.check_on_same_device(device) {
panic!("Callable created on a different device than the one it is called on: {:?} vs {:?}", a,b);
}
}
});
}
}
pub struct RawKernelDef {
#[allow(dead_code)]
pub(crate) device: Option<Device>,
Expand Down Expand Up @@ -1339,6 +1359,7 @@ macro_rules! impl_call_for_callable {
#[allow(unused_mut)]
pub fn call(&self, $($Ts:$Ts),*) -> R {
let mut encoder = CallableArgEncoder::new();
self.inner.check_on_same_device();
$($Ts.encode(&mut encoder);)*
CallableRet::_from_return(
crate::lang::__invoke_callable(&self.inner.module, &encoder.args))
Expand Down
7 changes: 4 additions & 3 deletions luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ impl KernelBuilder {
let mut resource_tracker = ResourceTracker::new();
let r = r.borrow_mut();
let mut captured: Vec<Capture> = Vec::new();
let mut captured_buffers: Vec<_> = r.captured_buffer.values().cloned().collect();
captured_buffers.sort_by_key(|(i, _, _, _)| *i);
for (j, (i, node, binding, handle)) in captured_buffers.into_iter().enumerate() {
let mut captured_resources: Vec<_> = r.captured_resources.values().cloned().collect();
captured_resources.sort_by_key(|(i, _, _, _)| *i);
for (j, (i, node, binding, handle)) in captured_resources.into_iter().enumerate() {
assert_eq!(j, i);
captured.push(Capture { node, binding });
resource_tracker.add_any(handle);
Expand Down Expand Up @@ -352,6 +352,7 @@ impl KernelBuilder {
let module = CallableModuleRef(CArc::new(module));
r.reset();
RawCallable {
device: self.device.clone(),
module,
resource_tracker: rt,
}
Expand Down
16 changes: 16 additions & 0 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,23 @@ fn event() {
let v = a.copy_to_vec();
assert_eq!(v[0], (1 + 3) * (4 + 5));
}

#[test]
#[should_panic]
fn callable_different_device() {
let device1 = get_device();
let device2 = get_device();
let abs = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(
&device1,
track!(|x| {
if x > 0.0 {
return x;
}
-x
}),
);
let _foo = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device2, |x| abs.call(x));
}
#[test]
#[should_panic]
fn callable_return_mismatch() {
Expand Down

0 comments on commit 77d4c4e

Please sign in to comment.