Skip to content

Commit

Permalink
proxies for matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 21, 2023
1 parent 9523435 commit 4db529c
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 11 deletions.
173 changes: 164 additions & 9 deletions luisa_compute/src/lang/types/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,172 @@ where
}
}

impl_simple_expr_proxy!(SquareMatrixExpr2 for SquareMatrix<2>);
impl_simple_var_proxy!(SquareMatrixVar2 for SquareMatrix<2>);
impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef2 for SquareMatrix<2>);
macro_rules! matrix_proxies {
($N:literal [ $($real_c:ident),* ] [ $($c:ident),* ]: $ExprName:ident, $VarName:ident, $AtomicName:ident, $SoaName:ident) => {
#[repr(C)]
#[derive(Copy, Clone)]
pub struct $ExprName {
self_: Expr<SquareMatrix<$N>>,
$(pub $c: Expr<Vector<f32, $N>>),*
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct $VarName {
self_: Var<SquareMatrix<$N>>,
$(pub $c: Var<Vector<f32, $N>>),*
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct $AtomicName {
self_: AtomicRef<SquareMatrix<$N>>,
$(pub $c: AtomicRef<Vector<f32, $N>>),*
}

#[repr(C)]
#[derive(Clone)]
pub struct $SoaName {
$(pub $c: <Vector<f32, $N> as SoaValue>::SoaBuffer),*
}

impl SoaValue for SquareMatrix<$N> {
type SoaBuffer = $SoaName;
}
impl SoaBufferProxy for $SoaName {
type Value = SquareMatrix<$N>;
#[allow(unused_assignments)]
fn from_soa_storage(
storage: ByteBufferVar,
meta: Expr<SoaMetadata>,
global_offset: usize,
) -> Self {
let s = <<Vector<f32,$N> as SoaValue>::SoaBuffer as SoaBufferProxy>::num_buffers();
let mut i = 0;
$(
let $c = <Vector<f32,$N> as SoaValue>::SoaBuffer::from_soa_storage(
storage.clone(),
meta.clone(),
global_offset + i * s,
);
i += 1;
if i >= $N { i = 0; }
)*
Self{
$($c),*
}
}
fn num_buffers() -> usize {
<<Vector::<f32,$N> as SoaValue>::SoaBuffer as SoaBufferProxy>::num_buffers() * $N
}
}
impl IndexRead for $SoaName {
type Element = SquareMatrix<$N>;
fn read<I: crate::lang::index::IntoIndex>(&self, i: I) -> Expr<Self::Element> {
let i = i.to_u64();
$(
let $real_c = self.$real_c.read(i);
)*
SquareMatrix::<$N>::from_elems_expr([$($real_c),*])
}
}
impl IndexWrite for $SoaName {
#[allow(unused_assignments)]
fn write<I: crate::lang::index::IntoIndex, V: AsExpr<Value = Self::Element>>(
&self,
i: I,
value: V,
) {
let i = i.to_u64();
let v = value.as_expr();
let mut comp = 0;
$(
{
let el = Expr::<Vector<f32, $N>>::from_node(__extract::<Vector<f32, $N>>(v.node(), comp));
self.$real_c.write(i, el);
comp += 1;
}
)*
}
}

impl ExprProxy for $ExprName {
type Value = SquareMatrix<$N>;
#[allow(unused_assignments)]
fn from_expr(e:Expr<Self::Value>) -> Self {
let data: [Expr<Vector<f32, $N>>;$N] = std::array::from_fn(|i| {
FromNode::from_node(__extract::<Vector<f32, $N>>(e.node(), i))
});
let mut i = 0;
$(
let $c = data[i].clone();
i += 1;
if i >= $N { i = 0; }
)*
Self{
self_: e,
$($c),*
}
}
fn as_expr_from_proxy(&self)->&Expr<Self::Value> {
&self.self_
}
}
impl Deref for $VarName{
type Target = Expr<SquareMatrix<$N>>;
fn deref(&self) -> &Self::Target {
_deref_proxy(self)
}
}
impl AtomicRefProxy for $AtomicName {
type Value = SquareMatrix<$N>;
#[allow(unused_assignments)]
fn from_atomic_ref(e:AtomicRef<Self::Value>) -> Self {
let data: [AtomicRef<Vector<f32, $N>>;$N] = std::array::from_fn(|i| {
FromNode::from_node(__extract::<Vector<f32, $N>>(e.node(), i))
});
let mut i = 0;
$(
let $c = data[i].clone();
i += 1;
if i >= $N { i = 0; }
)*
Self{
self_: e,
$($c),*
}
}
fn as_atomic_ref_from_proxy(&self)->&AtomicRef<Self::Value> {
&self.self_
}
}

impl_simple_expr_proxy!(SquareMatrixExpr3 for SquareMatrix<3>);
impl_simple_var_proxy!(SquareMatrixVar3 for SquareMatrix<3>);
impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef3 for SquareMatrix<3>);
impl VarProxy for $VarName {
type Value = SquareMatrix<$N>;
#[allow(unused_assignments)]
fn from_var(e:Var<Self::Value>) -> Self {
let data: [Var<Vector<f32, $N>>;$N] = std::array::from_fn(|i| {
FromNode::from_node(__extract::<Vector<f32, $N>>(e.node(), i))
});
let mut i = 0;
$(
let $c = data[i].clone();
i += 1;
if i >= $N { i = 0; }
)*
Self{
self_: e,
$($c),*
}
}
fn as_var_from_proxy(&self)->&Var<Self::Value> {
&self.self_
}
}
}
}

impl_simple_expr_proxy!(SquareMatrixExpr4 for SquareMatrix<4>);
impl_simple_var_proxy!(SquareMatrixVar4 for SquareMatrix<4>);
impl_simple_atomic_ref_proxy!(SquareMatrixAtomicRef4 for SquareMatrix<4>);
matrix_proxies!(2 [x, y] [x, y]: SquareMatrixExpr2, SquareMatrixVar2, SquareMatrixAtomicRef2, SquareMatrixSoaProxy2);
matrix_proxies!(3 [x, y, z] [x, y, z]: SquareMatrixExpr3, SquareMatrixVar3, SquareMatrixAtomicRef3, SquareMatrixSoaProxy3);
matrix_proxies!(4 [x, y, z, w] [x, y, z, w]: SquareMatrixExpr4, SquareMatrixVar4, SquareMatrixAtomicRef4, SquareMatrixSoaProxy4);

impl Value for SquareMatrix<2> {
type Expr = SquareMatrixExpr2;
Expand Down
26 changes: 25 additions & 1 deletion luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::cell::RefCell;

use luisa::lang::types::array::VLArrayVar;
use luisa::lang::types::dynamic::*;
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::{alias::*, Mat2};
use luisa::prelude::*;
use luisa_compute as luisa;
use luisa_compute_api_types::StreamTag;
Expand Down Expand Up @@ -1461,6 +1461,7 @@ pub struct Foo {
i: u32,
v: Float2,
a: [i32; 4],
m: Mat2,
}
#[derive(Clone, Copy, Debug, Value, Soa, PartialEq)]
#[repr(C)]
Expand All @@ -1483,6 +1484,7 @@ fn soa() {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]),
},
});
let bars_soa = device.create_soa_buffer::<Bar>(1024);
Expand All @@ -1505,6 +1507,7 @@ fn soa_view() {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]),
},
});
let bars_soa = device.create_soa_buffer::<Bar>(2048);
Expand All @@ -1531,16 +1534,19 @@ fn atomic() {
i: rng.gen(),
v: Float2::new(rng.gen(), rng.gen()),
a: [rng.gen(), rng.gen(), rng.gen(), rng.gen()],
m: Mat2::from_column_array(&[[rng.gen(), rng.gen()], [rng.gen(), rng.gen()]]),
});
let foo_max_init = Foo {
i: u32::MIN,
v: Float2::new(f32::MIN, f32::MIN),
a: [i32::MIN; 4],
m: Mat2::from_column_array(&[[f32::MIN; 2]; 2]),
};
let foo_min_init = Foo {
i: u32::MAX,
v: Float2::new(f32::MAX, f32::MAX),
a: [i32::MAX; 4],
m: Mat2::from_column_array(&[[f32::MAX; 2]; 2]),
};
let foo_max = device.create_buffer_from_slice(&[foo_max_init]);
let foo_min = device.create_buffer_from_slice(&[foo_min_init]);
Expand All @@ -1558,12 +1564,21 @@ fn atomic() {
for i in 0..4u32 {
foo_max.a[i].fetch_max(foo.a[i]);
}
foo_max.m.x.x.fetch_max(foo.m.x.x);
foo_max.m.x.y.fetch_max(foo.m.x.y);
foo_max.m.y.x.fetch_max(foo.m.y.x);
foo_max.m.y.y.fetch_max(foo.m.y.y);

foo_min.i.fetch_min(foo.i);
foo_min.v.x.fetch_min(foo.v.x);
foo_min.v.y.fetch_min(foo.v.y);
for i in 0..4u32 {
foo_min.a[i].fetch_min(foo.a[i]);
}
foo_min.m.x.x.fetch_min(foo.m.x.x);
foo_min.m.x.y.fetch_min(foo.m.x.y);
foo_min.m.y.x.fetch_min(foo.m.y.x);
foo_min.m.y.y.fetch_min(foo.m.y.y);
}),
);
kernel.dispatch([foos.len() as u32, 1, 1]);
Expand All @@ -1579,12 +1594,21 @@ fn atomic() {
for i in 0..4 {
expected_foo_max.a[i] = expected_foo_max.a[i].max(foo.a[i]);
}
expected_foo_max.m.cols[0].x = expected_foo_max.m.cols[0].x.max(foo.m.cols[0].x);
expected_foo_max.m.cols[0].y = expected_foo_max.m.cols[0].y.max(foo.m.cols[0].y);
expected_foo_max.m.cols[1].x = expected_foo_max.m.cols[1].x.max(foo.m.cols[1].x);
expected_foo_max.m.cols[1].y = expected_foo_max.m.cols[1].y.max(foo.m.cols[1].y);

expected_foo_min.i = expected_foo_min.i.min(foo.i);
expected_foo_min.v.x = expected_foo_min.v.x.min(foo.v.x);
expected_foo_min.v.y = expected_foo_min.v.y.min(foo.v.y);
for i in 0..4 {
expected_foo_min.a[i] = expected_foo_min.a[i].min(foo.a[i]);
}
expected_foo_min.m.cols[0].x = expected_foo_min.m.cols[0].x.min(foo.m.cols[0].x);
expected_foo_min.m.cols[0].y = expected_foo_min.m.cols[0].y.min(foo.m.cols[0].y);
expected_foo_min.m.cols[1].x = expected_foo_min.m.cols[1].x.min(foo.m.cols[1].x);
expected_foo_min.m.cols[1].y = expected_foo_min.m.cols[1].y.min(foo.m.cols[1].y);
}
assert_eq!(foo_max, expected_foo_max);
assert_eq!(foo_min, expected_foo_min);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_sys/LuisaCompute
Submodule LuisaCompute updated 154 files

0 comments on commit 4db529c

Please sign in to comment.