From 4db529c807dd2f278450044266289fe760411c90 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Thu, 21 Dec 2023 03:34:18 -0500 Subject: [PATCH] proxies for matrices --- luisa_compute/src/lang/types/vector.rs | 173 +++++++++++++++++++++++-- luisa_compute/tests/misc.rs | 26 +++- luisa_compute_sys/LuisaCompute | 2 +- 3 files changed, 190 insertions(+), 11 deletions(-) diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index ddbf8e0..204a805 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -457,17 +457,172 @@ where } } -impl_simple_expr_proxy!(SquareMatrixExpr2 for SquareMatrix<2>); -impl_simple_var_proxy!(SquareMatrixVar2 for SquareMatrix<2>); -impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef2 for SquareMatrix<2>); +macro_rules! matrix_proxies { + ($N:literal [ $($real_c:ident),* ] [ $($c:ident),* ]: $ExprName:ident, $VarName:ident, $AtomicName:ident, $SoaName:ident) => { + #[repr(C)] + #[derive(Copy, Clone)] + pub struct $ExprName { + self_: Expr>, + $(pub $c: Expr>),* + } + #[repr(C)] + #[derive(Copy, Clone)] + pub struct $VarName { + self_: Var>, + $(pub $c: Var>),* + } + #[repr(C)] + #[derive(Copy, Clone)] + pub struct $AtomicName { + self_: AtomicRef>, + $(pub $c: AtomicRef>),* + } + + #[repr(C)] + #[derive(Clone)] + pub struct $SoaName { + $(pub $c: as SoaValue>::SoaBuffer),* + } + + impl SoaValue for SquareMatrix<$N> { + type SoaBuffer = $SoaName; + } + impl SoaBufferProxy for $SoaName { + type Value = SquareMatrix<$N>; + #[allow(unused_assignments)] + fn from_soa_storage( + storage: ByteBufferVar, + meta: Expr, + global_offset: usize, + ) -> Self { + let s = < as SoaValue>::SoaBuffer as SoaBufferProxy>::num_buffers(); + let mut i = 0; + $( + let $c = as SoaValue>::SoaBuffer::from_soa_storage( + storage.clone(), + meta.clone(), + global_offset + i * s, + ); + i += 1; + if i >= $N { i = 0; } + )* + Self{ + $($c),* + } + } + fn num_buffers() -> usize { + < as SoaValue>::SoaBuffer as SoaBufferProxy>::num_buffers() * $N + } + } + impl IndexRead for $SoaName { + type Element = SquareMatrix<$N>; + fn read(&self, i: I) -> Expr { + let i = i.to_u64(); + $( + let $real_c = self.$real_c.read(i); + )* + SquareMatrix::<$N>::from_elems_expr([$($real_c),*]) + } + } + impl IndexWrite for $SoaName { + #[allow(unused_assignments)] + fn write>( + &self, + i: I, + value: V, + ) { + let i = i.to_u64(); + let v = value.as_expr(); + let mut comp = 0; + $( + { + let el = Expr::>::from_node(__extract::>(v.node(), comp)); + self.$real_c.write(i, el); + comp += 1; + } + )* + } + } + + impl ExprProxy for $ExprName { + type Value = SquareMatrix<$N>; + #[allow(unused_assignments)] + fn from_expr(e:Expr) -> Self { + let data: [Expr>;$N] = std::array::from_fn(|i| { + FromNode::from_node(__extract::>(e.node(), i)) + }); + let mut i = 0; + $( + let $c = data[i].clone(); + i += 1; + if i >= $N { i = 0; } + )* + Self{ + self_: e, + $($c),* + } + } + fn as_expr_from_proxy(&self)->&Expr { + &self.self_ + } + } + impl Deref for $VarName{ + type Target = Expr>; + fn deref(&self) -> &Self::Target { + _deref_proxy(self) + } + } + impl AtomicRefProxy for $AtomicName { + type Value = SquareMatrix<$N>; + #[allow(unused_assignments)] + fn from_atomic_ref(e:AtomicRef) -> Self { + let data: [AtomicRef>;$N] = std::array::from_fn(|i| { + FromNode::from_node(__extract::>(e.node(), i)) + }); + let mut i = 0; + $( + let $c = data[i].clone(); + i += 1; + if i >= $N { i = 0; } + )* + Self{ + self_: e, + $($c),* + } + } + fn as_atomic_ref_from_proxy(&self)->&AtomicRef { + &self.self_ + } + } -impl_simple_expr_proxy!(SquareMatrixExpr3 for SquareMatrix<3>); -impl_simple_var_proxy!(SquareMatrixVar3 for SquareMatrix<3>); -impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef3 for SquareMatrix<3>); + impl VarProxy for $VarName { + type Value = SquareMatrix<$N>; + #[allow(unused_assignments)] + fn from_var(e:Var) -> Self { + let data: [Var>;$N] = std::array::from_fn(|i| { + FromNode::from_node(__extract::>(e.node(), i)) + }); + let mut i = 0; + $( + let $c = data[i].clone(); + i += 1; + if i >= $N { i = 0; } + )* + Self{ + self_: e, + $($c),* + } + } + fn as_var_from_proxy(&self)->&Var { + &self.self_ + } + } + } +} -impl_simple_expr_proxy!(SquareMatrixExpr4 for SquareMatrix<4>); -impl_simple_var_proxy!(SquareMatrixVar4 for SquareMatrix<4>); -impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef4 for SquareMatrix<4>); +matrix_proxies!(2 [x, y] [x, y]: SquareMatrixExpr2, SquareMatrixVar2, SquareMatrixAtomicRef2, SquareMatrixSoaProxy2); +matrix_proxies!(3 [x, y, z] [x, y, z]: SquareMatrixExpr3, SquareMatrixVar3, SquareMatrixAtomicRef3, SquareMatrixSoaProxy3); +matrix_proxies!(4 [x, y, z, w] [x, y, z, w]: SquareMatrixExpr4, SquareMatrixVar4, SquareMatrixAtomicRef4, SquareMatrixSoaProxy4); impl Value for SquareMatrix<2> { type Expr = SquareMatrixExpr2; diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 1bff4d5..74b0d99 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use luisa::lang::types::array::VLArrayVar; use luisa::lang::types::dynamic::*; -use luisa::lang::types::vector::alias::*; +use luisa::lang::types::vector::{alias::*, Mat2}; use luisa::prelude::*; use luisa_compute as luisa; use luisa_compute_api_types::StreamTag; @@ -1461,6 +1461,7 @@ pub struct Foo { i: u32, v: Float2, a: [i32; 4], + m: Mat2, } #[derive(Clone, Copy, Debug, Value, Soa, PartialEq)] #[repr(C)] @@ -1483,6 +1484,7 @@ fn soa() { i: rng.gen(), v: Float2::new(rng.gen(), rng.gen()), a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]), }, }); let bars_soa = device.create_soa_buffer::(1024); @@ -1505,6 +1507,7 @@ fn soa_view() { i: rng.gen(), v: Float2::new(rng.gen(), rng.gen()), a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]), }, }); let bars_soa = device.create_soa_buffer::(2048); @@ -1531,16 +1534,19 @@ fn atomic() { i: rng.gen(), v: Float2::new(rng.gen(), rng.gen()), a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()], + m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]), }); let foo_max_init = Foo { i: u32::MIN, v: Float2::new(f32::MIN, f32::MIN), a: [i32::MIN; 4], + m: Mat2::from_column_array(&[[f32::MIN; 2]; 2]), }; let foo_min_init = Foo { i: u32::MAX, v: Float2::new(f32::MAX, f32::MAX), a: [i32::MAX; 4], + m: Mat2::from_column_array(&[[f32::MAX; 2]; 2]), }; let foo_max = device.create_buffer_from_slice(&[foo_max_init]); let foo_min = device.create_buffer_from_slice(&[foo_min_init]); @@ -1558,12 +1564,21 @@ fn atomic() { for i in 0..4u32 { foo_max.a[i].fetch_max(foo.a[i]); } + foo_max.m.x.x.fetch_max(foo.m.x.x); + foo_max.m.x.y.fetch_max(foo.m.x.y); + foo_max.m.y.x.fetch_max(foo.m.y.x); + foo_max.m.y.y.fetch_max(foo.m.y.y); + foo_min.i.fetch_min(foo.i); foo_min.v.x.fetch_min(foo.v.x); foo_min.v.y.fetch_min(foo.v.y); for i in 0..4u32 { foo_min.a[i].fetch_min(foo.a[i]); } + foo_min.m.x.x.fetch_min(foo.m.x.x); + foo_min.m.x.y.fetch_min(foo.m.x.y); + foo_min.m.y.x.fetch_min(foo.m.y.x); + foo_min.m.y.y.fetch_min(foo.m.y.y); }), ); kernel.dispatch([foos.len() as u32, 1, 1]); @@ -1579,12 +1594,21 @@ fn atomic() { for i in 0..4 { expected_foo_max.a[i] = expected_foo_max.a[i].max(foo.a[i]); } + expected_foo_max.m.cols[0].x = expected_foo_max.m.cols[0].x.max(foo.m.cols[0].x); + expected_foo_max.m.cols[0].y = expected_foo_max.m.cols[0].y.max(foo.m.cols[0].y); + expected_foo_max.m.cols[1].x = expected_foo_max.m.cols[1].x.max(foo.m.cols[1].x); + expected_foo_max.m.cols[1].y = expected_foo_max.m.cols[1].y.max(foo.m.cols[1].y); + expected_foo_min.i = expected_foo_min.i.min(foo.i); expected_foo_min.v.x = expected_foo_min.v.x.min(foo.v.x); expected_foo_min.v.y = expected_foo_min.v.y.min(foo.v.y); for i in 0..4 { expected_foo_min.a[i] = expected_foo_min.a[i].min(foo.a[i]); } + expected_foo_min.m.cols[0].x = expected_foo_min.m.cols[0].x.min(foo.m.cols[0].x); + expected_foo_min.m.cols[0].y = expected_foo_min.m.cols[0].y.min(foo.m.cols[0].y); + expected_foo_min.m.cols[1].x = expected_foo_min.m.cols[1].x.min(foo.m.cols[1].x); + expected_foo_min.m.cols[1].y = expected_foo_min.m.cols[1].y.min(foo.m.cols[1].y); } assert_eq!(foo_max, expected_foo_max); assert_eq!(foo_min, expected_foo_min); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 53b0b2c..4cfe950 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 53b0b2c3efbda8ac054308dbedf93f69e058d29c +Subproject commit 4cfe950e7f254eb6fc8a5fcf35d7747dc1a84eac