Skip to content

Commit

Permalink
separate zipped_rw macro
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-quinones committed Oct 2, 2024
1 parent 9a0f0a5 commit 1c4fd07
Show file tree
Hide file tree
Showing 57 changed files with 410 additions and 460 deletions.
8 changes: 4 additions & 4 deletions src/col/colmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
iter,
iter::chunks::ChunkPolicy,
row::{RowMut, RowRef},
unzipped, zipped, Idx, IdxInc, Unbind,
unzipped, zipped_rw, Idx, IdxInc, Unbind,
};

/// Mutable view over a column vector, similar to a mutable reference to a strided [prim@slice].
Expand Down Expand Up @@ -542,7 +542,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
this: ColMut<'_, E, R>,
other: ColRef<'_, ViewE, R>,
) {
zipped!(__rw, this, other)
zipped_rw!(this, other)
.for_each(|unzipped!(mut dst, src)| dst.write(src.read().canonicalize()));
}
implementation(self.rb_mut(), other.as_col_ref())
Expand All @@ -554,7 +554,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
where
E: ComplexField,
{
zipped!(__rw, self.rb_mut()).for_each(
zipped_rw!(self.rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(E::faer_zero()),
);
Expand All @@ -563,7 +563,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
/// Fills the elements of `self` with copies of `constant`.
#[track_caller]
pub fn fill(&mut self, constant: E) {
zipped!(__rw, (*self).rb_mut()).for_each(
zipped_rw!((*self).rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(constant),
);
Expand Down
2 changes: 1 addition & 1 deletion src/col/colown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ impl<E: Entity, R: Shape> Clone for Col<E, R> {
}
fn clone_from(&mut self, other: &Self) {
if self.nrows() == other.nrows() {
crate::zipped!(__rw, self, other)
crate::zipped_rw!(self, other)
.for_each(|crate::unzipped!(mut dst, src)| dst.write(src.read()));
} else {
if !R::IS_BOUND {
Expand Down
29 changes: 18 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl Conj {
///
/// # Example
/// ```
/// use faer::{mat, unzipped, zipped, Mat};
/// use faer::{mat, unzipped, zipped_rw, Mat};
///
/// let nrows = 2;
/// let ncols = 3;
Expand All @@ -385,28 +385,34 @@ impl Conj {
/// ```
#[macro_export]
macro_rules! zipped {
(__rw, $head: expr $(,)?) => {
$crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head }))
($head: expr $(,)?) => {
$crate::linalg::zip::LastEq($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })))
};

(__rw, $head: expr, $($tail: expr),* $(,)?) => {
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!(__rw, $($tail,)*))
($head: expr, $($tail: expr),* $(,)?) => {
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })), $crate::zipped!( $($tail,)*))
};

}

/// Like the [`zipped!`] macro, but is compatible with potentially uninit values by not forming
/// references.
#[macro_export]
macro_rules! zipped_rw {
($head: expr $(,)?) => {
$crate::linalg::zip::LastEq($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })))
$crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head }))
};

($head: expr, $($tail: expr),* $(,)?) => {
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })), $crate::zipped!($($tail,)*))
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped_rw!($($tail,)*))
};
}

