Skip to content

Commit

Permalink
added VLA
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Mar 20, 2023
1 parent edd003e commit 231c2d0
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 4 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ jobs:
- name: "Build"
run: CC=clang-14 CXX=clang++-14 cargo build --verbose --release
- name: "Run Tests"
run: CC=clang-14 CXX=clang++-14 cargo test --verbose --release
run: |
CC=clang-14 CXX=clang++-14 cargo test --verbose --release
bash run_tests.sh
# - name: "Run CUDA Tests"
# run: CC=clang-14 CXX=clang++-14 LUISA_TEST_DEVICE=cuda cargo test --features cuda --verbose --release
144 changes: 141 additions & 3 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use bumpalo::Bump;
use ir::context::type_hash;
pub use ir::ir::NodeRef;
use ir::ir::{
AccelBinding, BindlessArrayBinding, ModulePools, SwitchCase, TextureBinding, UserNodeData,
INVALID_REF,
AccelBinding, ArrayType, BindlessArrayBinding, ModulePools, SwitchCase, TextureBinding,
UserNodeData, INVALID_REF,
};
pub use ir::CArc;
use ir::Pooled;
Expand Down Expand Up @@ -593,6 +593,144 @@ impl<T: Value, const N: usize> Value for [T; N] {
}
}
#[derive(Clone, Copy, Debug)]
pub struct VLArrayExpr<T: Value> {
marker: std::marker::PhantomData<T>,
node: NodeRef,
}
impl<T: Value> FromNode for VLArrayExpr<T> {
fn from_node(node: NodeRef) -> Self {
Self {
marker: std::marker::PhantomData,
node,
}
}
fn node(&self) -> NodeRef {
self.node
}
}
impl<T: Value> Aggregate for VLArrayExpr<T> {
fn to_nodes(&self, nodes: &mut Vec<NodeRef>) {
nodes.push(self.node);
}
fn from_nodes<I: Iterator<Item = NodeRef>>(iter: &mut I) -> Self {
Self::from_node(iter.next().unwrap())
}
}
#[derive(Clone, Copy, Debug)]
pub struct VLArrayVar<T: Value> {
marker: std::marker::PhantomData<T>,
node: NodeRef,
}
impl<T: Value> FromNode for VLArrayVar<T> {
fn from_node(node: NodeRef) -> Self {
Self {
marker: std::marker::PhantomData,
node,
}
}
fn node(&self) -> NodeRef {
self.node
}
}
impl<T: Value> Aggregate for VLArrayVar<T> {
fn to_nodes(&self, nodes: &mut Vec<NodeRef>) {
nodes.push(self.node);
}
fn from_nodes<I: Iterator<Item = NodeRef>>(iter: &mut I) -> Self {
Self::from_node(iter.next().unwrap())
}
}
impl<T: Value> VLArrayVar<T> {
pub fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
if __env_need_backtrace() {
assert(i.cmplt(self.len()));
}
Expr::<T>::from_node(__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_());
b.call(Func::Load, &[gep], T::type_())
}))
}
pub fn len(&self) -> Expr<u32> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => const_(*length as u32),
_ => unreachable!(),
}
}
pub fn static_len(&self) -> usize {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => *length,
_ => unreachable!(),
}
}
pub fn write<I: Into<Expr<u32>>, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.into();
let value = value.into();
if __env_need_backtrace() {
assert(i.cmplt(self.len()));
}
__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_());
b.update(gep, value.node());
});
}
pub fn load(&self) -> VLArrayExpr<T> {
VLArrayExpr::from_node(__current_scope(|b| {
b.call(Func::Load, &[self.node], self.node.type_().clone())
}))
}
pub fn store(&self, value: VLArrayExpr<T>) {
__current_scope(|b| {
b.update(self.node, value.node);
});
}
pub fn zero(length: usize) -> Self {
FromNode::from_node(__current_scope(|b| {
b.local_zero_init(ir::context::register_type(Type::Array(ArrayType {
element: T::type_(),
length,
})))
}))
}
}
impl<T: Value> VLArrayExpr<T> {
pub fn zero(length: usize) -> Self {
let node = __current_scope(|b| {
b.call(
Func::ZeroInitializer,
&[],
ir::context::register_type(Type::Array(ArrayType {
element: T::type_(),
length,
})),
)
});
Self::from_node(node)
}
pub fn static_len(&self) -> usize {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => *length,
_ => unreachable!(),
}
}
pub fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
if __env_need_backtrace() {
assert(i.cmplt(self.len()));
}
Expr::<T>::from_node(__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_());
b.call(Func::Load, &[gep], T::type_())
}))
}
pub fn len(&self) -> Expr<u32> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => const_(*length as u32),
_ => unreachable!(),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct ArrayExpr<T: Value, const N: usize> {
marker: std::marker::PhantomData<T>,
node: NodeRef,
Expand Down Expand Up @@ -1427,7 +1565,7 @@ pub struct RtxHit {
}
impl RtxHitExpr {
pub fn valid(&self) -> Expr<bool> {
self.inst_id().cmpne(u32::MAX) & self.prim_id().cmpne(u32::MAX)
self.inst_id().cmpne(u32::MAX) & self.prim_id().cmpne(u32::MAX)
}
}
impl AccelVar {
Expand Down
39 changes: 39 additions & 0 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,45 @@ fn array_read_write2() {
}
}
#[test]
fn array_read_write_vla() {
init();
let device = get_device();
let x: Buffer<[i32; 4]> = device.create_buffer(1024).unwrap();
let y: Buffer<i32> = device.create_buffer(1024).unwrap();
let kernel = device
.create_kernel::<()>(&|| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x();
let vl = VLArrayVar::<i32>::zero(4);
let i = local_zeroed::<i32>();
while_!(i.load().cmplt(4), {
vl.write(i.load().uint(), tid.int() + i.load());
i.store(i.load() + 1);
});
let arr = local_zeroed::<[i32; 4]>();
let i = local_zeroed::<i32>();
while_!(i.load().cmplt(4), {
arr.write(i.load().uint(), vl.read(i.load().uint()));
i.store(i.load() + 1);
});
let arr = arr.load();
buf_x.write(tid, arr);
buf_y.write(tid, arr.read(0));
})
.unwrap();
kernel.dispatch([1024, 1, 1]).unwrap();
let x_data = x.view(..).copy_to_vec();
let y_data = y.view(..).copy_to_vec();
for i in 0..1024 {
assert_eq!(
x_data[i],
[i as i32, i as i32 + 1, i as i32 + 2, i as i32 + 3]
);
assert_eq!(y_data[i], i as i32);
}
}
#[test]
fn array_read_write_async_compile() {
init();
let device = get_device();
Expand Down
8 changes: 8 additions & 0 deletions run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cargo run --release --example atomic
cargo run --release --example autodiff
cargo run --release --example bindless
cargo run --release --example custom_aggreagate
cargo run --release --example custom_op
cargo run --release --example polymorphism
cargo run --release --example raytracing
cargo run --release --example vecadd

0 comments on commit 231c2d0

Please sign in to comment.