diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index a9c8c36..7f9812f 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -93,6 +93,22 @@ impl Index for ArrayExpr { ._ref() } } + +impl Index for ArrayVar { + type Output = Var; + fn index(&self, i: X) -> &Self::Output { + let i = i.to_u64(); + + // TODO: Add need_runtime_check()? + if need_runtime_check() { + check_index_lt_usize(i, N); + } + let i = i.node().get(); + let self_node = self.0.node().get(); + Var::::from_node(__current_scope(|b| b.gep_chained(self_node, &[i], T::type_())).into()) + ._ref() + } +} impl Index for ArrayAtomicRef { type Output = AtomicRef; fn index(&self, i: X) -> &Self::Output { diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 74b0d99..a0d36d2 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -922,7 +922,7 @@ fn array_read_write() { let arr = Var::<[i32; 4]>::zeroed(); let i = i32::var_zeroed(); while i < 4 { - arr.write(i.as_u32(), tid.as_i32() + i); + *arr[i.cast_u32()] = tid.as_i32() + i; *i += 1; } buf_x.write(tid, arr);