Skip to content

Commit

Permalink
imp IndexRead for VLArray
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
1 parent 0fa96b9 commit 21cb900
Showing 1 changed file with 37 additions and 32 deletions.
69 changes: 37 additions & 32 deletions luisa_compute/src/lang/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,44 +187,48 @@ impl<T: Value> Aggregate for VLArrayVar<T> {
Self::from_node(iter.next().unwrap())
}
}

impl<T: Value> VLArrayVar<T> {
pub fn read<I: Into<Expr<u32>>>(&self, i: I) -> Expr<T> {
let i = i.into();
impl<T: Value> IndexRead for VLArrayVar<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
let i = i.to_u64();
if need_runtime_check() {
lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds");
check_index_lt_usize(i, self.len());
}

Expr::<T>::from_node(__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_());
b.call(Func::Load, &[gep], T::type_())
}))
}
pub fn len(&self) -> Expr<u32> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => (*length as u32).expr(),
_ => unreachable!(),
}
}
pub fn static_len(&self) -> usize {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => *length,
_ => unreachable!(),
}
}
pub fn write<I: Into<Expr<u32>>, V: Into<Expr<T>>>(&self, i: I, value: V) {
let i = i.into();
let value = value.into();
}
impl<T: Value> IndexWrite for VLArrayVar<T> {
fn write<I: IntoIndex, V: AsExpr<Value = Self::Element>>(&self, i: I, value: V) {
let i = i.to_u64();
let value = value.as_expr();

if need_runtime_check() {
lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds");
check_index_lt_usize(i, self.len());
}

__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_());
b.update(gep, value.node());
});
}
}
impl<T: Value> VLArrayVar<T> {
pub fn len_expr(&self) -> Expr<u64> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(),
_ => unreachable!(),
}
}
pub fn len(&self) -> usize {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => *length,
_ => unreachable!(),
}
}
pub fn load(&self) -> VLArrayExpr<T> {
VLArrayExpr::from_node(__current_scope(|b| {
b.call(Func::Load, &[self.node], self.node.type_().clone())
Expand All @@ -244,7 +248,18 @@ impl<T: Value> VLArrayVar<T> {
}))
}
}

impl<T: Value> IndexRead for VLArrayExpr<T> {
type Element = T;
fn read<I: IntoIndex>(&self, i: I) -> Expr<Self::Element> {
let i = i.to_u64();
if need_runtime_check() {
check_index_lt_usize(i, self.len());
}
Expr::<T>::from_node(__current_scope(|b| {
b.call(Func::ExtractElement, &[self.node, i.node()], T::type_())
}))
}
}
impl<T: Value> VLArrayExpr<T> {
pub fn zero(length: usize) -> Self {
let node = __current_scope(|b| {
Expand All @@ -265,16 +280,6 @@ impl<T: Value> VLArrayExpr<T> {
_ => unreachable!(),
}
}
pub fn read<I: IntoIndex>(&self, i: I) -> Expr<T> {
let i = i.to_u64();
if need_runtime_check() {
check_index_lt_usize(i, self.len());
}

Expr::<T>::from_node(__current_scope(|b| {
b.call(Func::ExtractElement, &[self.node, i.node()], T::type_())
}))
}
pub fn len_expr(&self) -> Expr<u64> {
match self.node.type_().as_ref() {
Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(),
Expand Down

0 comments on commit 21cb900

Please sign in to comment.