From f8fd3796cc5a65b3051b0916574ad8d4f0ae25a7 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 8 Oct 2023 00:13:34 -0400 Subject: [PATCH] prevent SOA overflow SRV limit on dx --- luisa_compute/src/lang/soa.rs | 18 ++++++++++++++---- luisa_compute/src/runtime.rs | 2 +- luisa_compute_sys/LuisaCompute | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs index d1911e4..70322df 100644 --- a/luisa_compute/src/lang/soa.rs +++ b/luisa_compute/src/lang/soa.rs @@ -14,7 +14,7 @@ use super::types::SoaValue; pub struct SoaBuffer { pub(crate) device: Device, pub(crate) storage: Arc, - pub(crate) metadata_buf: Buffer, + pub(crate) metadata_buf: Arc>, pub(crate) metadata: SoaMetadata, pub(crate) copy_kernel: Mutex>>, pub(crate) _marker: std::marker::PhantomData, @@ -51,6 +51,7 @@ impl SoaBuffer { pub fn len_expr(&self) -> Expr { self.metadata_buf.read(0).count } + pub fn view>(&self, range: S) -> SoaBufferView { let lower = range.start_bound(); let upper = range.end_bound(); @@ -71,8 +72,13 @@ impl SoaBuffer { view_start: lower as u64, view_count: (upper - lower) as u64, }; + let is_full = lower == 0 && upper == self.len(); SoaBufferView { - metadata_buf: self.device.create_buffer_from_slice(&[metadata]), + metadata_buf: if is_full { + self.metadata_buf.clone() + } else { + Arc::new(self.device.create_buffer_from_slice(&[metadata])) + }, metadata, buffer: self, } @@ -97,6 +103,10 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> { *copy_kernel = Some(SoaBufferCopyKernel::new(&self.buffer.device)); } } + + /// **WARNING** when capturing the view, if the view is not equal to the full range, a new metadata buffer will be created. + /// However, DX has a limit on the number of SRVs, so it is not recommended to call this method repeatedly. + /// Instead, call it once per view and store the result. pub fn var(&self) -> SoaBufferVar { SoaBufferVar { proxy: T::SoaBuffer::from_soa_storage( @@ -137,7 +147,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> { submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(buffer)]); } } -#[derive(Clone, Copy, Value)] +#[derive(Clone, Copy, Value, PartialEq, Eq, Hash, Debug)] #[repr(C)] pub struct SoaMetadata { /// number of elements in the global buffer @@ -147,7 +157,7 @@ pub struct SoaMetadata { pub view_count: u64, } pub struct SoaBufferView<'a, T: SoaValue> { - pub(crate) metadata_buf: Buffer, + pub(crate) metadata_buf: Arc>, pub(crate) metadata: SoaMetadata, pub(crate) buffer: &'a SoaBuffer, } diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 8ce842f..b886a3a 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -206,7 +206,7 @@ impl Device { view_start: 0, view_count: count as u64, }; - let metadata_buf = self.create_buffer_from_slice(&[metadata]); + let metadata_buf = Arc::new(self.create_buffer_from_slice(&[metadata])); let buffer = SoaBuffer { storage, metadata_buf, diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index f865c63..d3c8a0d 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit f865c635ba50732a5edda22ff1873e233fc66f08 +Subproject commit d3c8a0d62943753f1a1c2ed90ed3ab5a0b04a4ee