From 77d4c4e024e6335bcd3ff5addfbfb87fce56a905 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 24 Sep 2023 01:02:50 -0400 Subject: [PATCH] add checks for cross device sharing --- luisa_compute/src/lang.rs | 37 ++++++++++++++-- luisa_compute/src/lang/debug.rs | 4 +- luisa_compute/src/resource.rs | 69 ++++++++++++----------------- luisa_compute/src/rtx.rs | 19 ++++---- luisa_compute/src/runtime.rs | 21 +++++++++ luisa_compute/src/runtime/kernel.rs | 7 +-- luisa_compute/tests/misc.rs | 16 +++++++ 7 files changed, 113 insertions(+), 60 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index a9ebb710..a6fada1e 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -255,7 +255,7 @@ pub(crate) struct Recorder { pub(crate) scopes: Vec, pub(crate) kernel_id: Option, pub(crate) lock: bool, - pub(crate) captured_buffer: IndexMap)>, + pub(crate) captured_resources: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, pub(crate) shared: Vec, @@ -270,7 +270,7 @@ pub(crate) struct Recorder { impl Recorder { pub(crate) fn reset(&mut self) { self.scopes.clear(); - self.captured_buffer.clear(); + self.captured_resources.clear(); self.cpu_custom_ops.clear(); self.callables.clear(); self.lock = false; @@ -281,11 +281,42 @@ impl Recorder { self.kernel_id = None; self.callable_ret_type = None; } + + pub(crate) fn check_on_same_device(&self, other: &Device) -> Option<(String, String)> { + if let Some(device) = &self.device { + let device = device.upgrade().unwrap(); + if !Arc::ptr_eq(&device.inner, &other.inner) { + return Some(( + format!("{} at {:?}", device.name(), Arc::as_ptr(&device.inner)), + format!("{} at {:?}", other.name(), Arc::as_ptr(&other.inner)), + )); + } + } else { + // @FIXME: What should we do? + } + None + } + pub(crate) fn capture_or_get( + &mut self, + binding: ir::Binding, + handle: &Arc, + create_node: impl FnOnce() -> Node, + ) -> NodeRef { + if let Some((_, node, _, _)) = self.captured_resources.get(&binding) { + *node + } else { + let node = new_node(self.pools.as_ref().unwrap(), create_node()); + let i = self.captured_resources.len(); + self.captured_resources + .insert(binding, (i, node, binding, handle.clone())); + node + } + } pub(crate) fn new() -> Self { Recorder { scopes: vec![], lock: false, - captured_buffer: IndexMap::new(), + captured_resources: IndexMap::new(), cpu_custom_ops: IndexMap::new(), callables: IndexMap::new(), shared: vec![], diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 4ac39c59..afe728be 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -154,7 +154,7 @@ pub fn __unreachable(file: &str, line: u32, col: u32) { } else { pretty_filename = file.to_string(); } - let msg = if is_cpu_backend() && __env_need_backtrace() { + let msg = if __env_need_backtrace() { let backtrace = get_backtrace(); format!( "unreachable code at {}:{}:{} \nbacktrace: {}", @@ -190,7 +190,7 @@ pub fn __assert(cond: impl Into>, msg: &str, file: &str, line: u32, c } else { pretty_filename = file.to_string(); } - let msg = if is_cpu_backend() && __env_need_backtrace() { + let msg = if __env_need_backtrace() { let backtrace = get_backtrace(); format!( "assertion failed: {} at {}:{}:{} \nbacktrace:\n{}", diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index ac18f52a..1322076d 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1451,19 +1451,15 @@ impl BindlessArrayVar { ); let handle: u64 = array.handle().0; let binding = Binding::BindlessArray(BindlessArrayBinding { handle }); - - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Bindless), Type::void()), + if let Some((a, b)) = r.check_on_same_device(&array.device) { + panic!( + "BindlessArray created for a device: `{:?}` but used in `{:?}`", + b, a ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, array.handle.clone())); - node } + r.capture_or_get(binding, &array.handle, || { + Node::new(CArc::new(Instruction::Bindless), Type::void()) + }) }); Self { node, @@ -1525,18 +1521,15 @@ impl BufferVar { size: buffer.len * std::mem::size_of::(), offset: (buffer.offset * std::mem::size_of::()) as u64, }); - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Buffer), T::type_()), + if let Some((a, b)) = r.check_on_same_device(&buffer.buffer.device) { + panic!( + "Buffer created for a device: `{:?}` but used in `{:?}`", + b, a ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, buffer.buffer.handle.clone())); - node } + r.capture_or_get(binding, &buffer.buffer.handle, || { + Node::new(CArc::new(Instruction::Buffer), T::type_()) + }) }); Self { node, @@ -1766,18 +1759,15 @@ impl Tex2dVar { handle, level: view.level, }); - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Texture2D), T::RwType::type_()), + if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) { + panic!( + "Tex2d created for a device: `{:?}` but used in `{:?}`", + b, a ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, view.tex.handle.clone())); - node } + r.capture_or_get(binding, &view.tex.handle, || { + Node::new(CArc::new(Instruction::Texture2D), T::RwType::type_()) + }) }); Self { node, @@ -1825,18 +1815,15 @@ impl Tex3dVar { handle, level: view.level, }); - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Texture3D), T::RwType::type_()), + if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) { + panic!( + "Tex3d created for a device: `{:?}` but used in `{:?}`", + b, a ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, view.tex.handle.clone())); - node } + r.capture_or_get(binding, &view.tex.handle, || { + Node::new(CArc::new(Instruction::Texture3D), T::RwType::type_()) + }) }); Self { node, diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 4bce93d1..1ba2f80b 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -662,21 +662,18 @@ impl AccelVar { pub fn new(accel: &rtx::Accel) -> Self { let node = RECORDER.with(|r| { let mut r = r.borrow_mut(); - assert!(r.lock, "BufferVar must be created from within a kernel"); + assert!(r.lock, "AccelVar must be created from within a kernel"); let handle: u64 = accel.handle().0; let binding = Binding::Accel(AccelBinding { handle }); - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Accel), Type::void()), + if let Some((a, b)) = r.check_on_same_device(&accel.handle.device) { + panic!( + "Accel created for a device: `{:?}` but used in `{:?}`", + b, a ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, accel.handle.clone())); - node } + r.capture_or_get(binding, &accel.handle, || { + Node::new(CArc::new(Instruction::Accel), Type::void()) + }) }); Self { node, diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index ca318667..6f699def 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -99,6 +99,12 @@ impl Drop for DeviceHandle { } impl Device { + pub fn query(&self, name: &str) -> Option { + self.inner.query(name) + } + pub fn name(&self) -> String { + self.query("device_name").unwrap_or("unknown".to_string()) + } pub fn create_swapchain( &self, window: &Window, @@ -1237,10 +1243,24 @@ impl DynCallable { unsafe impl Send for RawCallable {} unsafe impl Sync for RawCallable {} pub struct RawCallable { + #[allow(dead_code)] + pub(crate) device: Option, pub(crate) module: ir::CallableModuleRef, #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, } +impl RawCallable { + pub(crate) fn check_on_same_device(&self) { + RECORDER.with(|r| { + let r =r.borrow(); + if let Some(device) = &self.device { + if let Some((a,b)) = r.check_on_same_device(device) { + panic!("Callable created on a different device than the one it is called on: {:?} vs {:?}", a,b); + } + } + }); + } +} pub struct RawKernelDef { #[allow(dead_code)] pub(crate) device: Option, @@ -1339,6 +1359,7 @@ macro_rules! impl_call_for_callable { #[allow(unused_mut)] pub fn call(&self, $($Ts:$Ts),*) -> R { let mut encoder = CallableArgEncoder::new(); + self.inner.check_on_same_device(); $($Ts.encode(&mut encoder);)* CallableRet::_from_return( crate::lang::__invoke_callable(&self.inner.module, &encoder.args)) diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index b1b2a450..1e6fe1bc 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -268,9 +268,9 @@ impl KernelBuilder { let mut resource_tracker = ResourceTracker::new(); let r = r.borrow_mut(); let mut captured: Vec = Vec::new(); - let mut captured_buffers: Vec<_> = r.captured_buffer.values().cloned().collect(); - captured_buffers.sort_by_key(|(i, _, _, _)| *i); - for (j, (i, node, binding, handle)) in captured_buffers.into_iter().enumerate() { + let mut captured_resources: Vec<_> = r.captured_resources.values().cloned().collect(); + captured_resources.sort_by_key(|(i, _, _, _)| *i); + for (j, (i, node, binding, handle)) in captured_resources.into_iter().enumerate() { assert_eq!(j, i); captured.push(Capture { node, binding }); resource_tracker.add_any(handle); @@ -352,6 +352,7 @@ impl KernelBuilder { let module = CallableModuleRef(CArc::new(module)); r.reset(); RawCallable { + device: self.device.clone(), module, resource_tracker: rt, } diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 1e2002ad..b386d1e8 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -48,7 +48,23 @@ fn event() { let v = a.copy_to_vec(); assert_eq!(v[0], (1 + 3) * (4 + 5)); } + #[test] +#[should_panic] +fn callable_different_device() { + let device1 = get_device(); + let device2 = get_device(); + let abs = Callable::) -> Expr>::new( + &device1, + track!(|x| { + if x > 0.0 { + return x; + } + -x + }), + ); + let _foo = Callable::) -> Expr>::new(&device2, |x| abs.call(x)); +} #[test] #[should_panic] fn callable_return_mismatch() {