diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index c8249b4..47c8a5b 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -154,6 +154,23 @@ pub trait Aggregate: Sized { fn from_nodes>(iter: &mut I) -> Self; } +impl Aggregate for [T; N] { + fn to_nodes(&self, nodes: &mut Vec) { + for x in self { + x.to_nodes(nodes); + } + } + fn from_nodes>(iter: &mut I) -> Self { + unsafe { + let mut ret = std::mem::MaybeUninit::<[T; N]>::uninit(); + for i in 0..N { + let x = T::from_nodes(iter); + ret.as_mut_ptr().cast::().add(i).write(x); + } + ret.assume_init() + } + } +} impl Aggregate for Vec { fn to_nodes(&self, nodes: &mut Vec) { let len_node = __new_user_node(nodes.len()); @@ -427,7 +444,8 @@ impl FnRecorder { device: None, block_size: None, pools: pools.clone(), - arena: parent.as_ref() + arena: parent + .as_ref() .map(|p| p.borrow().arena.clone()) .unwrap_or_else(|| Rc::new(Bump::new())), building_kernel: false,