Skip to content

Commit

Permalink
fixed minor
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 8, 2023
1 parent 775485b commit 7d31694
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 13 deletions.
16 changes: 12 additions & 4 deletions luisa_compute/src/lang/soa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ impl<T: SoaValue> SoaBufferCopyKernel<T> {
fn new(device: &Device) -> Self {
let copy_to =
device.create_kernel::<fn(SoaBuffer<T>, Buffer<T>, u64)>(&|soa, buf, offset| {
let i = dispatch_id().x.as_u64() + offset;
let v = soa.read(i);
let i = dispatch_id().x.as_u64();
let v = soa.read(i + offset);
buf.write(i, v);
});
let copy_from =
device.create_kernel::<fn(SoaBuffer<T>, Buffer<T>, u64)>(&|soa, buf, offset| {
let i = dispatch_id().x.as_u64() + offset;
let i = dispatch_id().x.as_u64();
let v = buf.read(i);
soa.write(i, v);
soa.write(i + offset, v);
});
Self { copy_to, copy_from }
}
Expand Down Expand Up @@ -108,6 +108,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> {
}
pub fn copy_from_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.init_copy_kernel();
assert_eq!(self.metadata.view_count, buffer.len() as u64);
let copy_kernel = self.buffer.copy_kernel.lock();
let copy_kernel = copy_kernel.as_ref().unwrap();
copy_kernel.copy_from.dispatch_async(
Expand All @@ -122,6 +123,7 @@ impl<'a, T: SoaValue> SoaBufferView<'a, T> {
}
pub fn copy_to_buffer_async(&self, buffer: &Buffer<T>) -> Command<'static, 'static> {
self.init_copy_kernel();
assert_eq!(self.metadata.view_count, buffer.len() as u64);
let copy_kernel = self.buffer.copy_kernel.lock();
let copy_kernel = copy_kernel.as_ref().unwrap();
copy_kernel.copy_to.dispatch_async(
Expand Down Expand Up @@ -152,6 +154,12 @@ pub struct SoaBufferView<'a, T: SoaValue> {
pub struct SoaBufferVar<T: SoaValue> {
pub(crate) proxy: T::SoaBuffer,
}
impl<T: SoaValue> std::ops::Deref for SoaBufferVar<T> {
type Target = T::SoaBuffer;
fn deref(&self) -> &Self::Target {
&self.proxy
}
}
impl<T: SoaValue> IndexRead for SoaBufferVar<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
Expand Down
60 changes: 52 additions & 8 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1181,22 +1181,66 @@ pub struct Foo {
v: Float2,
a: [i32; 4],
}
#[derive(Clone, Copy, Debug, Value, Soa, PartialEq)]
#[repr(C)]
#[value_new(pub)]
pub struct Bar {
i: u32,
v: Float2,
a: [i32; 4],
f: Foo,
}
#[test]
fn soa() {
let device = get_device();
let mut rng = thread_rng();
let foos = device.create_buffer_from_fn(1024, |_| Foo {
let bars = device.create_buffer_from_fn(1024, |_| Bar {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
f: Foo {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
},
});
let foos_soa = device.create_soa_buffer::<Foo>(1024);
foos_soa.copy_from_buffer(&foos);
let also_foos = device.create_buffer(1024);
foos_soa.copy_to_buffer(&also_foos);
let foos_data = foos.view(..).copy_to_vec();
let also_foos_data = also_foos.view(..).copy_to_vec();
assert_eq!(foos_data, also_foos_data);
let bars_soa = device.create_soa_buffer::<Bar>(1024);
bars_soa.copy_from_buffer(&bars);
let also_bars = device.create_buffer(1024);
bars_soa.copy_to_buffer(&also_bars);
let bars_data = bars.view(..).copy_to_vec();
let also_bars_data = also_bars.view(..).copy_to_vec();
assert_eq!(bars_data, also_bars_data);
}
#[test]
fn soa_view() {
let device = get_device();
let mut rng = thread_rng();
let bars = device.create_buffer_from_fn(1024, |_| Bar {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
f: Foo {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
},
});
let bars_soa = device.create_soa_buffer::<Bar>(2048);
bars_soa.view(..1024).copy_from_buffer(&bars);
bars_soa.view(1024..2048).copy_from_buffer(&bars);

let also_bars = device.create_buffer(1024);
bars_soa.view(..1024).copy_to_buffer(&also_bars);
let bars_data = bars.view(..).copy_to_vec();
let also_bars_data = also_bars.view(..).copy_to_vec();
assert_eq!(bars_data, also_bars_data);

let also_bars = device.create_buffer(1024);
bars_soa.view(1024..2048).copy_to_buffer(&also_bars);
let bars_data = bars.view(..).copy_to_vec();
let also_bars_data = also_bars.view(..).copy_to_vec();
assert_eq!(bars_data, also_bars_data);
}
#[test]
fn atomic() {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_sys/LuisaCompute
Submodule LuisaCompute updated 38 files
+1 −0 include/luisa/ast/ast2json.h
+1 −1 include/luisa/dsl/builtin.h
+11 −0 include/luisa/dsl/resource.h
+1 −78 include/luisa/ir/ast2ir.h
+8 −4 include/luisa/runtime/rhi/command.h
+3 −0 include/luisa/rust/ir.hpp
+11 −0 src/ast/ast2json.cpp
+1 −0 src/ast/expression.cpp
+2 −4 src/backends/common/hlsl/builtin/accel_header
+2 −2 src/backends/common/hlsl/builtin/accel_header.c
+17 −7 src/backends/common/hlsl/hlsl_codegen_util.cpp
+12 −1 src/backends/cuda/cuda_builtin/cuda_device_resource.h
+3,964 −3,935 src/backends/cuda/cuda_builtin_embedded.cpp
+1 −1 src/backends/cuda/cuda_builtin_embedded.h
+23 −14 src/backends/cuda/cuda_shader_native.cpp
+15 −4 src/backends/cuda/cuda_shader_optix.cpp
+9 −9 src/backends/dx/d3dx12.h
+4 −4 src/backends/metal/metal_bindless_array.cpp
+1 −1 src/backends/metal/metal_bindless_array.h
+15 −1 src/backends/metal/metal_builtin/metal_device_lib.metal
+2,322 −2,294 src/backends/metal/metal_builtin_embedded.cpp
+1 −1 src/backends/metal/metal_builtin_embedded.h
+14 −13 src/backends/metal/metal_codegen_ast.cpp
+19 −12 src/backends/metal/metal_shader.cpp
+33 −1,480 src/ir/ast2ir.cpp
+4 −1 src/ir/ir2ast.cpp
+11 −11 src/py/luisa/autodiff.py
+6 −0 src/rust/luisa_compute_backend_impl/src/cpu/accel.rs
+21 −8 src/rust/luisa_compute_backend_impl/src/cpu/codegen/cpp.rs
+8 −0 src/rust/luisa_compute_backend_impl/src/cpu/codegen/cpu_resource.h
+14 −5 src/rust/luisa_compute_backend_impl/src/cpu/stream.rs
+1 −0 src/rust/luisa_compute_cpu_kernel_defs/cpu_kernel_defs.h
+1 −0 src/rust/luisa_compute_cpu_kernel_defs/src/lib.rs
+91 −44 src/rust/luisa_compute_ir/src/ast2ir.rs
+18 −2 src/rust/luisa_compute_ir/src/ir.rs
+4 −0 src/rust/luisa_compute_ir/src/serialize/convert.rs
+2 −0 src/rust/luisa_compute_ir/src/serialize/mod.rs
+4 −3 src/tests/test_autodiff_full.cpp

0 comments on commit 7d31694

Please sign in to comment.