Skip to content

Commit

Permalink
replace heap allocations with stack allocated matrix in triangular mu…
Browse files Browse the repository at this point in the history
…ltiplication
  • Loading branch information
sarah authored and sarah committed Sep 16, 2023
1 parent e811fa8 commit ca71ff8
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 101 deletions.
5 changes: 2 additions & 3 deletions faer-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@
//! compatible with the classic contiguous layout that's commonly used by other libraries.
//!
//! # Memory allocation
//! Since most `faer` crates aim to expose a low level api for optimal performance, all algorithms
//! are mostly allocation-free in single threaded settings, and allocate minimal amounts in
//! multithreaded settings.
//! Since most `faer` crates aim to expose a low level api for optimal performance, most algorithms
//! try to defer memory allocation to the user.
//!
//! However, since a lot of algorithms need some form of temporary space for intermediate
//! computations, they may ask for a slice of memory for that purpose, by taking a [`stack:
Expand Down
290 changes: 192 additions & 98 deletions faer-core/src/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,60 @@ pub fn matmul<E: ComplexField, LhsE: Conjugate<Canonical = E>, RhsE: Conjugate<C
matmul_with_conj::<E>(acc, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
}

macro_rules! stack_mat_16x16_begin {
($name: ident, $nrows: expr, $ncols: expr, $ty: ty) => {
let __nrows = $nrows;
let __ncols = $ncols;
let mut __data = <$ty as $crate::Entity>::map(
<$ty as $crate::Entity>::from_copy(<$ty as $crate::Entity>::UNIT),
#[inline(always)]
|()| unsafe {
$crate::transmute_unchecked::<
::core::mem::MaybeUninit<[<$ty as $crate::Entity>::Unit; 16 * 16]>,
[::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16],
>(::core::mem::MaybeUninit::<
[<$ty as $crate::Entity>::Unit; 16 * 16],
>::uninit())
},
);

let zero = <$ty as $crate::Entity>::into_units(<$ty as $crate::ComplexField>::zero());
<$ty as $crate::Entity>::map(
<$ty as $crate::Entity>::zip(<$ty as $crate::Entity>::as_mut(&mut __data), zero),
#[inline(always)]
|(__data, zero)| {
for __data in __data {
*__data = ::core::mem::MaybeUninit::new(::core::clone::Clone::clone(&zero));
}
},
);
let mut __data =
<$ty as $crate::Entity>::map(<$ty as $crate::Entity>::as_mut(&mut __data), |__data| {
(__data as *mut [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16]
as *mut <$ty as $crate::Entity>::Unit)
});

let mut $name = unsafe {
$crate::MatMut::<'_, $ty>::from_raw_parts(__data, __nrows, __ncols, 1isize, 16isize)
};
};
}

macro_rules! stack_mat_16x16_end {
($name: ident, $nrows: expr, $ncols: expr, $ty: ty) => {
<$ty as $crate::Entity>::map(
$name.as_ptr(),
#[inline(always)]
|__data| unsafe {
::core::ptr::drop_in_place(
__data as *mut <$ty as $crate::Entity>::Unit
as *mut [<$ty as $crate::Entity>::Unit; 16 * 16],
)
},
);
};
}

