Skip to content

Commit

Permalink
docs and zipped
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-quinones committed Oct 2, 2024
1 parent ff832b2 commit 9a0f0a5
Show file tree
Hide file tree
Showing 67 changed files with 953 additions and 700 deletions.
3 changes: 3 additions & 0 deletions faer-entity/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ pub type Slice<'a, E> = GroupFor<E, &'a [<E as Entity>::Unit]>;
pub type SliceMut<'a, E> = GroupFor<E, &'a mut [<E as Entity>::Unit]>;
pub type UninitSliceMut<'a, E> = GroupFor<E, &'a mut [core::mem::MaybeUninit<<E as Entity>::Unit>]>;

extern crate alloc;
pub type Vector<E> = GroupFor<E, alloc::vec::Vec<<E as Entity>::Unit>>;

pub type GroupFor<E, T> = <<E as Entity>::Group as ForType>::FaerOf<T>;
pub type GroupCopyFor<E, T> = <<E as Entity>::Group as ForCopyType>::FaerOfCopy<T>;
pub type GroupDebugFor<E, T> = <<E as Entity>::Group as ForDebugType>::FaerOfDebug<T>;
Expand Down
6 changes: 3 additions & 3 deletions src/col/colmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
this: ColMut<'_, E, R>,
other: ColRef<'_, ViewE, R>,
) {
zipped!(this.as_2d_mut(), other.as_2d())
zipped!(__rw, this, other)
.for_each(|unzipped!(mut dst, src)| dst.write(src.read().canonicalize()));
}
implementation(self.rb_mut(), other.as_col_ref())
Expand All @@ -554,7 +554,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
where
E: ComplexField,
{
zipped!(self.rb_mut().as_2d_mut()).for_each(
zipped!(__rw, self.rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(E::faer_zero()),
);
Expand All @@ -563,7 +563,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
/// Fills the elements of `self` with copies of `constant`.
#[track_caller]
pub fn fill(&mut self, constant: E) {
zipped!((*self).rb_mut().as_2d_mut()).for_each(
zipped!(__rw, (*self).rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(constant),
);
Expand Down
2 changes: 1 addition & 1 deletion src/col/colown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ impl<E: Entity, R: Shape> Clone for Col<E, R> {
}
fn clone_from(&mut self, other: &Self) {
if self.nrows() == other.nrows() {
crate::zipped!(self, other)
crate::zipped!(__rw, self, other)
.for_each(|crate::unzipped!(mut dst, src)| dst.write(src.read()));
} else {
if !R::IS_BOUND {
Expand Down
20 changes: 12 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,7 @@ impl Conj {
/// let mut sum = Mat::<f64>::zeros(nrows, ncols);
///
/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
/// let a = a.read();
/// let b = b.read();
/// sum.write(a + b);
/// *sum = a + b;
/// });
///
/// for i in 0..nrows {
Expand All @@ -387,12 +385,20 @@ impl Conj {
/// ```
#[macro_export]
macro_rules! zipped {
($head: expr $(,)?) => {
(__rw, $head: expr $(,)?) => {
$crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head }))
};

(__rw, $head: expr, $($tail: expr),* $(,)?) => {
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!(__rw, $($tail,)*))
};

($head: expr $(,)?) => {
$crate::linalg::zip::LastEq($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })))
};

($head: expr, $($tail: expr),* $(,)?) => {
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!($($tail,)*))
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })), $crate::zipped!($($tail,)*))
};
}

Expand All @@ -410,9 +416,7 @@ macro_rules! zipped {
/// let mut sum = Mat::<f64>::zeros(nrows, ncols);
///
/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
/// let a = a.read();
/// let b = b.read();
/// sum.write(a + b);
/// *sum = a + b;
/// });
///
/// for i in 0..nrows {
Expand Down
24 changes: 15 additions & 9 deletions src/linalg/cholesky/bunch_kaufman/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ pub mod compute {
kp = k;
} else {
zipped!(
__rw,
w.rb_mut().subrows_mut(k, imax - k).col_mut(k + 1),
a.rb().row(imax).subcols(k, imax - k).transpose(),
)
Expand Down Expand Up @@ -325,8 +326,9 @@ pub mod compute {
let d11 = d11.faer_inv();

let x = a.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k);
zipped!(x).for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
zipped!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
zipped!(__rw, x)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
} else {
let dd = w.read(k + 1, k).faer_abs();
Expand Down Expand Up @@ -404,10 +406,13 @@ pub mod compute {
a.write(j, k + 1, wkp1);
}

zipped!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
zipped!(w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1))
zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
zipped!(
__rw,
w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1)
)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
}
}

Expand All @@ -434,7 +439,7 @@ pub mod compute {
parallelism,
);

