diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 3e60577..d606874 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -2,9 +2,9 @@ use std::any::Any; use std::cell::{Cell, RefCell}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; -use std::rc::Rc; +use std::rc::{Rc, Weak}; use std::sync::atomic::AtomicUsize; -use std::sync::Arc; +use std::sync::{Arc, Weak as WeakArc}; use std::{env, unreachable}; use crate::internal_prelude::*; @@ -12,7 +12,7 @@ use crate::internal_prelude::*; use bumpalo::Bump; use indexmap::IndexMap; -use crate::runtime::WeakDevice; +use crate::runtime::{RawCallable, WeakDevice}; pub mod ir { pub use luisa_compute_ir::context::register_type; @@ -298,7 +298,7 @@ pub(crate) struct FnRecorder { /// Once a basicblock is finished, all nodes in it are added to this set pub(crate) inaccessible: Rc>>, pub(crate) kernel_id: usize, - pub(crate) captured_resources: IndexMap)>, + pub(crate) captured_resources: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, pub(crate) captured_vars: IndexMap, @@ -311,6 +311,7 @@ pub(crate) struct FnRecorder { pub(crate) callable_ret_type: Option>, pub(crate) const_builder: IrBuilder, pub(crate) index_const_pool: IndexMap, + pub(crate) rt: ResourceTracker, } pub(crate) type FnRecorderPtr = Rc>; impl FnRecorder { @@ -372,7 +373,7 @@ impl FnRecorder { pub(crate) fn capture_or_get( &mut self, binding: ir::Binding, - handle: &Arc, + handle: &WeakArc, create_node: impl FnOnce() -> Node, ) -> NodeRef { if let Some((_, node, _, _)) = self.captured_resources.get(&binding) { @@ -425,6 +426,7 @@ impl FnRecorder { parent, index_const_pool: IndexMap::new(), const_builder: IrBuilder::new(pools.clone()), + rt: ResourceTracker::new(), } } pub(crate) fn map_captured_vars(&mut self, node0: SafeNodeRef) -> SafeNodeRef { @@ -660,21 +662,23 @@ pub(crate) fn check_arg_alias(args: &[NodeRef]) { } } } -pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> NodeRef { +pub(crate) fn __invoke_callable(callable: &RawCallable, args: &[NodeRef]) -> NodeRef { + let inner = &callable.module; with_recorder(|r| { - let id = CArc::as_ptr(&callable.0) as u64; + let id = CArc::as_ptr(&inner.0) as u64; if let Some(c) = r.callables.get(&id) { - assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&callable.0)); + assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&inner.0)); } else { - r.callables.insert(id, callable.clone()); + r.callables.insert(id, inner.clone()); + r.rt.merge(callable.resource_tracker.clone()); } }); check_arg_alias(args); __current_scope(|b| { b.call( - Func::Callable(callable.clone()), + Func::Callable(inner.clone()), args, - callable.0.ret_type.clone(), + inner.0.ret_type.clone(), ) }) } diff --git a/luisa_compute/src/lang/poly.rs b/luisa_compute/src/lang/poly.rs index c131b02..17dd293 100644 --- a/luisa_compute/src/lang/poly.rs +++ b/luisa_compute/src/lang/poly.rs @@ -36,7 +36,7 @@ pub trait PolymorphicImpl: Value { #[macro_export] macro_rules! impl_new_poly_array { ($buffer:expr, $tag:expr, $key:expr) => {{ - let buffer = unsafe { $buffer.shallow_clone() }; + let buffer = $buffer.view(..); luisa_compute::PolyArray::new( $tag, $key, @@ -53,7 +53,7 @@ macro_rules! impl_polymorphic { tag: i32, key: K, ) -> luisa_compute::lang::poly::PolyArray { - let buffer = unsafe { buffer.shallow_clone() }; + let buffer = buffer.view(..); luisa_compute::lang::poly::PolyArray::new( tag, key, diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 7b09fc0..21e70d9 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -159,20 +159,47 @@ impl Context { #[derive(Clone)] pub struct ResourceTracker { - resources: Vec>, + strong_refs: Vec>, + weak_refs: Vec>, } impl ResourceTracker { pub fn add(&mut self, ptr: Arc) -> &mut Self { - self.resources.push(ptr); + self.strong_refs.push(ptr); self } pub fn add_any(&mut self, ptr: Arc) -> &mut Self { - self.resources.push(ptr); + self.strong_refs.push(ptr); self } + pub fn add_weak(&mut self, ptr: Weak) -> &mut Self { + self.weak_refs.push(ptr); + self + } + pub fn add_weak_any(&mut self, ptr: Weak) -> &mut Self { + self.weak_refs.push(ptr); + self + } + pub fn merge(&mut self, other: Self) { + self.strong_refs.extend(other.strong_refs); + self.weak_refs.extend(other.weak_refs); + } + pub fn upgrade(&self) -> Self { + let mut strong_refs = vec![]; + for r in self.weak_refs.iter() { + strong_refs.push(r.upgrade().unwrap_or_else(|| panic!("Bad weak ref. Kernel captured resources might be dropped."))); + } + strong_refs.extend(self.strong_refs.iter().cloned()); + Self { + strong_refs, + weak_refs: vec![], + } + } pub fn new() -> Self { - Self { resources: vec![] } + Self { + strong_refs: vec![], + weak_refs: vec![], + } } } diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 9c5bb40..af82d42 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::fmt; use std::ops::RangeBounds; use std::process::abort; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use parking_lot::lock_api::RawMutex as RawMutexTrait; use parking_lot::RawMutex; @@ -18,7 +18,7 @@ use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE}; use libc::c_void; pub type ByteBuffer = Buffer; -pub type ByteBufferView<'a> = BufferView<'a, u8>; +pub type ByteBufferView = BufferView; pub type ByteBufferVar = BufferVar; // Uncomment if the alias blowup again... @@ -280,10 +280,19 @@ impl BufferVar { } } pub struct Buffer { - pub(crate) device: Device, pub(crate) handle: Arc, - pub(crate) len: usize, - pub(crate) _marker: PhantomData, + pub(crate) full_view: BufferView, +} +impl BufferView { + pub fn copy_async<'a>(&self, s: &'a Scope<'a>) -> Buffer { + let copy = self.device.create_buffer(self.len); + s.submit([self.copy_to_buffer_async(copy.view(..))]); + copy + } + pub fn copy(&self) -> Buffer { + let default_stream = self.device.default_stream(); + default_stream.with_scope(|s| self.copy_async(s)) + } } impl fmt::Debug for Buffer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -325,28 +334,44 @@ impl Drop for BufferHandle { self.device.inner.destroy_buffer(self.handle); } } -#[derive(Clone, Copy)] -pub struct BufferView<'a, T: Value> { - pub(crate) buffer: &'a Buffer, +#[derive(Clone)] +pub struct BufferView { + pub(crate) device: Device, + pub(crate) handle: Weak, /// offset in #elements pub(crate) offset: usize, /// length in #elements pub(crate) len: usize, + pub(crate) _marker: PhantomData, } -impl<'a, T: Value> BufferView<'a, T> { +impl BufferView { + #[inline] pub fn var(&self) -> BufferVar { BufferVar::new(self) } + pub(crate) fn _handle(&self) -> Arc { + Weak::upgrade(&self.handle).unwrap() + } + #[inline] pub fn handle(&self) -> api::Buffer { - self.buffer.handle() + self._handle().handle + } + #[inline] + pub fn native_handle(&self) -> *mut c_void { + self._handle().native_handle } + #[inline] pub fn len(&self) -> usize { self.len } - pub fn copy_to_async<'b>(&'a self, data: &'b mut [T]) -> Command<'b, 'b> { + #[inline] + pub fn size_bytes(&self) -> usize { + self.len * std::mem::size_of::() + } + pub fn copy_to_async<'a>(&self, data: &'a mut [T]) -> Command<'a, 'a> { assert_eq!(data.len(), self.len); let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); + rt.add(self._handle()); Command { inner: api::Command::BufferDownload(BufferDownloadCommand { buffer: self.handle(), @@ -370,14 +395,14 @@ impl<'a, T: Value> BufferView<'a, T> { } pub fn copy_to(&self, data: &mut [T]) { unsafe { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); + submit_default_stream_and_sync(&self.device, [self.copy_to_async(data)]); } } - pub fn copy_from_async<'b>(&'a self, data: &'b [T]) -> Command<'b, 'static> { + pub fn copy_from_async<'a>(&self, data: &'a [T]) -> Command<'a, 'static> { assert_eq!(data.len(), self.len); let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); + rt.add(self._handle()); Command { inner: api::Command::BufferUpload(BufferUploadCommand { buffer: self.handle(), @@ -391,7 +416,7 @@ impl<'a, T: Value> BufferView<'a, T> { } } pub fn copy_from(&self, data: &[T]) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); + submit_default_stream_and_sync(&self.device, [self.copy_from_async(data)]); } pub fn fill_fn T>(&self, f: F) { self.copy_from(&(0..self.len).map(f).collect::>()); @@ -399,11 +424,11 @@ impl<'a, T: Value> BufferView<'a, T> { pub fn fill(&self, value: T) { self.fill_fn(|_| value); } - pub fn copy_to_buffer_async(&self, dst: BufferView<'a, T>) -> Command<'static, 'static> { + pub fn copy_to_buffer_async(&self, dst: BufferView) -> Command<'static, 'static> { assert_eq!(self.len, dst.len); let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - rt.add(dst.buffer.handle.clone()); + rt.add(self._handle()); + rt.add(dst._handle()); Command { inner: api::Command::BufferCopy(api::BufferCopyCommand { src: self.handle(), @@ -418,62 +443,7 @@ impl<'a, T: Value> BufferView<'a, T> { } } pub fn copy_to_buffer(&self, dst: BufferView) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); - } -} -impl Buffer { - #[inline] - pub fn handle(&self) -> api::Buffer { - self.handle.handle - } - #[inline] - pub unsafe fn shallow_clone(&self) -> Buffer { - Buffer { - device: self.device.clone(), - handle: self.handle.clone(), - len: self.len, - _marker: PhantomData, - } - } - #[inline] - pub fn native_handle(&self) -> *mut c_void { - self.handle.native_handle - } - #[inline] - pub fn copy_from(&self, data: &[T]) { - self.view(..).copy_from(data); - } - #[inline] - pub fn copy_from_async<'a>(&self, data: &'a [T]) -> Command<'a, 'static> { - self.view(..).copy_from_async(data) - } - #[inline] - pub fn copy_to(&self, data: &mut [T]) { - self.view(..).copy_to(data); - } - #[inline] - pub fn copy_to_async<'a>(&self, data: &'a mut [T]) -> Command<'a, 'a> { - self.view(..).copy_to_async(data) - } - #[inline] - pub fn copy_to_vec(&self) -> Vec { - self.view(..).copy_to_vec() - } - #[inline] - pub fn copy_to_buffer(&self, dst: &Buffer) { - self.view(..).copy_to_buffer(dst.view(..)); - } - #[inline] - pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a Buffer) -> Command<'static, 'static> { - self.view(..).copy_to_buffer_async(dst.view(..)) - } - #[inline] - pub fn fill_fn T>(&self, f: F) { - self.view(..).fill_fn(f); - } - #[inline] - pub fn fill(&self, value: T) { - self.view(..).fill(value); + submit_default_stream_and_sync(&self.device, [self.copy_to_buffer_async(dst)]); } pub fn view>(&self, range: S) -> BufferView { let lower = range.start_bound(); @@ -491,31 +461,15 @@ impl Buffer { assert!(lower <= upper); assert!(upper <= self.len); BufferView { - buffer: self, + device: self.device.clone(), + handle: self.handle.clone(), offset: lower, len: upper - lower, + _marker: PhantomData, } } - #[inline] - pub fn len(&self) -> usize { - self.len - } - #[inline] - pub fn size_bytes(&self) -> usize { - self.len * std::mem::size_of::() - } - #[inline] - pub fn var(&self) -> BufferVar { - BufferVar::new(&self.view(..)) - } -} -impl Clone for Buffer { - fn clone(&self) -> Self { - let cloned = self.device.create_buffer(self.len); - self.copy_to_buffer(&cloned); - cloned - } } + pub(crate) struct BindlessArrayHandle { pub(crate) device: Device, pub(crate) handle: api::BindlessArray, @@ -533,78 +487,6 @@ pub(crate) struct BindlessArraySlot { pub(crate) tex3d: Option>, } -#[deprecated( - note = "Spamming BufferHeap can cause serious performance issue on CUDA backend. Use BindlessArray instead." -)] -pub struct BufferHeap { - pub(crate) inner: BindlessArray, - pub(crate) _marker: PhantomData, -} -#[allow(deprecated)] -impl fmt::Debug for BufferHeap { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BufferHeap<{}>", std::any::type_name::(),) - } -} -#[deprecated] -pub struct BufferHeapVar { - inner: BindlessArrayVar, - _marker: PhantomData, -} -#[allow(deprecated)] -impl BufferHeap { - #[inline] - pub fn var(&self) -> BufferHeapVar { - BufferHeapVar { - inner: self.inner.var(), - _marker: PhantomData, - } - } - #[inline] - pub fn handle(&self) -> api::BindlessArray { - self.inner.handle() - } - #[inline] - pub fn native_handle(&self) -> *mut std::ffi::c_void { - self.inner.native_handle() - } - pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { - self.inner.emplace_buffer_async(index, buffer); - } - pub fn emplace_buffer_view_async<'a>(&self, index: usize, bufferview: &BufferView<'a, T>) { - self.inner.emplace_buffer_view_async(index, bufferview); - } - pub fn remove_buffer_async(&self, index: usize) { - self.inner.remove_buffer_async(index); - } - #[inline] - pub fn emplace_buffer(&self, index: usize, buffer: &Buffer) { - self.inner.emplace_buffer(index, buffer); - } - #[inline] - pub fn emplace_buffer_view<'a>(&self, index: usize, bufferview: &BufferView<'a, T>) { - self.inner.emplace_buffer_view_async(index, bufferview); - } - #[inline] - pub fn remove_buffer(&self, index: usize) { - self.inner.remove_buffer(index); - } - #[inline] - pub fn update(&self) { - self.inner.update(); - } - #[inline] - pub fn buffer(&self, index: impl AsExpr) -> BindlessBufferVar { - self.inner.buffer(index) - } -} -#[allow(deprecated)] -impl BufferHeapVar { - #[inline] - pub fn buffer(&self, index: impl AsExpr) -> BindlessBufferVar { - self.inner.buffer(index) - } -} pub struct BindlessArray { pub(crate) device: Device, pub(crate) handle: Arc, @@ -646,21 +528,13 @@ impl BindlessArray { pub fn emplace_byte_buffer_async(&self, index: usize, buffer: &ByteBuffer) { self.emplace_byte_buffer_view_async(index, &buffer.view(..)) } - pub fn emplace_byte_buffer_view_async<'a>( - &self, - index: usize, - bufferview: &ByteBufferView<'a>, - ) { + pub fn emplace_byte_buffer_view_async(&self, index: usize, bufferview: &ByteBufferView) { self.emplace_buffer_view_async(index, bufferview) } pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { self.emplace_buffer_view_async(index, &buffer.view(..)) } - pub fn emplace_buffer_view_async<'a, T: Value>( - &self, - index: usize, - bufferview: &BufferView<'a, T>, - ) { + pub fn emplace_buffer_view_async(&self, index: usize, bufferview: &BufferView) { self.lock(); self.modifications .borrow_mut() @@ -677,7 +551,7 @@ impl BindlessArray { offset: bufferview.offset, }; let mut slots = self.slots.borrow_mut(); - slots[index].buffer = Some(bufferview.buffer.handle.clone()); + slots[index].buffer = Some(bufferview._handle()); self.unlock(); } pub fn emplace_tex2d_async( @@ -796,7 +670,7 @@ impl BindlessArray { self.update(); } #[inline] - pub fn emplace_byte_buffer_view(&self, index: usize, buffer: &ByteBufferView<'_>) { + pub fn emplace_byte_buffer_view(&self, index: usize, buffer: &ByteBufferView) { self.emplace_byte_buffer_view_async(index, buffer); self.update(); } @@ -1083,11 +957,13 @@ pub struct Tex2d { #[allow(dead_code)] pub(crate) height: u32, pub(crate) handle: Arc, - pub(crate) marker: PhantomData, + pub(crate) views: Vec>, } -impl Clone for Tex2d { - fn clone(&self) -> Self { +impl Tex2d { + /// Create a new texture with the same dimensions and storage as `self` + /// and copy the contents of `self` to it asynchronously + pub fn copy_async<'a>(&self, s: &Scope<'a>) -> Self { let h = self.handle.as_ref(); let width = self.width; let height = self.height; @@ -1095,13 +971,16 @@ impl Clone for Tex2d { let mips = h.levels; let device = &h.device; let copy = device.create_tex2d::(storage, width, height, mips); - device.default_stream().with_scope(|s| { - s.submit( - (0..mips).map(|level| self.view(level).copy_to_texture_async(copy.view(level))), - ); - }); + s.submit((0..mips).map(|level| self.view(level).copy_to_texture_async(copy.view(level)))); copy } + + /// Create a new texture with the same dimensions and storage as `self` + /// and copy the contents of `self` to it. + pub fn copy(&self) -> Self { + let default_stream = self.handle.device.default_stream(); + default_stream.with_scope(|s| self.copy_async(s)) + } } impl fmt::Debug for Tex2d { @@ -1127,10 +1006,12 @@ pub struct Tex3d { #[allow(dead_code)] pub(crate) depth: u32, pub(crate) handle: Arc, - pub(crate) marker: PhantomData, + pub(crate) views: Vec>, } -impl Clone for Tex3d { - fn clone(&self) -> Self { +impl Tex3d { + /// Create a new texture with the same dimensions and storage as `self` + /// and copy the contents of `self` to it asynchronously + pub fn copy_async(&self, s: &Scope) -> Self { let h = self.handle.as_ref(); let width = self.width; let height = self.height; @@ -1139,13 +1020,16 @@ impl Clone for Tex3d { let mips = h.levels; let device = &h.device; let copy = device.create_tex3d::(storage, width, height, depth, mips); - device.default_stream().with_scope(|s| { - s.submit( - (0..mips).map(|level| self.view(level).copy_to_texture_async(copy.view(level))), - ); - }); + s.submit((0..mips).map(|level| self.view(level).copy_to_texture_async(copy.view(level)))); copy } + + /// Create a new texture with the same dimensions and storage as `self` + /// and copy the contents of `self` to it. + pub fn copy(&self) -> Self { + let default_stream = self.handle.device.default_stream(); + default_stream.with_scope(|s| self.copy_async(s)) + } } impl fmt::Debug for Tex3d { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -1160,15 +1044,33 @@ impl fmt::Debug for Tex3d { } } -#[derive(Clone, Copy)] -pub struct Tex2dView<'a, T: IoTexel> { - pub(crate) tex: &'a Tex2d, +#[derive(Clone)] +pub struct Tex2dView { + pub(crate) device: Device, + #[allow(dead_code)] + pub(crate) width: u32, + #[allow(dead_code)] + pub(crate) height: u32, + pub(crate) storage: PixelStorage, + pub(crate) format: PixelFormat, + pub(crate) handle: Weak, pub(crate) level: u32, + pub(crate) marker: PhantomData, } -#[derive(Clone, Copy)] -pub struct Tex3dView<'a, T: IoTexel> { - pub(crate) tex: &'a Tex3d, +#[derive(Clone)] +pub struct Tex3dView { + pub(crate) device: Device, + #[allow(dead_code)] + pub(crate) width: u32, + #[allow(dead_code)] + pub(crate) height: u32, + #[allow(dead_code)] + pub(crate) depth: u32, + pub(crate) storage: PixelStorage, + pub(crate) format: PixelFormat, + pub(crate) handle: Weak, pub(crate) level: u32, + pub(crate) marker: PhantomData, } impl Tex2d { pub fn handle(&self) -> api::Texture { @@ -1188,19 +1090,22 @@ impl Tex3d { } macro_rules! impl_tex_view { ($name:ident) => { - impl<'a, T: IoTexel> $name<'a, T> { - pub fn copy_to_async>( - &'a self, + impl $name { + pub(crate) fn _handle(&self) -> Arc { + self.handle.upgrade().unwrap() + } + pub fn copy_to_async<'a, U: StorageTexel>( + &self, data: &'a mut [U], ) -> Command<'a, 'a> { assert_eq!(data.len(), self.texel_count() as usize); - assert_eq!(self.tex.handle.storage, U::pixel_storage()); + assert_eq!(self.storage, U::pixel_storage()); let mut rt = ResourceTracker::new(); - rt.add(self.tex.handle.clone()); + rt.add(self._handle()); Command { inner: api::Command::TextureDownload(api::TextureDownloadCommand { texture: self.handle(), - storage: self.tex.handle.storage, + storage: self.storage, level: self.level, size: self.size(), data: data.as_mut_ptr() as *mut u8, @@ -1210,12 +1115,12 @@ macro_rules! impl_tex_view { callback: None, } } - pub fn copy_to>(&'a self, data: &'a mut [U]) { + pub fn copy_to>(&self, data: &mut [U]) { assert_eq!(data.len(), self.texel_count() as usize); - submit_default_stream_and_sync(&self.tex.handle.device, [self.copy_to_async(data)]); + submit_default_stream_and_sync(&self.device, [self.copy_to_async(data)]); } - pub fn copy_to_vec>(&'a self) -> Vec { + pub fn copy_to_vec>(&self) -> Vec { let mut data = Vec::with_capacity(self.texel_count() as usize); unsafe { data.set_len(self.texel_count() as usize); @@ -1223,18 +1128,18 @@ macro_rules! impl_tex_view { self.copy_to(&mut data); data } - pub fn copy_from_async<'b, U: StorageTexel>( - &'a self, - data: &'b [U], - ) -> Command<'b, 'static> { + pub fn copy_from_async>( + &self, + data: &[U], + ) -> Command<'static, 'static> { assert_eq!(data.len(), self.texel_count() as usize); - assert_eq!(self.tex.handle.storage, U::pixel_storage()); + assert_eq!(self.storage, U::pixel_storage()); let mut rt = ResourceTracker::new(); - rt.add(self.tex.handle.clone()); + rt.add(self._handle()); Command { inner: api::Command::TextureUpload(api::TextureUploadCommand { texture: self.handle(), - storage: self.tex.handle.storage, + storage: self.storage, level: self.level, size: self.size(), data: data.as_ptr() as *const u8, @@ -1244,25 +1149,22 @@ macro_rules! impl_tex_view { callback: None, } } - pub fn copy_from>(&'a self, data: &[U]) { - submit_default_stream_and_sync( - &self.tex.handle.device, - [self.copy_from_async(data)], - ); + pub fn copy_from>(&self, data: &[U]) { + submit_default_stream_and_sync(&self.device, [self.copy_from_async(data)]); } - pub fn copy_to_buffer_async<'b, U: StorageTexel + Value>( - &'a self, - buffer_view: &'b BufferView, + pub fn copy_to_buffer_async + Value>( + &self, + buffer_view: &BufferView, ) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); - rt.add(self.tex.handle.clone()); - rt.add(buffer_view.buffer.handle.clone()); + rt.add(self._handle()); + rt.add(buffer_view._handle()); assert_eq!(buffer_view.len, self.texel_count() as usize); - assert_eq!(self.tex.handle.storage, U::pixel_storage()); + assert_eq!(self.storage, U::pixel_storage()); Command { inner: api::Command::TextureToBufferCopy(api::TextureToBufferCopyCommand { texture: self.handle(), - storage: self.tex.handle.storage, + storage: self.storage, texture_level: self.level, texture_size: self.size(), buffer: buffer_view.handle(), @@ -1273,28 +1175,25 @@ macro_rules! impl_tex_view { callback: None, } } - pub fn copy_to_buffer + Value>( - &'a self, - buffer_view: &BufferView, - ) { + pub fn copy_to_buffer + Value>(&self, buffer_view: &BufferView) { submit_default_stream_and_sync( - &self.tex.handle.device, + &self.device, [self.copy_to_buffer_async(buffer_view)], ); } - pub fn copy_from_buffer_async<'b, U: StorageTexel + Value>( - &'a self, + pub fn copy_from_buffer_async + Value>( + &self, buffer_view: BufferView, ) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); - rt.add(self.tex.handle.clone()); - rt.add(buffer_view.buffer.handle.clone()); + rt.add(self._handle()); + rt.add(buffer_view._handle()); assert_eq!(buffer_view.len, self.texel_count() as usize); - assert_eq!(self.tex.handle.storage, U::pixel_storage()); + assert_eq!(self.storage, U::pixel_storage()); Command { inner: api::Command::BufferToTextureCopy(api::BufferToTextureCopyCommand { texture: self.handle(), - storage: self.tex.handle.storage, + storage: self.storage, texture_level: self.level, texture_size: self.size(), buffer: buffer_view.handle(), @@ -1305,26 +1204,23 @@ macro_rules! impl_tex_view { callback: None, } } - pub fn copy_from_buffer + Value>( - &'a self, - buffer_view: BufferView, - ) { + pub fn copy_from_buffer + Value>(&self, buffer_view: BufferView) { submit_default_stream_and_sync( - &self.tex.handle.device, + &self.device, [self.copy_from_buffer_async(buffer_view)], ); } - pub fn copy_to_texture_async(&'a self, other: $name) -> Command<'static, 'static> { + pub fn copy_to_texture_async(&self, other: $name) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); - rt.add(self.tex.handle.clone()); - rt.add(other.tex.handle.clone()); + rt.add(self._handle()); + rt.add(other._handle()); assert_eq!(self.size(), other.size()); - assert_eq!(self.tex.handle.storage, other.tex.handle.storage); - assert_eq!(self.tex.handle.format, other.tex.handle.format); + assert_eq!(self.storage, other.storage); + assert_eq!(self.format, other.format); Command { inner: api::Command::TextureCopy(api::TextureCopyCommand { src: self.handle(), - storage: self.tex.handle.storage, + storage: self.storage, src_level: self.level, size: self.size(), dst: other.handle(), @@ -1335,18 +1231,15 @@ macro_rules! impl_tex_view { callback: None, } } - pub fn copy_to_texture(&'a self, other: $name) { - submit_default_stream_and_sync( - &self.tex.handle.device, - [self.copy_to_texture_async(other)], - ); + pub fn copy_to_texture(&self, other: $name) { + submit_default_stream_and_sync(&self.device, [self.copy_to_texture_async(other)]); } } }; } -impl<'a, T: IoTexel> Tex2dView<'a, T> { +impl Tex2dView { pub fn handle(&self) -> api::Texture { - self.tex.handle.handle + self._handle().handle } pub fn texel_count(&self) -> u32 { let s = self.size(); @@ -1354,19 +1247,19 @@ impl<'a, T: IoTexel> Tex2dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), + (self.width >> self.level).max(1), + (self.height >> self.level).max(1), 1, ] } pub fn var(&self) -> Tex2dVar { - Tex2dVar::new(*self) + Tex2dVar::new(self.clone()) } } impl_tex_view!(Tex2dView); -impl<'a, T: IoTexel> Tex3dView<'a, T> { +impl Tex3dView { pub fn handle(&self) -> api::Texture { - self.tex.handle.handle + self._handle().handle } pub fn texel_count(&self) -> u32 { let s = self.size(); @@ -1374,13 +1267,13 @@ impl<'a, T: IoTexel> Tex3dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), - (self.tex.handle.depth >> self.level).max(1), + (self.width >> self.level).max(1), + (self.height >> self.level).max(1), + (self.depth >> self.level).max(1), ] } pub fn var(&self) -> Tex3dVar { - Tex3dVar::new(*self) + Tex3dVar::new(self.clone()) } } impl_tex_view!(Tex3dView); @@ -1392,7 +1285,7 @@ impl Drop for TextureHandle { impl Tex2d { pub fn view(&self, level: u32) -> Tex2dView { - Tex2dView { tex: self, level } + self.views[level as usize].clone() } pub fn native_handle(&self) -> *mut std::ffi::c_void { self.handle.native_handle @@ -1415,7 +1308,7 @@ impl Tex2d { } impl Tex3d { pub fn view(&self, level: u32) -> Tex3dView { - Tex3dView { tex: self, level } + self.views[level as usize].clone() } pub fn native_handle(&self) -> *mut std::ffi::c_void { self.handle.native_handle @@ -1913,7 +1806,7 @@ impl BindlessArrayVar { b, a ); } - r.capture_or_get(binding, &array.handle, || { + r.capture_or_get(binding, &Arc::downgrade(&array.handle), || { Node::new(CArc::new(Instruction::Bindless), Type::void()) }) }) @@ -1924,6 +1817,13 @@ impl BindlessArrayVar { } } } + +impl std::ops::Deref for Buffer { + type Target = BufferView; + fn deref(&self) -> &Self::Target { + &self.full_view + } +} impl ToNode for Buffer { fn node(&self) -> SafeNodeRef { self.var().node() @@ -1940,6 +1840,7 @@ impl IndexWrite for Buffer { self.var().write(i, v) } } + impl IndexRead for BufferVar { type Element = T; fn read(&self, i: I) -> Expr { @@ -1967,20 +1868,20 @@ impl IndexWrite for BufferVar { } } impl BufferVar { - pub fn new(buffer: &BufferView<'_, T>) -> Self { + pub fn new(buffer: &BufferView) -> Self { let node = with_recorder(|r| { let binding = Binding::Buffer(BufferBinding { handle: buffer.handle().0, size: buffer.len * std::mem::size_of::(), offset: (buffer.offset * std::mem::size_of::()) as u64, }); - if let Some((a, b)) = r.check_on_same_device(&buffer.buffer.device) { + if let Some((a, b)) = r.check_on_same_device(&buffer.device) { panic!( "Buffer created for a device: `{:?}` but used in `{:?}`", b, a ); } - r.capture_or_get(binding, &buffer.buffer.handle, || { + r.capture_or_get(binding, &buffer.handle, || { Node::new(CArc::new(Instruction::Buffer), T::type_()) }) }) @@ -1988,7 +1889,7 @@ impl BufferVar { Self { node, marker: PhantomData, - handle: Some(buffer.buffer.handle.clone()), + handle: Some(buffer._handle()), } } pub fn atomic_ref(&self, i: impl IntoIndex) -> AtomicRef { @@ -2240,27 +2141,27 @@ pub struct Tex2dVar { } impl Tex2dVar { - pub fn new(view: Tex2dView<'_, T>) -> Self { + pub fn new(view: Tex2dView) -> Self { let node = with_recorder(|r| { - let handle: u64 = view.tex.handle().0; + let handle: u64 = view.handle().0; let binding = Binding::Texture(TextureBinding { handle, level: view.level, }); - if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) { + if let Some((a, b)) = r.check_on_same_device(&view.device) { panic!( "Tex2d created for a device: `{:?}` but used in `{:?}`", b, a ); } - r.capture_or_get(binding, &view.tex.handle, || { + r.capture_or_get(binding, &view.handle, || { Node::new(CArc::new(Instruction::Texture2D), T::RwType::type_()) }) }) .into(); Self { node, - handle: Some(view.tex.handle.clone()), + handle: Some(view._handle()), level: Some(view.level), marker: PhantomData, } @@ -2291,27 +2192,27 @@ impl Tex2dVar { } impl Tex3dVar { - pub fn new(view: Tex3dView<'_, T>) -> Self { + pub fn new(view: Tex3dView) -> Self { let node = with_recorder(|r| { - let handle: u64 = view.tex.handle().0; + let handle: u64 = view.handle().0; let binding = Binding::Texture(TextureBinding { handle, level: view.level, }); - if let Some((a, b)) = r.check_on_same_device(&view.tex.handle.device) { + if let Some((a, b)) = r.check_on_same_device(&view.device) { panic!( "Tex3d created for a device: `{:?}` but used in `{:?}`", b, a ); } - r.capture_or_get(binding, &view.tex.handle, || { + r.capture_or_get(binding, &view.handle, || { Node::new(CArc::new(Instruction::Texture3D), T::RwType::type_()) }) }) .into(); Self { node, - handle: Some(view.tex.handle.clone()), + handle: Some(view._handle()), level: Some(view.level), marker: PhantomData, } diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 1a2bd18..b38722e 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -687,7 +687,7 @@ impl AccelVar { b, a ); } - r.capture_or_get(binding, &accel.handle, || { + r.capture_or_get(binding, &Arc::downgrade(&accel.handle), || { Node::new(CArc::new(Instruction::Accel), Type::void()) }) }) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index a3de7c0..97c5024 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -175,6 +175,8 @@ impl Device { }; swapchain } + + /// Creates an **unintialized** buffer of `len` bytes. pub fn create_byte_buffer(&self, len: usize) -> Buffer { let name = self.name(); if name == "dx" { @@ -184,18 +186,25 @@ impl Device { ); } let buffer = self.inner.create_buffer(&Type::void(), len); - let buffer = Buffer { + let handle = Arc::new(BufferHandle { device: self.clone(), - handle: Arc::new(BufferHandle { + handle: api::Buffer(buffer.resource.handle), + native_handle: buffer.resource.native_handle, + }); + let buffer = Buffer { + handle: handle.clone(), + full_view: BufferView { device: self.clone(), - handle: api::Buffer(buffer.resource.handle), - native_handle: buffer.resource.native_handle, - }), - len, - _marker: PhantomData, + handle: Arc::downgrade(&handle), + offset: 0, + len, + _marker: PhantomData, + }, }; buffer } + + /// Creates an **unintialized** buffer of `count` elements of type `T` in SOA layout. pub fn create_soa_buffer(&self, count: usize) -> SoaBuffer { assert!( count <= u32::MAX as usize, @@ -221,6 +230,8 @@ impl Device { }; buffer } + + /// Creates an **unintialized** buffer of `count` elements of type `T`. pub fn create_buffer(&self, count: usize) -> Buffer { let name = self.name(); if name == "dx" { @@ -238,15 +249,20 @@ impl Device { "size of T must be greater than 0" ); let buffer = self.inner.create_buffer(&T::type_(), count); - let buffer = Buffer { + let handle = Arc::new(BufferHandle { device: self.clone(), - handle: Arc::new(BufferHandle { + handle: api::Buffer(buffer.resource.handle), + native_handle: buffer.resource.native_handle, + }); + let buffer = Buffer { + handle: handle.clone(), + full_view: BufferView { device: self.clone(), - handle: api::Buffer(buffer.resource.handle), - native_handle: buffer.resource.native_handle, - }), - _marker: PhantomData {}, - len: count, + handle: Arc::downgrade(&handle), + offset: 0, + len: count, + _marker: PhantomData, + }, }; buffer } @@ -264,17 +280,6 @@ impl Device { buffer.view(..).fill_fn(f); buffer } - #[deprecated( - note = "Spamming BufferHeap can cause serious performance issue on CUDA backend. Use BindlessArray instead." - )] - #[allow(deprecated)] - pub fn create_buffer_heap(&self, slots: usize) -> BufferHeap { - let array = self.create_bindless_array(slots); - BufferHeap { - inner: array, - _marker: PhantomData {}, - } - } pub fn create_bindless_array(&self, slots: usize) -> BindlessArray { assert!(slots > 0, "slots must be greater than 0"); let array = self.inner.create_bindless_array(slots); @@ -319,11 +324,23 @@ impl Device { depth: 1, storage: format.storage(), }); + let weak = Arc::downgrade(&handle); let tex = Tex2d { width, height, handle, - marker: PhantomData {}, + views: (0..mips) + .map(|level| Tex2dView { + device: self.clone(), + width, + height, + storage, + format, + handle: weak.clone(), + level, + marker: PhantomData, + }) + .collect(), }; tex } @@ -350,12 +367,25 @@ impl Device { depth, storage: format.storage(), }); + let weak = Arc::downgrade(&handle); let tex = Tex3d { width, height, depth, handle, - marker: PhantomData {}, + views: (0..mips) + .map(|level| Tex3dView { + device: self.clone(), + width, + height, + depth, + storage, + format, + handle: weak.clone(), + level, + marker: PhantomData, + }) + .collect(), }; tex } @@ -389,7 +419,7 @@ impl Device { } pub fn create_procedural_primitive( &self, - aabb_buffer: BufferView<'_, rtx::Aabb>, + aabb_buffer: BufferView, option: AccelOption, ) -> rtx::ProceduralPrimitive { let primitive = self.inner.create_procedural_primitive(option); @@ -398,7 +428,7 @@ impl Device { device: self.clone(), handle: api::ProceduralPrimitive(primitive.handle), native_handle: primitive.native_handle, - aabb_buffer: aabb_buffer.buffer.handle.clone(), + aabb_buffer: aabb_buffer._handle(), }), aabb_buffer: aabb_buffer.handle(), aabb_buffer_offset: aabb_buffer.offset * std::mem::size_of::() as usize, @@ -407,8 +437,8 @@ impl Device { } pub fn create_mesh( &self, - vbuffer: BufferView<'_, V>, - tbuffer: BufferView<'_, rtx::Index>, + vbuffer: BufferView, + tbuffer: BufferView, option: AccelOption, ) -> Mesh { let mesh = self.inner.create_mesh(option); @@ -419,8 +449,8 @@ impl Device { device: self.clone(), handle: api::Mesh(handle), native_handle, - vbuffer: vbuffer.buffer.handle.clone(), - ibuffer: tbuffer.buffer.handle.clone(), + vbuffer: vbuffer._handle(), + ibuffer: tbuffer._handle(), }), vertex_buffer: vbuffer.handle(), vertex_buffer_offset: vbuffer.offset * std::mem::size_of::() as usize, @@ -1129,7 +1159,7 @@ impl KernelArg for T { } } -impl<'a, T: Value> KernelArg for BufferView<'a, T> { +impl KernelArg for BufferView { type Parameter = BufferVar; fn encode(&self, encoder: &mut KernelArgEncoder) { encoder.buffer_view(self); @@ -1150,14 +1180,14 @@ impl KernelArg for Tex3d { } } -impl<'a, T: IoTexel> KernelArg for Tex2dView<'a, T> { +impl KernelArg for Tex2dView { type Parameter = Tex2dVar; fn encode(&self, encoder: &mut KernelArgEncoder) { encoder.tex2d(self); } } -impl<'a, T: IoTexel> KernelArg for Tex3dView<'a, T> { +impl KernelArg for Tex3dView { type Parameter = Tex3dVar; fn encode(&self, encoder: &mut KernelArgEncoder) { encoder.tex3d(self); @@ -1228,6 +1258,9 @@ impl RawKernel { let args = Arc::new(args); assert_eq!(args.len(), self.module.args.len()); rt.add(args.clone()); + let mut captures = self.resource_tracker.clone(); + captures.upgrade(); + rt.merge(captures); Command { inner: api::Command::ShaderDispatch(api::ShaderDispatchCommand { shader: self.unwrap(), @@ -1291,8 +1324,7 @@ impl DynCallable { for c in callables { if crate::lang::__check_callable(&c.inner.module, nodes) { return CallableRet::_from_return(crate::lang::__invoke_callable( - &c.inner.module, - nodes, + &c.inner, nodes, )); } } @@ -1329,7 +1361,7 @@ impl DynCallable { let callables = &mut inner.callables; callables.push(new_callable); CallableRet::_from_return(crate::lang::__invoke_callable( - &callables.last().unwrap().inner.module, + &callables.last().unwrap().inner, nodes, )) } @@ -1455,7 +1487,7 @@ impl AsKernelArg for Buffer { type Output = Buffer; } -impl<'a, T: Value> AsKernelArg for BufferView<'a, T> { +impl<'a, T: Value> AsKernelArg for BufferView { type Output = Buffer; } impl AsKernelArg for SoaBuffer { @@ -1464,11 +1496,11 @@ impl AsKernelArg for SoaBuffer { impl<'a, T: SoaValue> AsKernelArg for SoaBufferView<'a, T> { type Output = SoaBuffer; } -impl<'a, T: IoTexel> AsKernelArg for Tex2dView<'a, T> { +impl<'a, T: IoTexel> AsKernelArg for Tex2dView { type Output = Tex2d; } -impl<'a, T: IoTexel> AsKernelArg for Tex3dView<'a, T> { +impl<'a, T: IoTexel> AsKernelArg for Tex3dView { type Output = Tex3d; } @@ -1501,7 +1533,7 @@ macro_rules! impl_call_for_callable { args.extend_from_slice(&encoder.args); args.extend_from_slice(&self.inner.captured_args); CallableRet::_from_return( - crate::lang::__invoke_callable(&self.inner.module, &args)) + crate::lang::__invoke_callable(&self.inner, &args)) } } impl DynCallableR> { diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index a1a2ad0..9ea4e71 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -350,14 +350,14 @@ impl KernelBuilder { } fn collect_module_info(&self) -> (ResourceTracker, Vec>, Vec) { with_recorder(|r| { - let mut resource_tracker = ResourceTracker::new(); + let mut resource_tracker = std::mem::replace(&mut r.rt, ResourceTracker::new()); let mut captured: Vec = Vec::new(); let mut captured_resources: Vec<_> = r.captured_resources.values().cloned().collect(); captured_resources.sort_by_key(|(i, _, _, _)| *i); for (j, (i, node, binding, handle)) in captured_resources.into_iter().enumerate() { assert_eq!(j, i); captured.push(Capture { node, binding }); - resource_tracker.add_any(handle); + resource_tracker.add_weak_any(handle); } let mut cpu_custom_ops: Vec<_> = r.cpu_custom_ops.values().cloned().collect(); cpu_custom_ops.sort_by_key(|(i, _)| *i); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index a9d4bb0..3a80c82 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1681,7 +1681,96 @@ fn buffer_size() { let out = out.view(..).copy_to_vec(); assert_eq!(out[0], 1024); } +#[test] +#[should_panic] +fn drop_buffer_before_kernel() { + let device = get_device(); + let x = device.create_buffer::(1024); + let k = device.create_kernel::(track!(&|| { + let tid = dispatch_id().x; + let x = x.var(); + x.write(tid, tid.as_f32()); + })); + std::mem::drop(x); + k.dispatch([1024, 1, 1]); +} +#[test] +#[should_panic] +fn drop_buffer_before_callable() { + let device = get_device(); + let x = device.create_buffer::(1024); + let k = device.create_kernel::(track!(&|| { + let tid = dispatch_id().x; + outline(|| { + let x = x.var(); + x.write(tid, tid.as_f32()); + }) + })); + std::mem::drop(x); + k.dispatch([1024, 1, 1]); +} +#[test] +#[should_panic] +fn drop_buffer_before_callable_captured_outside_kernel() { + let device = get_device(); + let x = device.create_buffer::(1024); + let c = Callable::::new(&device, || { + let tid = dispatch_id().x; + let x = x.var(); + outline(|| { + x.write(tid, tid.as_f32()); + }); + }); + let k = device.create_kernel::(track!(&|| { + c.call(); + })); + std::mem::drop(x); + k.dispatch([1024, 1, 1]); +} +#[test] +#[should_panic] +fn drop_texture_before_kernel() { + let device = get_device(); + let t = device.create_tex2d::(PixelStorage::Byte4, 512, 512, 1); + let k = device.create_kernel::(track!(&|| { + let tid = dispatch_id().xy(); + t.write(tid, Float4::splat_expr(1.0f32)); + })); + std::mem::drop(t); + k.dispatch([512, 512, 1]); +} +#[test] +#[should_panic] +fn drop_texture_before_callable() { + let device = get_device(); + let t = device.create_tex2d::(PixelStorage::Byte4, 512, 512, 1); + let k = device.create_kernel::(track!(&|| { + let tid = dispatch_id().xy(); + outline(|| { + t.write(tid, Float4::splat_expr(1.0f32)); + }); + })); + std::mem::drop(t); + k.dispatch([512, 512, 1]); +} +#[test] +#[should_panic] +fn drop_texture_before_callable_captured_outside_kernel() { + let device = get_device(); + let t = device.create_tex2d::(PixelStorage::Byte4, 512, 512, 1); + let c = Callable::::new(&device, || { + let tid = dispatch_id().xy(); + outline(|| { + t.write(tid, Float4::splat_expr(1.0f32)); + }); + }); + let k = device.create_kernel::(track!(&|| { + c.call(); + })); + std::mem::drop(t); + k.dispatch([512, 512, 1]); +} #[test] #[tracked] fn test_tracked() {