From 231c2d00b9b12d26e3eaa6b4be0d2a55e102818d Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Mon, 20 Mar 2023 17:19:45 -0400 Subject: [PATCH] added VLA --- .github/workflows/rust.yml | 4 +- luisa_compute/src/lang/mod.rs | 144 +++++++++++++++++++++++++++++++++- luisa_compute/tests/misc.rs | 39 +++++++++ run_tests.sh | 8 ++ 4 files changed, 191 insertions(+), 4 deletions(-) create mode 100644 run_tests.sh diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c373e34..ded11b3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -37,6 +37,8 @@ jobs: - name: "Build" run: CC=clang-14 CXX=clang++-14 cargo build --verbose --release - name: "Run Tests" - run: CC=clang-14 CXX=clang++-14 cargo test --verbose --release + run: | + CC=clang-14 CXX=clang++-14 cargo test --verbose --release + bash run_tests.sh # - name: "Run CUDA Tests" # run: CC=clang-14 CXX=clang++-14 LUISA_TEST_DEVICE=cuda cargo test --features cuda --verbose --release diff --git a/luisa_compute/src/lang/mod.rs b/luisa_compute/src/lang/mod.rs index 1168aee..dc89a49 100644 --- a/luisa_compute/src/lang/mod.rs +++ b/luisa_compute/src/lang/mod.rs @@ -20,8 +20,8 @@ use bumpalo::Bump; use ir::context::type_hash; pub use ir::ir::NodeRef; use ir::ir::{ - AccelBinding, BindlessArrayBinding, ModulePools, SwitchCase, TextureBinding, UserNodeData, - INVALID_REF, + AccelBinding, ArrayType, BindlessArrayBinding, ModulePools, SwitchCase, TextureBinding, + UserNodeData, INVALID_REF, }; pub use ir::CArc; use ir::Pooled; @@ -593,6 +593,144 @@ impl Value for [T; N] { } } #[derive(Clone, Copy, Debug)] +pub struct VLArrayExpr { + marker: std::marker::PhantomData, + node: NodeRef, +} +impl FromNode for VLArrayExpr { + fn from_node(node: NodeRef) -> Self { + Self { + marker: std::marker::PhantomData, + node, + } + } + fn node(&self) -> NodeRef { + self.node + } +} +impl Aggregate for VLArrayExpr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} +#[derive(Clone, Copy, Debug)] +pub struct VLArrayVar { + marker: std::marker::PhantomData, + node: NodeRef, +} +impl FromNode for VLArrayVar { + fn from_node(node: NodeRef) -> Self { + Self { + marker: std::marker::PhantomData, + node, + } + } + fn node(&self) -> NodeRef { + self.node + } +} +impl Aggregate for VLArrayVar { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + Self::from_node(iter.next().unwrap()) + } +} +impl VLArrayVar { + pub fn read>>(&self, i: I) -> Expr { + let i = i.into(); + if __env_need_backtrace() { + assert(i.cmplt(self.len())); + } + Expr::::from_node(__current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.call(Func::Load, &[gep], T::type_()) + })) + } + pub fn len(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => const_(*length as u32), + _ => unreachable!(), + } + } + pub fn static_len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } + pub fn write>, V: Into>>(&self, i: I, value: V) { + let i = i.into(); + let value = value.into(); + if __env_need_backtrace() { + assert(i.cmplt(self.len())); + } + __current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.update(gep, value.node()); + }); + } + pub fn load(&self) -> VLArrayExpr { + VLArrayExpr::from_node(__current_scope(|b| { + b.call(Func::Load, &[self.node], self.node.type_().clone()) + })) + } + pub fn store(&self, value: VLArrayExpr) { + __current_scope(|b| { + b.update(self.node, value.node); + }); + } + pub fn zero(length: usize) -> Self { + FromNode::from_node(__current_scope(|b| { + b.local_zero_init(ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length, + }))) + })) + } +} +impl VLArrayExpr { + pub fn zero(length: usize) -> Self { + let node = __current_scope(|b| { + b.call( + Func::ZeroInitializer, + &[], + ir::context::register_type(Type::Array(ArrayType { + element: T::type_(), + length, + })), + ) + }); + Self::from_node(node) + } + pub fn static_len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } + pub fn read>>(&self, i: I) -> Expr { + let i = i.into(); + if __env_need_backtrace() { + assert(i.cmplt(self.len())); + } + Expr::::from_node(__current_scope(|b| { + let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); + b.call(Func::Load, &[gep], T::type_()) + })) + } + pub fn len(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => const_(*length as u32), + _ => unreachable!(), + } + } +} +#[derive(Clone, Copy, Debug)] pub struct ArrayExpr { marker: std::marker::PhantomData, node: NodeRef, @@ -1427,7 +1565,7 @@ pub struct RtxHit { } impl RtxHitExpr { pub fn valid(&self) -> Expr { - self.inst_id().cmpne(u32::MAX) & self.prim_id().cmpne(u32::MAX) + self.inst_id().cmpne(u32::MAX) & self.prim_id().cmpne(u32::MAX) } } impl AccelVar { diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 18fd12b..6044977 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -411,6 +411,45 @@ fn array_read_write2() { } } #[test] +fn array_read_write_vla() { + init(); + let device = get_device(); + let x: Buffer<[i32; 4]> = device.create_buffer(1024).unwrap(); + let y: Buffer = device.create_buffer(1024).unwrap(); + let kernel = device + .create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x(); + let vl = VLArrayVar::::zero(4); + let i = local_zeroed::(); + while_!(i.load().cmplt(4), { + vl.write(i.load().uint(), tid.int() + i.load()); + i.store(i.load() + 1); + }); + let arr = local_zeroed::<[i32; 4]>(); + let i = local_zeroed::(); + while_!(i.load().cmplt(4), { + arr.write(i.load().uint(), vl.read(i.load().uint())); + i.store(i.load() + 1); + }); + let arr = arr.load(); + buf_x.write(tid, arr); + buf_y.write(tid, arr.read(0)); + }) + .unwrap(); + kernel.dispatch([1024, 1, 1]).unwrap(); + let x_data = x.view(..).copy_to_vec(); + let y_data = y.view(..).copy_to_vec(); + for i in 0..1024 { + assert_eq!( + x_data[i], + [i as i32, i as i32 + 1, i as i32 + 2, i as i32 + 3] + ); + assert_eq!(y_data[i], i as i32); + } +} +#[test] fn array_read_write_async_compile() { init(); let device = get_device(); diff --git a/run_tests.sh b/run_tests.sh new file mode 100644 index 0000000..d02213a --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,8 @@ +cargo run --release --example atomic +cargo run --release --example autodiff +cargo run --release --example bindless +cargo run --release --example custom_aggreagate +cargo run --release --example custom_op +cargo run --release --example polymorphism +cargo run --release --example raytracing +cargo run --release --example vecadd \ No newline at end of file