zipped!(a_right.diagonal_mut().column_vector_mut())
zipped!(__rw, a_right.diagonal_mut().column_vector_mut())
.for_each(|unzipped!(mut x)| x.write(E::faer_from_real(x.read().faer_real())));

let mut j = k - 1;
Expand Down Expand Up @@ -591,7 +596,8 @@ pub mod compute {
}
make_real(trailing.rb_mut(), j, j);
}
zipped!(x).for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
zipped!(__rw, x)
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
} else {
let d21 = a.read(k + 1, k).faer_abs();
let d21_inv = d21.faer_inv();
Expand Down Expand Up @@ -1036,7 +1042,7 @@ mod tests {

let err = &a * &x - &rhs;
let mut max = 0.0;
zipped!(err.as_ref()).for_each(|unzipped!(err)| {
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
let err = err.read().abs();
if err > max {
max = err
Expand Down Expand Up @@ -1093,7 +1099,7 @@ mod tests {

let err = a.conjugate() * &x - &rhs;
let mut max = 0.0;
zipped!(err.as_ref()).for_each(|unzipped!(err)| {
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
let err = err.read().abs();
if err > max {
max = err
Expand Down
2 changes: 1 addition & 1 deletion src/linalg/cholesky/ldlt_diagonal/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ fn cholesky_in_place_impl<E: ComplexField>(
let a10_col = a10.rb_mut().col_mut(j);
let d0_elem = d0.read(j).faer_real().faer_inv();

zipped!(l10xd0_col, a10_col).for_each(
zipped!(__rw, l10xd0_col, a10_col).for_each(
|unzipped!(mut l10xd0_elem, mut a10_elem)| {
let a10_elem_read = a10_elem.read();
a10_elem.write(a10_elem_read.faer_scale_real(d0_elem));
Expand Down
4 changes: 2 additions & 2 deletions src/linalg/cholesky/ldlt_diagonal/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}

Expand Down Expand Up @@ -216,6 +216,6 @@ pub fn solve_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}
8 changes: 4 additions & 4 deletions src/linalg/cholesky/ldlt_diagonal/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ fn rank_update_step_impl4<E: ComplexField>(
let [p0, p1, p2, p3] = p_array;
let [beta0, beta1, beta2, beta3] = beta_array;

zipped!(l_col, w0, w1, w2, w3).for_each(
zipped!(__rw, l_col, w0, w1, w2, w3).for_each(
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
Expand Down Expand Up @@ -482,7 +482,7 @@ fn rank_update_step_impl3<E: ComplexField>(
let [p0, p1, p2] = p_array;
let [beta0, beta1, beta2] = beta_array;

zipped!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -532,7 +532,7 @@ fn rank_update_step_impl2<E: ComplexField>(
let [p0, p1] = p_array;
let [beta0, beta1] = beta_array;

zipped!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -574,7 +574,7 @@ fn rank_update_step_impl1<E: ComplexField>(
let [p0] = p_array;
let [beta0] = beta_array;

zipped!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();

Expand Down
2 changes: 1 addition & 1 deletion src/linalg/cholesky/llt/reconstruct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub fn reconstruct_lower_in_place<E: ComplexField>(
let (mut tmp, stack) = temp_mat_uninit::<E>(n, n, stack);
let mut tmp = tmp.as_mut();
reconstruct_lower(tmp.rb_mut(), cholesky_factor.rb(), parallelism, stack);
zipped!(cholesky_factor, tmp.rb())
zipped!(__rw, cholesky_factor, tmp.rb())
.for_each_triangular_lower(Diag::Include, |unzipped!(mut dst, src)| {
dst.write(src.read())
});
Expand Down
4 changes: 2 additions & 2 deletions src/linalg/cholesky/llt/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub fn solve_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}

Expand Down Expand Up @@ -202,6 +202,6 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
stack: &mut PodStack,
) {
let mut dst = dst;
zipped!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}
8 changes: 4 additions & 4 deletions src/linalg/cholesky/llt/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ fn rank_update_step_impl4<E: ComplexField>(
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2, alpha_wj_over_nljj3] =
alpha_wj_over_nljj_array;

zipped!(l_col, w0, w1, w2, w3,).for_each(
zipped!(__rw, l_col, w0, w1, w2, w3,).for_each(
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
Expand Down Expand Up @@ -607,7 +607,7 @@ fn rank_update_step_impl3<E: ComplexField>(
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2] =
alpha_wj_over_nljj_array;

zipped!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -668,7 +668,7 @@ fn rank_update_step_impl2<E: ComplexField>(
let [nljj_over_ljj0, nljj_over_ljj1] = nljj_over_ljj_array;
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1] = alpha_wj_over_nljj_array;

zipped!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();
let mut local_w1 = w1.read();
Expand Down Expand Up @@ -719,7 +719,7 @@ fn rank_update_step_impl1<E: ComplexField>(
let [nljj_over_ljj0] = nljj_over_ljj_array;
let [alpha_wj_over_nljj0] = alpha_wj_over_nljj_array;

zipped!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
let mut local_l = l.read();
let mut local_w0 = w0.read();

Expand Down
3 changes: 2 additions & 1 deletion src/linalg/cholesky/piv_llt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ pub mod compute {
crate::perm::swap_cols_idx(a.rb_mut().get_mut(pvt + 1.., ..), j, pvt);
unsafe {
zipped!(
__rw,
a.rb().get(j + 1..pvt, j).const_cast(),
a.rb().get(pvt, j + 1..pvt).const_cast().transpose_mut(),
)
Expand Down Expand Up @@ -172,7 +173,7 @@ pub mod compute {
);
}
let ajj = ajj.faer_inv();
zipped!(a.rb_mut().get_mut(j + 1.., j))
zipped!(__rw, a.rb_mut().get_mut(j + 1.., j))
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(ajj)));
}

Expand Down
36 changes: 21 additions & 15 deletions src/linalg/evd/hessenberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,22 +502,28 @@ fn make_hessenberg_in_place_basic<E: ComplexField>(
(nu.faer_mul(psi.faer_conj())).faer_add(zeta.faer_mul(nu.faer_conj())),
),
);
zipped!(a12.rb_mut(), y21.rb().transpose(), u21.rb().transpose()).for_each(
|unzipped!(mut a, y, u)| {
let y = y.read();
zipped!(
__rw,
a12.rb_mut(),
y21.rb().transpose(),
u21.rb().transpose()
)
.for_each(|unzipped!(mut a, y, u)| {
let y = y.read();
let u = u.read();
a.write(a.read().faer_sub(
(nu.faer_mul(y.faer_conj())).faer_add(zeta.faer_mul(u.faer_conj())),
));
});
zipped!(__rw, a21.rb_mut(), u21.rb(), z21.rb()).for_each(
|unzipped!(mut a, u, z)| {
let z = z.read();
let u = u.read();
a.write(a.read().faer_sub(
(nu.faer_mul(y.faer_conj())).faer_add(zeta.faer_mul(u.faer_conj())),
(u.faer_mul(psi.faer_conj())).faer_add(z.faer_mul(nu.faer_conj())),
));
},
);
zipped!(a21.rb_mut(), u21.rb(), z21.rb()).for_each(|unzipped!(mut a, u, z)| {
let z = z.read();
let u = u.read();
a.write(a.read().faer_sub(
(u.faer_mul(psi.faer_conj())).faer_add(z.faer_mul(nu.faer_conj())),
));
});
}

let (tau, new_head) = {
Expand Down Expand Up @@ -562,14 +568,14 @@ fn make_hessenberg_in_place_basic<E: ComplexField>(
);
}

zipped!(u21.rb_mut(), a21.rb())
zipped!(__rw, u21.rb_mut(), a21.rb())
.for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
a21.write(0, new_head);

let beta = inner_prod_with_conj(u21.rb(), Conj::Yes, z21.rb(), Conj::No)
.faer_scale_power_of_two(E::Real::faer_from_f64(0.5));

zipped!(y21.rb_mut(), u21.rb()).for_each(|unzipped!(mut y, u)| {
zipped!(__rw, y21.rb_mut(), u21.rb()).for_each(|unzipped!(mut y, u)| {
let u = u.read();
let beta = beta.faer_conj();
y.write(
Expand All @@ -578,7 +584,7 @@ fn make_hessenberg_in_place_basic<E: ComplexField>(
.faer_mul(tau_inv),
);
});
zipped!(z21.rb_mut(), u21.rb()).for_each(|unzipped!(mut z, u)| {
zipped!(__rw, z21.rb_mut(), u21.rb()).for_each(|unzipped!(mut z, u)| {
let u = u.read();
z.write(
z.read()
Expand Down Expand Up @@ -703,7 +709,7 @@ fn make_hessenberg_in_place_qgvdg_unblocked<E: ComplexField>(
let u_0 = u.rb().get(.., ..k);
let u10_adjoint = u_0.rb().get(k, ..);

zipped!(tmp.rb_mut(), u10_adjoint.transpose())
zipped!(__rw, tmp.rb_mut(), u10_adjoint.transpose())
.for_each(|unzipped!(mut dst, src)| dst.write(src.read().faer_conj()));
if k > 0 {
tmp.write(k - 1, one);
Expand Down
Loading

0 comments on commit 9a0f0a5

Please sign in to comment.