diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs index 6ce80fa..d1911e4 100644 --- a/luisa_compute/src/lang/soa.rs +++ b/luisa_compute/src/lang/soa.rs @@ -28,15 +28,15 @@ impl SoaBufferCopyKernel { fn new(device: &Device) -> Self { let copy_to = device.create_kernel::, Buffer, u64)>(&|soa, buf, offset| { - let i = dispatch_id().x.as_u64() + offset; - let v = soa.read(i); + let i = dispatch_id().x.as_u64(); + let v = soa.read(i + offset); buf.write(i, v); }); let copy_from = device.create_kernel::, Buffer, u64)>(&|soa, buf, offset| { - let i = dispatch_id().x.as_u64() + offset; + let i = dispatch_id().x.as_u64(); let v = buf.read(i); - soa.write(i, v); + soa.write(i + offset, v); }); Self { copy_to, copy_from } } @@ -108,6 +108,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> { } pub fn copy_from_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { self.init_copy_kernel(); + assert_eq!(self.metadata.view_count, buffer.len() as u64); let copy_kernel = self.buffer.copy_kernel.lock(); let copy_kernel = copy_kernel.as_ref().unwrap(); copy_kernel.copy_from.dispatch_async( @@ -122,6 +123,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> { } pub fn copy_to_buffer_async(&self, buffer: &Buffer) -> Command<'static, 'static> { self.init_copy_kernel(); + assert_eq!(self.metadata.view_count, buffer.len() as u64); let copy_kernel = self.buffer.copy_kernel.lock(); let copy_kernel = copy_kernel.as_ref().unwrap(); copy_kernel.copy_to.dispatch_async( @@ -152,6 +154,12 @@ pub struct SoaBufferView<'a, T: SoaValue> { pub struct SoaBufferVar { pub(crate) proxy: T::SoaBuffer, } +impl std::ops::Deref for SoaBufferVar { + type Target = T::SoaBuffer; + fn deref(&self) -> &Self::Target { + &self.proxy + } +} impl IndexRead for SoaBufferVar { type Element = T; fn read(&self, i: I) -> Expr { diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index ecbadd7..86c410f 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1181,22 +1181,66 @@ pub struct Foo { v: Float2, a: [i32; 4], } +#[derive(Clone, Copy, Debug, Value, Soa, PartialEq)] +#[repr(C)] +#[value_new(pub)] +pub struct Bar { + i: u32, + v: Float2, + a: [i32; 4], + f: Foo, +} #[test] fn soa() { let device = get_device(); let mut rng = thread_rng(); - let foos = device.create_buffer_from_fn(1024, |_| Foo { + let bars = device.create_buffer_from_fn(1024, |_| Bar { i: rng.gen(), v: Float2::new(rng.gen(), rng.gen()), a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + f: Foo { + i: rng.gen(), + v: Float2::new(rng.gen(), rng.gen()), + a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + }, }); - let foos_soa = device.create_soa_buffer::(1024); - foos_soa.copy_from_buffer(&foos); - let also_foos = device.create_buffer(1024); - foos_soa.copy_to_buffer(&also_foos); - let foos_data = foos.view(..).copy_to_vec(); - let also_foos_data = also_foos.view(..).copy_to_vec(); - assert_eq!(foos_data, also_foos_data); + let bars_soa = device.create_soa_buffer::(1024); + bars_soa.copy_from_buffer(&bars); + let also_bars = device.create_buffer(1024); + bars_soa.copy_to_buffer(&also_bars); + let bars_data = bars.view(..).copy_to_vec(); + let also_bars_data = also_bars.view(..).copy_to_vec(); + assert_eq!(bars_data, also_bars_data); +} +#[test] +fn soa_view() { + let device = get_device(); + let mut rng = thread_rng(); + let bars = device.create_buffer_from_fn(1024, |_| Bar { + i: rng.gen(), + v: Float2::new(rng.gen(), rng.gen()), + a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + f: Foo { + i: rng.gen(), + v: Float2::new(rng.gen(), rng.gen()), + a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + }, + }); + let bars_soa = device.create_soa_buffer::(2048); + bars_soa.view(..1024).copy_from_buffer(&bars); + bars_soa.view(1024..2048).copy_from_buffer(&bars); + + let also_bars = device.create_buffer(1024); + bars_soa.view(..1024).copy_to_buffer(&also_bars); + let bars_data = bars.view(..).copy_to_vec(); + let also_bars_data = also_bars.view(..).copy_to_vec(); + assert_eq!(bars_data, also_bars_data); + + let also_bars = device.create_buffer(1024); + bars_soa.view(1024..2048).copy_to_buffer(&also_bars); + let bars_data = bars.view(..).copy_to_vec(); + let also_bars_data = also_bars.view(..).copy_to_vec(); + assert_eq!(bars_data, also_bars_data); } #[test] fn atomic() { diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 46d8999..f865c63 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 46d89997f6495156083d6cf2d138c50442c708d1 +Subproject commit f865c635ba50732a5edda22ff1873e233fc66f08