/// Triangular matrix multiplication module, where some of the operands are treated as triangular
/// matrices.
pub mod triangular {
Expand Down Expand Up @@ -2174,22 +2228,30 @@ pub mod triangular {
debug_assert!(n == rhs.ncols());

if n <= 16 {
let mut dst_buffer = crate::Mat::zeros(n, n);
let mut temp_dst = dst_buffer.as_mut();
let mut rhs_buffer = crate::Mat::zeros(n, n);
let mut temp_rhs = rhs_buffer.as_mut();
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
temp_dst.rb_mut(),
lhs,
temp_rhs.into_const(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.into_const(), skip_diag, alpha);
let op = {
#[inline(never)]
|| {
stack_mat_16x16_begin!(temp_dst, n, n, E);
stack_mat_16x16_begin!(temp_rhs, n, n, E);

copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
temp_dst.rb_mut(),
lhs,
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);

stack_mat_16x16_end!(temp_dst, n, n, E);
stack_mat_16x16_end!(temp_rhs, n, n, E);
}
};
op();
} else {
let bs = n / 2;

Expand Down Expand Up @@ -2294,25 +2356,24 @@ pub mod triangular {
let op = {
#[inline(never)]
|| {
let mut rhs_buffer = crate::Mat::zeros(n, n);
let mut temp_rhs = rhs_buffer.as_mut();
stack_mat_16x16_begin!(temp_rhs, n, n, E);

copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
let temp_rhs = temp_rhs.into_const();

mul(
dst,
lhs,
temp_rhs,
temp_rhs.rb(),
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);

stack_mat_16x16_end!(temp_rhs, n, n, E);
}
};

op();
} else {
// split rhs into 3 sections
Expand Down Expand Up @@ -2388,29 +2449,34 @@ pub mod triangular {
debug_assert!(n == dst.ncols());

if n <= 16 {
let mut dst_buffer = crate::Mat::zeros(n, n);
let mut temp_dst = dst_buffer.as_mut();
let mut lhs_buffer = crate::Mat::zeros(n, n);
let mut temp_lhs = lhs_buffer.as_mut();
let mut rhs_buffer = crate::Mat::zeros(n, n);
let mut temp_rhs = rhs_buffer.as_mut();

copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);

let temp_lhs = temp_lhs.into_const();
let temp_rhs = temp_rhs.into_const();
mul(
temp_dst.rb_mut(),
temp_lhs,
temp_rhs,
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.into_const(), skip_diag, alpha);
let op = {
#[inline(never)]
|| {
stack_mat_16x16_begin!(temp_dst, n, n, E);
stack_mat_16x16_begin!(temp_lhs, n, n, E);
stack_mat_16x16_begin!(temp_rhs, n, n, E);

copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);

mul(
temp_dst.rb_mut(),
temp_lhs.rb(),
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);

stack_mat_16x16_end!(temp_dst, n, n, E);
stack_mat_16x16_end!(temp_lhs, n, n, E);
stack_mat_16x16_end!(temp_rhs, n, n, E);
}
};
op();
} else {
let bs = n / 2;

Expand Down Expand Up @@ -2495,26 +2561,31 @@ pub mod triangular {
debug_assert!(n == dst.ncols());

if n <= 16 {
let mut lhs_buffer = crate::Mat::zeros(n, n);
let mut temp_lhs = lhs_buffer.as_mut();
let mut rhs_buffer = crate::Mat::zeros(n, n);
let mut temp_rhs = rhs_buffer.as_mut();
let op = {
#[inline(never)]
|| {
stack_mat_16x16_begin!(temp_lhs, n, n, E);
stack_mat_16x16_begin!(temp_rhs, n, n, E);

copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);

let temp_lhs = temp_lhs.into_const();
let temp_rhs = temp_rhs.into_const();
mul(
dst,
temp_lhs,
temp_rhs,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mul(
dst,
temp_lhs.rb(),
temp_rhs.rb(),
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);

stack_mat_16x16_end!(temp_lhs, n, n, E);
stack_mat_16x16_end!(temp_rhs, n, n, E);
}
};
op();
} else {
let bs = n / 2;

Expand Down Expand Up @@ -2626,30 +2697,35 @@ pub mod triangular {
debug_assert!(n == dst.ncols());

if n <= 16 {
let mut dst_buffer = crate::Mat::zeros(n, n);
let mut temp_dst = dst_buffer.as_mut();
let mut lhs_buffer = crate::Mat::zeros(n, n);
let mut temp_lhs = lhs_buffer.as_mut();
let mut rhs_buffer = crate::Mat::zeros(n, n);
let mut temp_rhs = rhs_buffer.as_mut();

copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);

let temp_lhs = temp_lhs.into_const();
let temp_rhs = temp_rhs.into_const();
mul(
temp_dst.rb_mut(),
temp_lhs,
temp_rhs,
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
let op = {
#[inline(never)]
|| {
stack_mat_16x16_begin!(temp_dst, n, n, E);
stack_mat_16x16_begin!(temp_lhs, n, n, E);
stack_mat_16x16_begin!(temp_rhs, n, n, E);

accum_lower(dst.rb_mut(), temp_dst.into_const(), skip_diag, alpha);
copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);

mul(
temp_dst.rb_mut(),
temp_lhs.rb(),
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);

accum_lower(dst.rb_mut(), temp_dst.rb(), skip_diag, alpha);

stack_mat_16x16_end!(temp_dst, n, n, E);
stack_mat_16x16_end!(temp_lhs, n, n, E);
stack_mat_16x16_end!(temp_rhs, n, n, E);
}
};
op();
} else {
let bs = n / 2;

Expand Down Expand Up @@ -2747,19 +2823,27 @@ pub mod triangular {
};

if n <= 16 {
let mut dst_buffer = crate::Mat::zeros(n, n);
let mut temp_dst = dst_buffer.as_mut();
mul(
temp_dst.rb_mut(),
lhs,
rhs,
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha)
let op = {
#[inline(never)]
|| {
stack_mat_16x16_begin!(temp_dst, n, n, E);

mul(
temp_dst.rb_mut(),
lhs,
rhs,
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);

stack_mat_16x16_end!(temp_dst, n, n, E);
}
};
op();
} else {
let bs = n / 2;
let [dst_top_left, _, dst_bot_left, dst_bot_right] = dst.split_at(bs, bs);
Expand Down Expand Up @@ -3383,6 +3467,16 @@ mod tests {
use assert_approx_eq::assert_approx_eq;
use num_complex::Complex32;

#[test]
fn test_stack_mat() {
stack_mat_16x16_begin!(m, 3, 3, f64);
{
let _ = &mut m;
dbg!(&m);
}
stack_mat_16x16_end!(m, 3, 3, f64);
}

#[test]
#[ignore = "this takes too long to launch in CI"]
fn test_matmul() {
Expand Down

0 comments on commit ca71ff8

Please sign in to comment.