/// Used to undo the zipping by the [`zipped!`] macro.
/// Used to undo the zipping by the [`zipped_rw!`] macro.
///
/// # Example
/// ```
/// use faer::{mat, unzipped, zipped, Mat};
/// use faer::{mat, unzipped, zipped_rw, Mat};
///
/// let nrows = 2;
/// let ncols = 3;
Expand All @@ -415,7 +421,7 @@ macro_rules! zipped {
/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]];
/// let mut sum = Mat::<f64>::zeros(nrows, ncols);
///
/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
/// zipped_rw!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
/// *sum = a + b;
/// });
///
Expand Down Expand Up @@ -985,7 +991,8 @@ pub mod prelude {
pub use crate::{
col,
complex_native::{c32, c64},
mat, row, unzipped, zipped, Col, ColMut, ColRef, Mat, MatMut, MatRef, Row, RowMut, RowRef,
mat, row, unzipped, zipped_rw, Col, ColMut, ColRef, Mat, MatMut, MatRef, Row, RowMut,
RowRef,
};

pub use crate::linalg::solvers::{
Expand Down
26 changes: 11 additions & 15 deletions src/linalg/cholesky/bunch_kaufman/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
},
perm::{permute_rows, swap_cols_idx as swap_cols, swap_rows_idx as swap_rows, PermRef},
unzipped, zipped, ColMut, ColRef, Conj, Index, MatMut, MatRef, Parallelism, SignedIndex,
unzipped, zipped_rw, ColMut, ColRef, Conj, Index, MatMut, MatRef, Parallelism, SignedIndex,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use faer_entity::{ComplexField, Entity, RealField};
Expand Down Expand Up @@ -235,8 +235,7 @@ pub mod compute {
if abs_akk >= colmax.faer_mul(alpha) {
kp = k;
} else {
zipped!(
__rw,
zipped_rw!(
w.rb_mut().subrows_mut(k, imax - k).col_mut(k + 1),
a.rb().row(imax).subcols(k, imax - k).transpose(),
)
Expand Down Expand Up @@ -326,9 +325,9 @@ pub mod compute {
let d11 = d11.faer_inv();

let x = a.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k);
zipped!(__rw, x)
zipped_rw!(x)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
zipped_rw!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
} else {
let dd = w.read(k + 1, k).faer_abs();
Expand Down Expand Up @@ -406,13 +405,10 @@ pub mod compute {
a.write(j, k + 1, wkp1);
}

zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
zipped_rw!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
zipped_rw!(w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
zipped!(
__rw,
w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1)
)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
}
}

Expand All @@ -439,7 +435,7 @@ pub mod compute {
parallelism,
);

zipped!(__rw, a_right.diagonal_mut().column_vector_mut())
zipped_rw!(a_right.diagonal_mut().column_vector_mut())
.for_each(|unzipped!(mut x)| x.write(E::faer_from_real(x.read().faer_real())));

let mut j = k - 1;
Expand Down Expand Up @@ -596,7 +592,7 @@ pub mod compute {
}
make_real(trailing.rb_mut(), j, j);
}
zipped!(__rw, x)
zipped_rw!(x)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
} else {
let d21 = a.read(k + 1, k).faer_abs();
Expand Down Expand Up @@ -1042,7 +1038,7 @@ mod tests {

let err = &a * &x - &rhs;
let mut max = 0.0;
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
zipped_rw!(err.as_ref()).for_each(|unzipped!(err)| {
let err = err.read().abs();
if err > max {
max = err
Expand Down Expand Up @@ -1099,7 +1095,7 @@ mod tests {

let err = a.conjugate() * &x - &rhs;
let mut max = 0.0;
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
zipped_rw!(err.as_ref()).for_each(|unzipped!(err)| {
let err = err.read().abs();
if err > max {
max = err
Expand Down
4 changes: 2 additions & 2 deletions src/linalg/cholesky/ldlt_diagonal/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
},
unzipped,
utils::{simd::*, slice::*, DivCeil},
zipped, ComplexField, MatMut, MatRef, Parallelism,
zipped_rw, ComplexField, MatMut, MatRef, Parallelism,
};
use core::{convert::Infallible, marker::PhantomData};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
Expand Down Expand Up @@ -1426,7 +1426,7 @@ fn cholesky_in_place_impl<E: ComplexField>(
let a10_col = a10.rb_mut().col_mut(j);
let d0_elem = d0.read(j).faer_real().faer_inv();

zipped!(__rw, l10xd0_col, a10_col).for_each(
zipped_rw!(l10xd0_col, a10_col).for_each(
|unzipped!(mut l10xd0_elem, mut a10_elem)| {
let a10_elem_read = a10_elem.read();
a10_elem.write(a10_elem_read.faer_scale_real(d0_elem));
Expand Down
6 changes: 3 additions & 3 deletions src/linalg/cholesky/ldlt_diagonal/solve.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity,
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
Expand Down Expand Up @@ -184,7 +184,7 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}

Expand Down Expand Up @@ -216,6 +216,6 @@ pub fn solve_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}
10 changes: 5 additions & 5 deletions src/linalg/cholesky/ldlt_diagonal/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
},
unzipped,
utils::{simd::*, slice::*},
zipped, ColMut, MatMut, Parallelism,
zipped_rw, ColMut, MatMut, Parallelism,
};
use core::iter::zip;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
Expand Down Expand Up @@ -422,7 +422,7 @@ fn rank_update_step_impl4<E: ComplexField>(
let [p0, p1, p2, p3] = p_array;
let [beta0, beta1, beta2, beta3] = beta_array;

zipped!(__rw, l_col, w0, w1, w2, w3).for_each(
zipped_rw!(l_col, w0, w1, w2, w3).for_each(
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
Expand Down Expand Up @@ -482,7 +482,7 @@ fn rank_update_step_impl3<E: ComplexField>(
let [p0, p1, p2] = p_array;
let [beta0, beta1, beta2] = beta_array;

zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
zipped_rw!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -532,7 +532,7 @@ fn rank_update_step_impl2<E: ComplexField>(
let [p0, p1] = p_array;
let [beta0, beta1] = beta_array;

zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
zipped_rw!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -574,7 +574,7 @@ fn rank_update_step_impl1<E: ComplexField>(
let [p0] = p_array;
let [beta0] = beta_array;

zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
zipped_rw!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();

Expand Down
4 changes: 2 additions & 2 deletions src/linalg/cholesky/llt/reconstruct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
temp_mat_req, temp_mat_uninit,
zip::Diag,
},
unzipped, zipped, ComplexField, Entity, MatMut, MatRef, Parallelism,
unzipped, zipped_rw, ComplexField, Entity, MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
Expand Down Expand Up @@ -78,7 +78,7 @@ pub fn reconstruct_lower_in_place<E: ComplexField>(
let (mut tmp, stack) = temp_mat_uninit::<E>(n, n, stack);
let mut tmp = tmp.as_mut();
reconstruct_lower(tmp.rb_mut(), cholesky_factor.rb(), parallelism, stack);
zipped!(__rw, cholesky_factor, tmp.rb())
zipped_rw!(cholesky_factor, tmp.rb())
.for_each_triangular_lower(Diag::Include, |unzipped!(mut dst, src)| {
dst.write(src.read())
});
Expand Down
6 changes: 3 additions & 3 deletions src/linalg/cholesky/llt/solve.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity,
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
Expand Down Expand Up @@ -135,7 +135,7 @@ pub fn solve_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}

Expand Down Expand Up @@ -202,6 +202,6 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}
10 changes: 5 additions & 5 deletions src/linalg/cholesky/llt/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
},
unzipped,
utils::{simd::*, slice::*},
zipped, ColMut, MatMut, Parallelism,
zipped_rw, ColMut, MatMut, Parallelism,
};
use core::iter::zip;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
Expand Down Expand Up @@ -533,7 +533,7 @@ fn rank_update_step_impl4<E: ComplexField>(
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2, alpha_wj_over_nljj3] =
alpha_wj_over_nljj_array;

zipped!(__rw, l_col, w0, w1, w2, w3,).for_each(
zipped_rw!(l_col, w0, w1, w2, w3,).for_each(
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
Expand Down Expand Up @@ -607,7 +607,7 @@ fn rank_update_step_impl3<E: ComplexField>(
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2] =
alpha_wj_over_nljj_array;

zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
zipped_rw!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -668,7 +668,7 @@ fn rank_update_step_impl2<E: ComplexField>(
let [nljj_over_ljj0, nljj_over_ljj1] = nljj_over_ljj_array;
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1] = alpha_wj_over_nljj_array;

zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
zipped_rw!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -719,7 +719,7 @@ fn rank_update_step_impl1<E: ComplexField>(
let [nljj_over_ljj0] = nljj_over_ljj_array;
let [alpha_wj_over_nljj0] = alpha_wj_over_nljj_array;

zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
zipped_rw!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();

Expand Down
5 changes: 2 additions & 3 deletions src/linalg/cholesky/piv_llt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ pub mod compute {
crate::perm::swap_rows_idx(a.rb_mut().get_mut(.., ..j), j, pvt);
crate::perm::swap_cols_idx(a.rb_mut().get_mut(pvt + 1.., ..), j, pvt);
unsafe {
zipped!(
__rw,
zipped_rw!(
a.rb().get(j + 1..pvt, j).const_cast(),
a.rb().get(pvt, j + 1..pvt).const_cast().transpose_mut(),
)
Expand Down Expand Up @@ -173,7 +172,7 @@ pub mod compute {
);
}
let ajj = ajj.faer_inv();
zipped!(__rw, a.rb_mut().get_mut(j + 1.., j))
zipped_rw!(a.rb_mut().get_mut(j + 1.., j))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(ajj)));
}

Expand Down
Loading

0 comments on commit 1c4fd07

Please sign in to comment.