diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index a25d6a7..12897a9 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -187,12 +187,12 @@ impl Aggregate for VLArrayVar { Self::from_node(iter.next().unwrap()) } } - -impl VLArrayVar { - pub fn read>>(&self, i: I) -> Expr { - let i = i.into(); +impl IndexRead for VLArrayVar { + type Element = T; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); + check_index_lt_usize(i, self.len()); } Expr::::from_node(__current_scope(|b| { @@ -200,24 +200,14 @@ impl VLArrayVar { b.call(Func::Load, &[gep], T::type_()) })) } - pub fn len(&self) -> Expr { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => (*length as u32).expr(), - _ => unreachable!(), - } - } - pub fn static_len(&self) -> usize { - match self.node.type_().as_ref() { - Type::Array(ArrayType { element: _, length }) => *length, - _ => unreachable!(), - } - } - pub fn write>, V: Into>>(&self, i: I, value: V) { - let i = i.into(); - let value = value.into(); +} +impl IndexWrite for VLArrayVar { + fn write>(&self, i: I, value: V) { + let i = i.to_u64(); + let value = value.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); + check_index_lt_usize(i, self.len()); } __current_scope(|b| { @@ -225,6 +215,20 @@ impl VLArrayVar { b.update(gep, value.node()); }); } +} +impl VLArrayVar { + pub fn len_expr(&self) -> Expr { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(), + _ => unreachable!(), + } + } + pub fn len(&self) -> usize { + match self.node.type_().as_ref() { + Type::Array(ArrayType { element: _, length }) => *length, + _ => unreachable!(), + } + } pub fn load(&self) -> VLArrayExpr { VLArrayExpr::from_node(__current_scope(|b| { b.call(Func::Load, &[self.node], self.node.type_().clone()) @@ -244,7 +248,18 @@ impl VLArrayVar { })) } } - +impl IndexRead for VLArrayExpr { + type Element = T; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + if need_runtime_check() { + check_index_lt_usize(i, self.len()); + } + Expr::::from_node(__current_scope(|b| { + b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) + })) + } +} impl VLArrayExpr { pub fn zero(length: usize) -> Self { let node = __current_scope(|b| { @@ -265,16 +280,6 @@ impl VLArrayExpr { _ => unreachable!(), } } - pub fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - check_index_lt_usize(i, self.len()); - } - - Expr::::from_node(__current_scope(|b| { - b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) - })) - } pub fn len_expr(&self) -> Expr { match self.node.type_().as_ref() { Type::Array(ArrayType { element: _, length }) => (*length as u64).expr(),