diff --git a/faer-traits/src/lib.rs b/faer-traits/src/lib.rs index d0d28e45..d64589b3 100644 --- a/faer-traits/src/lib.rs +++ b/faer-traits/src/lib.rs @@ -402,6 +402,10 @@ macro_rules! help2 { #[allow(unused_macros)] macro_rules! write2 { + ($mat: ident[$idx: expr] = $val: expr) => {{ + let __val = $val; + $crate::utils::write::<$C, _>(&mut $mat.rb_mut().__at_mut($idx), __val) + }}; ($place: expr, $val: expr) => {{ let __val = $val; $crate::utils::write::<$C, _>(&mut $place, __val) @@ -1060,6 +1064,15 @@ impl, S: Simd> SimdCtx { pub fn tail_mask(&self, len: usize) -> T::SimdMask { T::simd_tail_mask(&self.0, len) } + #[inline(always)] + pub fn head_mask(&self, len: usize) -> T::SimdMask { + T::simd_head_mask(&self.0, len) + } + #[inline(always)] + pub fn and_mask(&self, lhs: T::SimdMask, rhs: T::SimdMask) -> T::SimdMask { + T::simd_and_mask(&self.0, lhs, rhs) + } + #[inline(always)] pub unsafe fn mask_load( &self, @@ -1344,6 +1357,14 @@ impl, S: Simd> SimdCtxCopy { unsafe { core::mem::transmute_copy(&T::simd_tail_mask(&self.0, len)) } } #[inline(always)] + pub fn head_mask(&self, len: usize) -> T::SimdMask { + unsafe { core::mem::transmute_copy(&T::simd_head_mask(&self.0, len)) } + } + #[inline(always)] + pub fn and_mask(&self, lhs: T::SimdMask, rhs: T::SimdMask) -> T::SimdMask { + T::simd_and_mask(&self.0, lhs, rhs) + } + #[inline(always)] pub unsafe fn mask_load( &self, mask: T::SimdMask, @@ -1381,7 +1402,7 @@ pub unsafe trait ConjUnit { type Canonical: ConjUnit; } -pub unsafe trait Container: 'static { +pub unsafe trait Container: 'static + core::fmt::Debug { type Of; type OfCopy: Copy; type OfDebug: Debug; @@ -2264,6 +2285,12 @@ pub trait ComplexField: ) -> Self::SimdIndex; fn simd_tail_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask; + fn simd_and_mask( + ctx: &Self::SimdCtx, + lhs: Self::SimdMask, + rhs: Self::SimdMask, + ) -> Self::SimdMask; + fn simd_head_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask; unsafe fn simd_mask_load( ctx: &Self::SimdCtx, mask: Self::SimdMask, @@ -2926,6 +2953,11 @@ impl> ComplexField> for T { let ctx = SimdCtx::::new(ctx); ctx.tail_mask(len) } + #[inline(always)] + fn simd_head_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask { + let ctx = SimdCtx::::new(ctx); + ctx.head_mask(len) + } #[inline(always)] unsafe fn simd_mask_load( @@ -2996,6 +3028,15 @@ impl> ComplexField> for T { fn ctx_from_simd(ctx: &Self::SimdCtx) -> (Self::MathCtx, S) { T::ctx_from_simd(ctx) } + + #[inline(always)] + fn simd_and_mask( + ctx: &Self::SimdCtx, + lhs: Self::SimdMask, + rhs: Self::SimdMask, + ) -> Self::SimdMask { + T::simd_and_mask(ctx, lhs, rhs) + } } impl AsRef for Unit { @@ -3343,6 +3384,10 @@ impl ComplexField for f32 { ctx.tail_mask_f32s(len) } #[inline(always)] + fn simd_head_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask { + ctx.head_mask_f32s(len) + } + #[inline(always)] unsafe fn simd_mask_load( ctx: &Self::SimdCtx, mask: Self::SimdMask, @@ -3399,6 +3444,15 @@ impl ComplexField for f32 { fn ctx_from_simd(ctx: &Self::SimdCtx) -> (Self::MathCtx, S) { (Unit, *ctx) } + + #[inline(always)] + fn simd_and_mask( + simd: &Self::SimdCtx, + lhs: Self::SimdMask, + rhs: Self::SimdMask, + ) -> Self::SimdMask { + simd.and_m32s(lhs, rhs) + } } impl RealField for f32 { @@ -3769,6 +3823,10 @@ impl ComplexField for f64 { ctx.tail_mask_f64s(len) } #[inline(always)] + fn simd_head_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask { + ctx.head_mask_f64s(len) + } + #[inline(always)] unsafe fn simd_mask_load( ctx: &Self::SimdCtx, mask: Self::SimdMask, @@ -3824,6 +3882,15 @@ impl ComplexField for f64 { fn ctx_from_simd(ctx: &Self::SimdCtx) -> (Self::MathCtx, S) { (Unit, *ctx) } + + #[inline(always)] + fn simd_and_mask( + simd: &Self::SimdCtx, + lhs: Self::SimdMask, + rhs: Self::SimdMask, + ) -> Self::SimdMask { + simd.and_m64s(lhs, rhs) + } } impl RealField for f64 { @@ -4272,6 +4339,11 @@ impl ComplexField for Complex { let ctx = SimdCtx::::new(ctx); ctx.tail_mask(2 * len) } + #[inline(always)] + fn simd_head_mask(ctx: &Self::SimdCtx, len: usize) -> Self::SimdMask { + let ctx = SimdCtx::::new(ctx); + ctx.head_mask(2 * len) + } #[inline(always)] unsafe fn simd_mask_load( @@ -4324,6 +4396,15 @@ impl ComplexField for Complex { fn ctx_from_simd(ctx: &Self::SimdCtx) -> (Self::MathCtx, S) { T::ctx_from_simd(ctx) } + + #[inline(always)] + fn simd_and_mask( + simd: &Self::SimdCtx, + lhs: Self::SimdMask, + rhs: Self::SimdMask, + ) -> Self::SimdMask { + T::simd_and_mask(simd, lhs, rhs) + } } impl EnableComplex for f32 { diff --git a/faer/src/lib.rs b/faer/src/lib.rs index 2f001c0c..3667d546 100644 --- a/faer/src/lib.rs +++ b/faer/src/lib.rs @@ -1,9 +1,24 @@ #![allow(non_snake_case)] -use core::{num::NonZeroUsize, sync::atomic::AtomicUsize}; +use core::{num::NonZero, sync::atomic::AtomicUsize}; use equator::{assert, debug_assert}; use faer_traits::*; +macro_rules! stack_mat { + ($ctx: expr, $name: ident, $m: expr, $n: expr, $M: expr, $N: expr, $C: ty, $T: ty $(,)?) => { + let mut __tmp = { + #[repr(align(64))] + struct __Col([T; M]); + struct __Mat([__Col; N]); + + core::mem::MaybeUninit::>>::uninit() + }; + let __stack = DynStack::new_any(core::slice::from_mut(&mut __tmp)); + let mut $name = unsafe { temp_mat_uninit($ctx, $m, $n, __stack) }.0; + let mut $name = $name.as_mat_mut(); + }; +} + #[macro_export] #[doc(hidden)] macro_rules! __dbg { @@ -448,16 +463,16 @@ impl Conj { pub enum Parallelism { None, #[cfg(feature = "rayon")] - Rayon(NonZeroUsize), + Rayon(NonZero), } impl Parallelism { #[cfg(feature = "rayon")] pub fn rayon(nthreads: usize) -> Self { if nthreads == 0 { - Self::Rayon(NonZeroUsize::new(rayon::current_num_threads()).unwrap()) + Self::Rayon(NonZero::new(rayon::current_num_threads()).unwrap()) } else { - Self::Rayon(NonZeroUsize::new(nthreads).unwrap()) + Self::Rayon(NonZero::new(nthreads).unwrap()) } } } diff --git a/faer/src/linalg/cholesky/ldlt/factor.rs b/faer/src/linalg/cholesky/ldlt/factor.rs index 32bcec45..151e0f88 100644 --- a/faer/src/linalg/cholesky/ldlt/factor.rs +++ b/faer/src/linalg/cholesky/ldlt/factor.rs @@ -1,4 +1,6 @@ -use crate::internal_prelude::*; +use crate::{assert, internal_prelude::*}; +use core::num::NonZero; +use faer_traits::RealValue; use linalg::matmul::triangular::BlockStructure; use pulp::Simd; @@ -30,7 +32,10 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> let (disjoint, head, tail, _, _) = mat_rows.split_inc(start, ROW.HEAD, ROW.TAIL); let simd = SimdCtx::::new_force_mask(simd, tail.len()); - let indices = simd.indices(); + let (idx_head, indices, idx_tail) = simd.indices(); + assert!(idx_head.is_none()); + let Some(idx_tail) = idx_tail else { panic!() }; + let ctx = &Ctx::(T::ctx_from_simd(&simd.ctx).0); let mut count = 0usize; @@ -50,7 +55,7 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> let mut Aj = A11.rb_mut().col_mut(right.from_global(j)); { - let mut Aj = Aj.rb_mut().as_array_mut(); + let mut Aj = Aj.rb_mut(); let mut iter = indices.clone(); let i0 = iter.next(); let i1 = iter.next(); @@ -58,10 +63,10 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> match (i0, i1, i2) { (None, None, None) => { - let mut Aij = simd.read_tail(rb!(Aj)); + let mut Aij = simd.read(Aj.rb(), idx_tail); for k in left { - let Ak = A10.col(left.from_global(k)).as_array(); + let Ak = A10.col(left.from_global(k)); let D = math(real(D[mat_cols.from_global(k)])); let D = if is_llt { math.re.one() } else { D }; @@ -71,17 +76,17 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> re.neg(D) )))); - let Aik = simd.read_tail(rb!(Ak)); + let Aik = simd.read(Ak, idx_tail); Aij = simd.mul_add(Ajk, Aik, Aij); } - simd.write_tail(rb_mut!(Aj), Aij); + simd.write(Aj.rb_mut(), idx_tail, Aij); } (Some(i0), None, None) => { - let mut A0j = simd.read(rb!(Aj), i0); - let mut Aij = simd.read_tail(rb!(Aj)); + let mut A0j = simd.read(Aj.rb(), i0); + let mut Aij = simd.read(Aj.rb(), idx_tail); for k in left { - let Ak = A10.col(left.from_global(k)).as_array(); + let Ak = A10.col(left.from_global(k)); let D = math(real(D[mat_cols.from_global(k)])); let D = if is_llt { math.re.one() } else { D }; @@ -91,21 +96,21 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> re.neg(D) )))); - let A0k = simd.read(rb!(Ak), i0); - let Aik = simd.read_tail(rb!(Ak)); + let A0k = simd.read(Ak, i0); + let Aik = simd.read(Ak, idx_tail); A0j = simd.mul_add(Ajk, A0k, A0j); Aij = simd.mul_add(Ajk, Aik, Aij); } - simd.write(rb_mut!(Aj), i0, A0j); - simd.write_tail(rb_mut!(Aj), Aij); + simd.write(Aj.rb_mut(), i0, A0j); + simd.write(Aj.rb_mut(), idx_tail, Aij); } (Some(i0), Some(i1), None) => { - let mut A0j = simd.read(rb!(Aj), i0); - let mut A1j = simd.read(rb!(Aj), i1); - let mut Aij = simd.read_tail(rb!(Aj)); + let mut A0j = simd.read(Aj.rb(), i0); + let mut A1j = simd.read(Aj.rb(), i1); + let mut Aij = simd.read(Aj.rb(), idx_tail); for k in left { - let Ak = A10.col(left.from_global(k)).as_array(); + let Ak = A10.col(left.from_global(k)); let D = math(real(D[mat_cols.from_global(k)])); let D = if is_llt { math.re.one() } else { D }; @@ -115,25 +120,25 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> re.neg(D) )))); - let A0k = simd.read(rb!(Ak), i0); - let A1k = simd.read(rb!(Ak), i1); - let Aik = simd.read_tail(rb!(Ak)); + let A0k = simd.read(Ak, i0); + let A1k = simd.read(Ak, i1); + let Aik = simd.read(Ak, idx_tail); A0j = simd.mul_add(Ajk, A0k, A0j); A1j = simd.mul_add(Ajk, A1k, A1j); Aij = simd.mul_add(Ajk, Aik, Aij); } - simd.write(rb_mut!(Aj), i0, A0j); - simd.write(rb_mut!(Aj), i1, A1j); - simd.write_tail(rb_mut!(Aj), Aij); + simd.write(Aj.rb_mut(), i0, A0j); + simd.write(Aj.rb_mut(), i1, A1j); + simd.write(Aj.rb_mut(), idx_tail, Aij); } (Some(i0), Some(i1), Some(i2)) => { - let mut A0j = simd.read(rb!(Aj), i0); - let mut A1j = simd.read(rb!(Aj), i1); - let mut A2j = simd.read(rb!(Aj), i2); - let mut Aij = simd.read_tail(rb!(Aj)); + let mut A0j = simd.read(Aj.rb(), i0); + let mut A1j = simd.read(Aj.rb(), i1); + let mut A2j = simd.read(Aj.rb(), i2); + let mut Aij = simd.read(Aj.rb(), idx_tail); for k in left { - let Ak = A10.col(left.from_global(k)).as_array(); + let Ak = A10.col(left.from_global(k)); let D = math(real(D[mat_cols.from_global(k)])); let D = if is_llt { math.re.one() } else { D }; @@ -143,19 +148,19 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> re.neg(D) )))); - let A0k = simd.read(rb!(Ak), i0); - let A1k = simd.read(rb!(Ak), i1); - let A2k = simd.read(rb!(Ak), i2); - let Aik = simd.read_tail(rb!(Ak)); + let A0k = simd.read(Ak, i0); + let A1k = simd.read(Ak, i1); + let A2k = simd.read(Ak, i2); + let Aik = simd.read(Ak, idx_tail); A0j = simd.mul_add(Ajk, A0k, A0j); A1j = simd.mul_add(Ajk, A1k, A1j); A2j = simd.mul_add(Ajk, A2k, A2j); Aij = simd.mul_add(Ajk, Aik, Aij); } - simd.write(rb_mut!(Aj), i0, A0j); - simd.write(rb_mut!(Aj), i1, A1j); - simd.write(rb_mut!(Aj), i2, A2j); - simd.write_tail(rb_mut!(Aj), Aij); + simd.write(Aj.rb_mut(), i0, A0j); + simd.write(Aj.rb_mut(), i1, A1j); + simd.write(Aj.rb_mut(), i2, A2j); + simd.write(Aj.rb_mut(), idx_tail, Aij); } _ => { unreachable!(); @@ -219,18 +224,18 @@ fn simd_cholesky_row_batch<'N, C: ComplexContainer, T: ComplexField, S: Simd> let diag = math(real(D)); { - let mut Aj = Aj.rb_mut().as_array_mut(); + let mut Aj = Aj.rb_mut(); let inv = simd.splat_real(as_ref2!(math.re.recip(diag))); for i in indices.clone() { - let mut Aij = simd.read(rb!(Aj), i); + let mut Aij = simd.read(Aj.rb(), i); Aij = simd.mul_real(Aij, inv); - simd.write(rb_mut!(Aj), i, Aij); + simd.write(Aj.rb_mut(), i, Aij); } { - let mut Aij = simd.read_tail(rb!(Aj)); + let mut Aij = simd.read(Aj.rb(), idx_tail); Aij = simd.mul_real(Aij, inv); - simd.write_tail(rb_mut!(Aj), Aij); + simd.write(Aj.rb_mut(), idx_tail, Aij); } } }); @@ -460,11 +465,12 @@ fn cholesky_fallback<'N, C: ComplexContainer, T: ComplexField>( } #[math] -fn cholesky_recursion<'N, C: ComplexContainer, T: ComplexField>( +pub(crate) fn cholesky_recursion<'N, C: ComplexContainer, T: ComplexField>( ctx: &Ctx, A: MatMut<'_, C, T, Dim<'N>, Dim<'N>>, D: RowMut<'_, C, T, Dim<'N>>, + recursion_threshold: usize, is_llt: bool, regularize: bool, eps: ::Of<&T::RealUnit>, @@ -473,7 +479,7 @@ fn cholesky_recursion<'N, C: ComplexContainer, T: ComplexField>( par: Parallelism, ) -> Result { let N = A.ncols(); - if *N <= 1 { + if *N <= recursion_threshold { cholesky_fallback(ctx, A, D, is_llt, regularize, eps, delta, signs) } else { let mut count = 0; @@ -507,6 +513,7 @@ fn cholesky_recursion<'N, C: ComplexContainer, T: ComplexField>( ctx, A00.rb_mut(), D0.rb_mut(), + recursion_threshold, is_llt, regularize, rb2!(eps), @@ -574,6 +581,121 @@ fn cholesky_recursion<'N, C: ComplexContainer, T: ComplexField>( } } +/// Dynamic LDLT regularization. +/// Values below `epsilon` in absolute value, or with the wrong sign are set to `delta` with +/// their corrected sign. +pub struct LdltRegularization<'a, C: ComplexContainer, T: ComplexField> { + /// Expected signs for the diagonal at each step of the decomposition. + pub dynamic_regularization_signs: Option<&'a [i8]>, + /// Regularized value. + pub dynamic_regularization_delta: RealValue, + /// Regularization threshold. + pub dynamic_regularization_epsilon: RealValue, +} + +/// Info about the result of the LDLT factorization. +#[derive(Copy, Clone, Debug)] +pub struct LdltInfo { + /// Number of pivots whose value or sign had to be corrected. + pub dynamic_regularization_count: usize, +} + +/// Error in the LDLT factorization. +#[derive(Copy, Clone, Debug)] +pub enum LdltError { + ZeroPivot { index: usize }, +} + +impl> LdltRegularization<'_, C, T> { + #[math] + pub fn default_with(ctx: &Ctx) -> Self { + Self { + dynamic_regularization_signs: None, + dynamic_regularization_delta: math.re(zero()), + dynamic_regularization_epsilon: math.re(zero()), + } + } +} + +impl> Default + for LdltRegularization<'_, C, T> +{ + fn default() -> Self { + Self::default_with(&ctx()) + } +} + +#[non_exhaustive] +pub struct LdltParams { + pub blocksize: NonZero, +} + +impl Default for LdltParams { + #[inline] + fn default() -> Self { + Self { + blocksize: NonZero::new(64).unwrap(), + } + } +} + +#[inline] +pub fn cholesky_in_place_scratch>( + dim: usize, +) -> Result { + temp_mat_scratch::(dim, 1) +} + +#[math] +pub fn cholesky_in_place<'N, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + A: MatMut<'_, C, T, Dim<'N>, Dim<'N>>, + regularization: LdltRegularization<'_, C, T>, + par: Parallelism, + stack: &mut DynStack, + params: LdltParams, +) -> Result { + let N = A.nrows(); + let mut D = unsafe { temp_mat_uninit(ctx, N, 1, stack).0 }; + let D = D.as_mat_mut(); + let mut D = D.col_mut(0).transpose_mut(); + let mut A = A; + + help!(C::Real); + let ret = match cholesky_recursion( + ctx, + A.rb_mut(), + D.rb_mut(), + params.blocksize.get(), + false, + math.gt_zero(regularization.dynamic_regularization_delta) + && math.gt_zero(regularization.dynamic_regularization_epsilon), + as_ref!(regularization.dynamic_regularization_epsilon), + as_ref!(regularization.dynamic_regularization_delta), + regularization + .dynamic_regularization_signs + .map(|signs| Array::from_ref(signs, N)), + par, + ) { + Ok(count) => Ok(LdltInfo { + dynamic_regularization_count: count, + }), + Err(index) => Err(LdltError::ZeroPivot { index }), + }; + let init = if let Err(LdltError::ZeroPivot { index }) = ret { + N.idx(index).next() + } else { + N.end() + }; + + help2!(C); + for i in zero().to(init) { + write2!(A[(i, i)] = math(copy(D[i]))); + } + + ret +} + #[cfg(test)] mod tests { use super::*; @@ -665,7 +787,7 @@ mod tests { .rand::>(rng); let A = &A * &A.adjoint(); - let A = A.as_ref().as_shape(N, N); + let A = A.as_ref(); let mut L = A.cloned(); let mut L = L.as_mut(); @@ -676,6 +798,7 @@ mod tests { &default(), L.rb_mut(), D.rb_mut(), + 32, llt, false, &0.0, diff --git a/faer/src/linalg/cholesky/ldlt/mod.rs b/faer/src/linalg/cholesky/ldlt/mod.rs index f07cfb5d..7a8a8702 100644 --- a/faer/src/linalg/cholesky/ldlt/mod.rs +++ b/faer/src/linalg/cholesky/ldlt/mod.rs @@ -1 +1,3 @@ pub mod factor; +pub mod solve; +pub mod update; diff --git a/faer/src/linalg/cholesky/ldlt/solve.rs b/faer/src/linalg/cholesky/ldlt/solve.rs new file mode 100644 index 00000000..7e5cc4a4 --- /dev/null +++ b/faer/src/linalg/cholesky/ldlt/solve.rs @@ -0,0 +1,49 @@ +use crate::internal_prelude::*; + +pub fn solve_in_place_scratch>( + dim: usize, + rhs_ncols: usize, + par: Parallelism, +) -> Result { + _ = (dim, rhs_ncols, par); + Ok(StackReq::empty()) +} + +#[math] +pub fn solve_in_place_with_conj<'N, 'K, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + LD_factors: MatRef<'_, C, T, Dim<'N>, Dim<'N>>, + conj_lhs: Conj, + rhs: MatMut<'_, C, T, Dim<'N>, Dim<'K>>, + par: Parallelism, + stack: &mut DynStack, +) { + _ = stack; + + let N = rhs.nrows(); + let K = rhs.ncols(); + let mut rhs = rhs; + linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj( + ctx, + LD_factors, + conj_lhs, + rhs.rb_mut(), + par, + ); + + help!(C); + for j in K.indices() { + for i in N.indices() { + let d = math.re(recip(cx.real(LD_factors[(i, i)]))); + write1!(rhs[(i, j)] = math(mul_real(rhs[(i, j)], d))); + } + } + + linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj( + ctx, + LD_factors.transpose(), + conj_lhs.compose(Conj::Yes), + rhs.rb_mut(), + par, + ); +} diff --git a/faer/src/linalg/cholesky/ldlt/update.rs b/faer/src/linalg/cholesky/ldlt/update.rs new file mode 100644 index 00000000..d627f1cd --- /dev/null +++ b/faer/src/linalg/cholesky/ldlt/update.rs @@ -0,0 +1,566 @@ +use crate::internal_prelude::*; +use pulp::Simd; + +#[math] +fn rank_update_step_simd<'N, 'R, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + L: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + W: MatMut<'_, C, T, Dim<'N>, Dim<'R>, ContiguousFwd>, + p: ColRef<'_, C, T, Dim<'R>>, + beta: ColRef<'_, C, T, Dim<'R>>, + align_offset: usize, +) { + struct Impl<'a, 'N, 'R, C: ComplexContainer, T: ComplexField> { + ctx: &'a Ctx, + L: ColMut<'a, C, T, Dim<'N>, ContiguousFwd>, + W: MatMut<'a, C, T, Dim<'N>, Dim<'R>, ContiguousFwd>, + p: ColRef<'a, C, T, Dim<'R>>, + beta: ColRef<'a, C, T, Dim<'R>>, + align_offset: usize, + } + + impl<'a, 'N, 'R, C: ComplexContainer, T: ComplexField> pulp::WithSimd + for Impl<'a, 'N, 'R, C, T> + { + type Output = (); + #[inline(always)] + fn with_simd(self, simd: S) { + let Self { + ctx, + L, + W, + p, + beta, + align_offset, + } = self; + + let mut L = L; + let mut W = W; + let N = W.nrows(); + let R = W.ncols(); + + let simd = SimdCtx::::new_align(T::simd_ctx(ctx, simd), N, align_offset); + let (head, body, tail) = simd.indices(); + + let mut iter = R.indices(); + let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next()); + + match (i0, i1, i2, i3) { + (Some(i0), None, None, None) => { + let p0 = math(simd.splat(p[i0])); + let beta0 = math(simd.splat(beta[i0])); + + macro_rules! simd { + ($i: expr) => {{ + let i = $i; + let mut l = simd.read(L.rb(), i); + let mut w0 = simd.read(W.rb().col(i0), i); + + w0 = simd.mul_add(p0, l, w0); + l = simd.mul_add(beta0, w0, l); + + simd.write(L.rb_mut(), i, l); + simd.write(W.rb_mut().col_mut(i0), i, w0); + }}; + } + + if let Some(i) = head { + simd!(i); + } + for i in body { + simd!(i); + } + if let Some(i) = tail { + simd!(i); + } + } + (Some(i0), Some(i1), None, None) => { + let (p0, p1) = math((simd.splat(p[i0]), simd.splat(p[i1]))); + let (beta0, beta1) = math((simd.splat(beta[i0]), simd.splat(beta[i1]))); + + macro_rules! simd { + ($i: expr) => {{ + let i = $i; + let mut l = simd.read(L.rb(), i); + let mut w0 = simd.read(W.rb().col(i0), i); + let mut w1 = simd.read(W.rb().col(i1), i); + + w0 = simd.mul_add(p0, l, w0); + l = simd.mul_add(beta0, w0, l); + w1 = simd.mul_add(p1, l, w1); + l = simd.mul_add(beta1, w1, l); + + simd.write(L.rb_mut(), i, l); + simd.write(W.rb_mut().col_mut(i0), i, w0); + simd.write(W.rb_mut().col_mut(i1), i, w1); + }}; + } + + if let Some(i) = head { + simd!(i); + } + for i in body { + simd!(i); + } + if let Some(i) = tail { + simd!(i); + } + } + (Some(i0), Some(i1), Some(i2), None) => { + let (p0, p1, p2) = + math((simd.splat(p[i0]), simd.splat(p[i1]), simd.splat(p[i2]))); + let (beta0, beta1, beta2) = math(( + simd.splat(beta[i0]), + simd.splat(beta[i1]), + simd.splat(beta[i2]), + )); + + macro_rules! simd { + ($i: expr) => {{ + let i = $i; + let mut l = simd.read(L.rb(), i); + let mut w0 = simd.read(W.rb().col(i0), i); + let mut w1 = simd.read(W.rb().col(i1), i); + let mut w2 = simd.read(W.rb().col(i2), i); + + w0 = simd.mul_add(p0, l, w0); + l = simd.mul_add(beta0, w0, l); + w1 = simd.mul_add(p1, l, w1); + l = simd.mul_add(beta1, w1, l); + w2 = simd.mul_add(p2, l, w2); + l = simd.mul_add(beta2, w2, l); + + simd.write(L.rb_mut(), i, l); + simd.write(W.rb_mut().col_mut(i0), i, w0); + simd.write(W.rb_mut().col_mut(i1), i, w1); + simd.write(W.rb_mut().col_mut(i2), i, w2); + }}; + } + + if let Some(i) = head { + simd!(i); + } + for i in body { + simd!(i); + } + if let Some(i) = tail { + simd!(i); + } + } + (Some(i0), Some(i1), Some(i2), Some(i3)) => { + let (p0, p1, p2, p3) = math(( + simd.splat(p[i0]), + simd.splat(p[i1]), + simd.splat(p[i2]), + simd.splat(p[i3]), + )); + let (beta0, beta1, beta2, beta3) = math(( + simd.splat(beta[i0]), + simd.splat(beta[i1]), + simd.splat(beta[i2]), + simd.splat(beta[i3]), + )); + + macro_rules! simd { + ($i: expr) => {{ + let i = $i; + let mut l = simd.read(L.rb(), i); + let mut w0 = simd.read(W.rb().col(i0), i); + let mut w1 = simd.read(W.rb().col(i1), i); + let mut w2 = simd.read(W.rb().col(i2), i); + let mut w3 = simd.read(W.rb().col(i3), i); + + w0 = simd.mul_add(p0, l, w0); + l = simd.mul_add(beta0, w0, l); + w1 = simd.mul_add(p1, l, w1); + l = simd.mul_add(beta1, w1, l); + w2 = simd.mul_add(p2, l, w2); + l = simd.mul_add(beta2, w2, l); + w3 = simd.mul_add(p3, l, w3); + l = simd.mul_add(beta3, w3, l); + + simd.write(L.rb_mut(), i, l); + simd.write(W.rb_mut().col_mut(i0), i, w0); + simd.write(W.rb_mut().col_mut(i1), i, w1); + simd.write(W.rb_mut().col_mut(i2), i, w2); + simd.write(W.rb_mut().col_mut(i3), i, w3); + }}; + } + + if let Some(i) = head { + simd!(i); + } + for i in body { + simd!(i); + } + if let Some(i) = tail { + simd!(i); + } + } + _ => panic!(), + } + } + } + + T::Arch::default().dispatch(Impl { + ctx, + L, + W, + p, + beta, + align_offset, + }) +} + +#[math] +fn rank_update_step_fallback<'N, 'R, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + L: ColMut<'_, C, T, Dim<'N>>, + W: MatMut<'_, C, T, Dim<'N>, Dim<'R>>, + p: ColRef<'_, C, T, Dim<'R>>, + beta: ColRef<'_, C, T, Dim<'R>>, +) { + let mut L = L; + let mut W = W; + let N = W.nrows(); + let R = W.ncols(); + + let body = N.indices(); + + let mut iter = R.indices(); + let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next()); + help!(C); + + match (i0, i1, i2, i3) { + (Some(i0), None, None, None) => { + let p0 = math(p[i0]); + let beta0 = math(beta[i0]); + + for i in body { + { + let mut l = math(copy(L[i])); + let mut w0 = math(copy(W[(i, i0)])); + + w0 = math(p0 * l + w0); + l = math(beta0 * w0 + l); + + write1!(L[i] = l); + write1!(W[(i, i0,)] = w0); + } + } + } + (Some(i0), Some(i1), None, None) => { + let (p0, p1) = math((p[i0], p[i1])); + let (beta0, beta1) = math((beta[i0], beta[i1])); + + for i in body { + { + let mut l = math(copy(L[i])); + let mut w0 = math(copy(W[(i, i0)])); + let mut w1 = math(copy(W[(i, i1)])); + + w0 = math(p0 * l + w0); + l = math(beta0 * w0 + l); + w1 = math(p1 * l + w1); + l = math(beta1 * w1 + l); + + write1!(L[i] = l); + write1!(W[(i, i0,)] = w0); + write1!(W[(i, i1,)] = w1); + } + } + } + (Some(i0), Some(i1), Some(i2), None) => { + let (p0, p1, p2) = math((p[i0], p[i1], p[i2])); + let (beta0, beta1, beta2) = math((beta[i0], beta[i1], beta[i2])); + + for i in body { + { + let mut l = math(copy(L[i])); + let mut w0 = math(copy(W[(i, i0)])); + let mut w1 = math(copy(W[(i, i1)])); + let mut w2 = math(copy(W[(i, i2)])); + + w0 = math(p0 * l + w0); + l = math(beta0 * w0 + l); + w1 = math(p1 * l + w1); + l = math(beta1 * w1 + l); + w2 = math(p2 * l + w2); + l = math(beta2 * w2 + l); + + write1!(L[i] = l); + write1!(W[(i, i0,)] = w0); + write1!(W[(i, i1,)] = w1); + write1!(W[(i, i2,)] = w2); + } + } + } + (Some(i0), Some(i1), Some(i2), Some(i3)) => { + let (p0, p1, p2, p3) = math((p[i0], p[i1], p[i2], p[i3])); + let (beta0, beta1, beta2, beta3) = math((beta[i0], beta[i1], beta[i2], beta[i3])); + + for i in body { + { + let mut l = math(copy(L[i])); + let mut w0 = math(copy(W[(i, i0)])); + let mut w1 = math(copy(W[(i, i1)])); + let mut w2 = math(copy(W[(i, i2)])); + let mut w3 = math(copy(W[(i, i3)])); + + w0 = math(p0 * l + w0); + l = math(beta0 * w0 + l); + w1 = math(p1 * l + w1); + l = math(beta1 * w1 + l); + w2 = math(p2 * l + w2); + l = math(beta2 * w2 + l); + w3 = math(p3 * l + w3); + l = math(beta3 * w3 + l); + + write1!(L[i] = l); + write1!(W[(i, i0,)] = w0); + write1!(W[(i, i1,)] = w1); + write1!(W[(i, i2,)] = w2); + write1!(W[(i, i3,)] = w3); + } + } + } + _ => panic!(), + } +} + +struct RankRUpdate<'a, 'N, 'R, C: ComplexContainer, T: ComplexField> { + ctx: &'a Ctx, + ld: MatMut<'a, C, T, Dim<'N>, Dim<'N>>, + w: MatMut<'a, C, T, Dim<'N>, Dim<'R>>, + alpha: ColMut<'a, C, T, Dim<'R>>, + r: &'a mut dyn FnMut() -> IdxInc<'R>, +} + +impl<'N, 'R, C: ComplexContainer, T: ComplexField> RankRUpdate<'_, 'N, 'R, C, T> { + // On the Modification of LDLT Factorizations + // By R. Fletcher and M. J. D. Powell + // https://www.ams.org/journals/mcom/1974-28-128/S0025-5718-1974-0359297-1/S0025-5718-1974-0359297-1.pdf + + #[math] + fn run(self) { + let Self { + ctx, + mut ld, + mut w, + mut alpha, + r, + } = self; + + let N = w.nrows(); + let K = w.ncols(); + help!(C); + + for j in N.indices() { + ghost_tree!(FULL(HEAD, TAIL), { + let (full, FULL) = N.full(FULL); + let (_, head, tail, _, _) = + full.split_inc(full.from_local_inc(j.next()), FULL.HEAD, FULL.TAIL); + let j = head.idx(*j); + + let mut L_col = ld.rb_mut().col_mut(full.from_global(j)); + + let r = Ord::min((*r)(), K.end()); + ghost_tree!(W_FULL(R, W_UNUSED), { + let (full_w, W_FULL) = K.full(W_FULL); + let (_, R_segment, _, _, _) = + full_w.split_inc(full_w.from_local_inc(r), W_FULL.R, W_FULL.W_UNUSED); + let R = R_segment.len(); + let mut W = w.rb_mut().col_segment_mut(R_segment); + let mut alpha = alpha.rb_mut().row_segment_mut(R_segment); + + const BLOCKSIZE: usize = 4; + + let mut r_next = zero(); + while let Some(r) = R.try_check(*r_next) { + r_next = R.advance(r, BLOCKSIZE); + + ghost_tree!(W_FULL(W_HEAD_UNUSED, W_TAIL(R0, W_TAIL_UNUSED)), { + let (full_r, W_FULL) = R.full(W_FULL); + + let (_, _, _, w_tail, _, TAIL) = full_r.split( + full_r.from_local(r), + W_FULL.W_HEAD_UNUSED, + W_FULL.W_TAIL, + ); + + let j_next = w_tail.idx_inc(*r_next); + + let (_, r0, _, _, _) = + w_tail.split_inc(j_next, TAIL.R0, TAIL.W_TAIL_UNUSED); + + stack_mat!(ctx, p, r0.len(), 1, BLOCKSIZE, 1, C, T); + stack_mat!(ctx, beta, r0.len(), 1, BLOCKSIZE, 1, C, T); + + let mut p = p.rb_mut().col_mut(0); + let mut beta = beta.rb_mut().col_mut(0); + + for k in r0 { + let mut p = p.rb_mut().at_mut(r0.from_global(k)); + let mut beta = beta.rb_mut().at_mut(r0.from_global(k)); + let mut alpha = alpha.rb_mut().at_mut(full_r.from_global(k)); + let mut d = L_col.rb_mut().at_mut(full.from_global(j)); + + let w = W.rb().col(full_r.from_global(k)); + + write1!(p, math(copy(w[full.from_global(j)]))); + + let alpha_conj_p = math(alpha * conj(p)); + let new_d = math.re(cx.real(d) + cx.real(cx.mul(alpha_conj_p, p))); + write1!(beta, math(mul_real(alpha_conj_p, re.recip(new_d)))); + write1!( + alpha, + math.re(cx.from_real(cx.real(alpha) - new_d * cx.abs2(beta))) + ); + write1!(d, math.from_real(new_d)); + write1!(p, math(-p)); + } + + let mut L_col = L_col.rb_mut().row_segment_mut(tail); + let mut W_col = W.rb_mut().col_segment_mut(r0).row_segment_mut(tail); + + if const { T::SIMD_CAPABILITIES.is_simd() } { + if let (Some(L_col), Some(W_col)) = ( + L_col.rb_mut().try_as_col_major_mut(), + W_col.rb_mut().try_as_col_major_mut(), + ) { + rank_update_step_simd( + ctx, + L_col, + W_col, + p.rb(), + beta.rb(), + N.next_power_of_two() - *j, + ); + } else { + rank_update_step_fallback(ctx, L_col, W_col, p.rb(), beta.rb()); + } + } else { + rank_update_step_fallback(ctx, L_col, W_col, p.rb(), beta.rb()); + } + }); + } + }); + }); + } + } +} + +#[track_caller] +pub fn rank_r_update_clobber<'N, 'R, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + cholesky_factors: MatMut<'_, C, T, Dim<'N>, Dim<'N>>, + w: MatMut<'_, C, T, Dim<'N>, Dim<'R>>, + alpha: DiagMut<'_, C, T, Dim<'R>>, +) { + let N = cholesky_factors.nrows(); + let R = w.ncols(); + + if *N == 0 { + return; + } + + RankRUpdate { + ctx, + ld: cholesky_factors, + w, + alpha: alpha.column_vector_mut(), + r: &mut || R.end(), + } + .run(); +} + +#[cfg(test)] +mod tests { + use dyn_stack::GlobalMemBuffer; + use faer_traits::Unit; + + use super::*; + use crate::{assert, c64, stats::prelude::*, utils::approx::*, Col, Mat}; + + #[test] + fn test_rank_update() { + let rng = &mut StdRng::seed_from_u64(0); + + let approx_eq = CwiseMat(ApproxEq { + ctx: ctx::>(), + abs_tol: 1e-12, + rel_tol: 1e-12, + }); + + for r in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10] { + for n in [2, 4, 8, 15] { + with_dim!(N, n); + with_dim!(R, r); + + let A = CwiseMatDistribution { + nrows: N, + ncols: N, + dist: ComplexDistribution::new(StandardNormal, StandardNormal), + } + .rand::>(rng); + let mut W = CwiseMatDistribution { + nrows: N, + ncols: R, + dist: ComplexDistribution::new(StandardNormal, StandardNormal), + } + .rand::>(rng); + let mut alpha = CwiseColDistribution { + nrows: R, + dist: ComplexDistribution::new(StandardNormal, StandardNormal), + } + .rand::>(rng) + .into_diagonal(); + + for j in R.indices() { + alpha.column_vector_mut()[j].im = 0.0; + } + + let A = &A * &A.adjoint(); + let A_new = &A + &W * &alpha * &W.adjoint(); + + let A = A.as_ref(); + let A_new = A_new.as_ref(); + + let mut L = A.cloned(); + let mut L = L.as_mut(); + + linalg::cholesky::ldlt::factor::cholesky_in_place( + &ctx(), + L.rb_mut(), + default(), + Parallelism::None, + DynStack::new(&mut GlobalMemBuffer::new( + linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::(*N) + .unwrap(), + )), + Default::default(), + ) + .unwrap(); + + linalg::cholesky::ldlt::update::rank_r_update_clobber( + &ctx(), + L.rb_mut(), + W.as_mut(), + alpha.as_mut(), + ); + let D = L.as_mut().diagonal().column_vector().as_mat().cloned(); + let D = D.col(0).as_diagonal(); + + for j in N.indices() { + for i in zero().to(j.excl()) { + L[(i, j)] = c64::ZERO; + } + L[(j, j)] = c64::ONE; + } + let L = L.as_ref(); + + assert!(A_new ~ L * D * L.adjoint()); + } + } + } +} diff --git a/faer/src/linalg/cholesky/llt/factor.rs b/faer/src/linalg/cholesky/llt/factor.rs new file mode 100644 index 00000000..96e6451e --- /dev/null +++ b/faer/src/linalg/cholesky/llt/factor.rs @@ -0,0 +1,82 @@ +use crate::{internal_prelude::*, linalg::cholesky::ldlt::factor::cholesky_recursion, RealValue}; +use core::num::NonZero; + +/// Dynamic LDLT regularization. +/// Values below `epsilon` in absolute value, or with the wrong sign are set to `delta` with +/// their corrected sign. +pub struct LltRegularization> { + /// Regularized value. + pub dynamic_regularization_delta: RealValue, + /// Regularization threshold. + pub dynamic_regularization_epsilon: RealValue, +} + +/// Info about the result of the LDLT factorization. +#[derive(Copy, Clone, Debug)] +pub struct LltInfo { + /// Number of pivots whose value or sign had to be corrected. + pub dynamic_regularization_count: usize, +} + +/// Error in the LDLT factorization. +#[derive(Copy, Clone, Debug)] +pub enum LltError { + NonPositivePivot { index: usize }, +} + +impl> LltRegularization { + #[math] + pub fn default_with(ctx: &Ctx) -> Self { + Self { + dynamic_regularization_delta: math.re(zero()), + dynamic_regularization_epsilon: math.re(zero()), + } + } +} + +impl> Default + for LltRegularization +{ + fn default() -> Self { + Self::default_with(&ctx()) + } +} + +#[non_exhaustive] +pub struct LltParams { + pub blocksize: NonZero, +} + +#[math] +pub fn cholesky_in_place<'N, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + A: MatMut<'_, C, T, Dim<'N>, Dim<'N>>, + regularization: LltRegularization, + par: Parallelism, + stack: &mut DynStack, + params: LltParams, +) -> Result { + let N = A.nrows(); + let mut D = unsafe { temp_mat_uninit(ctx, N, 1, stack).0 }; + let D = D.as_mat_mut(); + + help!(C::Real); + match cholesky_recursion( + ctx, + A, + D.col_mut(0).transpose_mut(), + params.blocksize.get(), + true, + math.gt_zero(regularization.dynamic_regularization_delta) + && math.gt_zero(regularization.dynamic_regularization_epsilon), + as_ref!(regularization.dynamic_regularization_epsilon), + as_ref!(regularization.dynamic_regularization_delta), + None, + par, + ) { + Ok(count) => Ok(LltInfo { + dynamic_regularization_count: count, + }), + Err(index) => Err(LltError::NonPositivePivot { index }), + } +} diff --git a/faer/src/linalg/cholesky/llt/mod.rs b/faer/src/linalg/cholesky/llt/mod.rs new file mode 100644 index 00000000..7a8a8702 --- /dev/null +++ b/faer/src/linalg/cholesky/llt/mod.rs @@ -0,0 +1,3 @@ +pub mod factor; +pub mod solve; +pub mod update; diff --git a/faer/src/linalg/cholesky/llt/solve.rs b/faer/src/linalg/cholesky/llt/solve.rs new file mode 100644 index 00000000..dd1410ed --- /dev/null +++ b/faer/src/linalg/cholesky/llt/solve.rs @@ -0,0 +1,62 @@ +use crate::internal_prelude::*; + +pub fn solve_in_place_scratch>( + dim: usize, + rhs_ncols: usize, + par: Parallelism, +) -> Result { + _ = (dim, rhs_ncols, par); + Ok(StackReq::empty()) +} + +pub fn solve_scratch>( + dim: usize, + rhs_ncols: usize, + par: Parallelism, +) -> Result { + _ = (dim, rhs_ncols, par); + Ok(StackReq::empty()) +} + +#[math] +pub fn solve_in_place_with_conj<'N, 'K, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + L: MatRef<'_, C, T, Dim<'N>, Dim<'N>>, + conj_lhs: Conj, + rhs: MatMut<'_, C, T, Dim<'N>, Dim<'K>>, + par: Parallelism, + stack: &mut DynStack, +) { + _ = stack; + let mut rhs = rhs; + linalg::triangular_solve::solve_lower_triangular_in_place_with_conj( + ctx, + L, + conj_lhs, + rhs.rb_mut(), + par, + ); + + linalg::triangular_solve::solve_upper_triangular_in_place_with_conj( + ctx, + L.transpose(), + conj_lhs.compose(Conj::Yes), + rhs.rb_mut(), + par, + ); +} + +#[math] +pub fn solve_with_conj<'N, 'K, C: ComplexContainer, T: ComplexField>( + ctx: &Ctx, + dst: MatMut<'_, C, T, Dim<'N>, Dim<'K>>, + L: MatRef<'_, C, T, Dim<'N>, Dim<'N>>, + conj_lhs: Conj, + rhs: MatRef<'_, C, T, Dim<'N>, Dim<'K>>, + par: Parallelism, + stack: &mut DynStack, +) { + let mut dst = dst; + dst.copy_from_with_ctx(ctx, rhs); + solve_in_place_with_conj(ctx, L, conj_lhs, dst, par, stack); +} diff --git a/faer/src/linalg/cholesky/llt/update.rs b/faer/src/linalg/cholesky/llt/update.rs new file mode 100644 index 00000000..e69de29b diff --git a/faer/src/linalg/cholesky/mod.rs b/faer/src/linalg/cholesky/mod.rs index d9f14305..1878e23c 100644 --- a/faer/src/linalg/cholesky/mod.rs +++ b/faer/src/linalg/cholesky/mod.rs @@ -1,2 +1,3 @@ pub mod bunch_kaufman; pub mod ldlt; +pub mod llt; diff --git a/faer/src/linalg/householder.rs b/faer/src/linalg/householder.rs index 1512218d..c7b1ac29 100644 --- a/faer/src/linalg/householder.rs +++ b/faer/src/linalg/householder.rs @@ -449,48 +449,48 @@ fn apply_block_householder_on_the_left_in_place_generic< let N = rhs.nrows(); let K = rhs.ncols(); let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), N); - let indices = simd.indices(); + let (head, indices, tail) = simd.indices(); help!(C); for idx in K.indices() { let mut col0 = rhs0.rb_mut().at_mut(idx); - let mut col = rhs.rb_mut().col_mut(idx).as_array_mut(); - let essential = essential.as_array(); + let mut col = rhs.rb_mut().col_mut(idx); + let essential = essential; let dot = if const { CONJ } { - math(col0 + dot::inner_prod_no_conj_simd(simd, rb!(essential), rb!(col))) + math(col0 + dot::inner_prod_no_conj_simd(simd, essential.rb(), col.rb())) } else { - math(col0 + dot::inner_prod_conj_lhs_simd(simd, rb!(essential), rb!(col))) + math(col0 + dot::inner_prod_conj_lhs_simd(simd, essential.rb(), col.rb())) }; let k = math(-dot * tau_inv); write1!(col0, math(col0 + k)); let k = simd.splat(as_ref!(k)); - for i in indices.clone() { - let mut a = simd.read(rb!(col), i); - let b = simd.read(rb!(essential), i); - - if const { CONJ } { - a = simd.conj_mul_add(b, k, a); - } else { - a = simd.mul_add(b, k, a); - } - - simd.write(rb_mut!(col), i, a); + macro_rules! simd { + ($i: expr) => {{ + let i = $i; + let mut a = simd.read(col.rb(), i); + let b = simd.read(essential.rb(), i); + + if const { CONJ } { + a = simd.conj_mul_add(b, k, a); + } else { + a = simd.mul_add(b, k, a); + } + + simd.write(col.rb_mut(), i, a); + }}; } - if simd.has_tail() { - let mut a = simd.read_tail(rb!(col)); - let b = simd.read_tail(rb!(essential)); - - if const { CONJ } { - a = simd.conj_mul_add(b, k, a); - } else { - a = simd.mul_add(b, k, a); - } - - simd.write_tail(rb_mut!(col), a); + if let Some(i) = head { + simd!(i); + } + for i in indices.clone() { + simd!(i); + } + if let Some(i) = tail { + simd!(i); } } } diff --git a/faer/src/linalg/jacobi.rs b/faer/src/linalg/jacobi.rs index 3c2c4b9e..06848c14 100644 --- a/faer/src/linalg/jacobi.rs +++ b/faer/src/linalg/jacobi.rs @@ -221,32 +221,43 @@ impl> JacobiRotation { return; }); - let mut x = x.as_array_mut(); - let mut y = y.as_array_mut(); + let mut x = x.transpose_mut(); + let mut y = y.transpose_mut(); let c = simd.splat(as_ref!(c)); let s = simd.splat(as_ref!(s)); - for i in simd.indices() { - let mut xx = simd.read(rb!(x), i); - let mut yy = simd.read(rb!(y), i); + let (head, body, tail) = simd.indices(); + + if let Some(i) = head { + let mut xx = simd.read(x.rb(), i); + let mut yy = simd.read(y.rb(), i); xx = simd.mul_add(c, xx, simd.mul(s, yy)); yy = simd.mul_add(c, yy, simd.neg(simd.mul(s, xx))); - simd.write(rb_mut!(x), i, xx); - simd.write(rb_mut!(y), i, yy); + simd.write(x.rb_mut(), i, xx); + simd.write(y.rb_mut(), i, yy); } + for i in body { + let mut xx = simd.read(x.rb(), i); + let mut yy = simd.read(y.rb(), i); - if simd.has_tail() { - let mut xx = simd.read_tail(rb!(x)); - let mut yy = simd.read_tail(rb!(y)); + xx = simd.mul_add(c, xx, simd.mul(s, yy)); + yy = simd.mul_add(c, yy, simd.neg(simd.mul(s, xx))); + + simd.write(x.rb_mut(), i, xx); + simd.write(y.rb_mut(), i, yy); + } + if let Some(i) = tail { + let mut xx = simd.read(x.rb(), i); + let mut yy = simd.read(y.rb(), i); xx = simd.mul_add(c, xx, simd.mul(s, yy)); yy = simd.mul_add(c, yy, simd.neg(simd.mul(s, xx))); - simd.write_tail(rb_mut!(x), xx); - simd.write_tail(rb_mut!(y), yy); + simd.write(x.rb_mut(), i, xx); + simd.write(y.rb_mut(), i, yy); } } diff --git a/faer/src/linalg/matmul/mod.rs b/faer/src/linalg/matmul/mod.rs index 75c81aa5..2194f155 100644 --- a/faer/src/linalg/matmul/mod.rs +++ b/faer/src/linalg/matmul/mod.rs @@ -1,12 +1,10 @@ use super::temp_mat_scratch; use crate::{ col::ColRefGeneric, + internal_prelude::*, mat::{MatMutGeneric, MatRefGeneric}, row::RowRefGeneric, - utils::{ - bound::{Array, Dim}, - simd::SimdCtx, - }, + utils::{bound::Dim, simd::SimdCtx}, Conj, ContiguousFwd, Parallelism, Shape, Stride, }; use dyn_stack::{DynStack, GlobalMemBuffer}; @@ -32,14 +30,7 @@ pub mod dot { conj_rhs: Conj, ) -> C::Of { if let (Some(lhs), Some(rhs)) = (lhs.try_as_row_major(), rhs.try_as_col_major()) { - inner_prod_slice::( - ctx, - lhs.ncols(), - lhs.as_array(), - conj_lhs, - rhs.as_array(), - conj_rhs, - ) + inner_prod_slice::(ctx, lhs.ncols(), lhs.transpose(), conj_lhs, rhs, conj_rhs) } else { let mut acc = math(zero()); for j in lhs.ncols().indices() { @@ -54,9 +45,9 @@ pub mod dot { fn inner_prod_slice<'K, C: ComplexContainer, T: ComplexField>( ctx: &Ctx, len: Dim<'K>, - lhs: C::Of<&Array<'K, T>>, + lhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, conj_lhs: Conj, - rhs: C::Of<&Array<'K, T>>, + rhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, conj_rhs: Conj, ) -> C::Of { help!(C); @@ -64,9 +55,9 @@ pub mod dot { struct Impl<'a, 'K, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, len: Dim<'K>, - lhs: C::Of<&'a Array<'K, T>>, + lhs: ColRef<'a, C, T, Dim<'K>, ContiguousFwd>, conj_lhs: Conj, - rhs: C::Of<&'a Array<'K, T>>, + rhs: ColRef<'a, C, T, Dim<'K>, ContiguousFwd>, conj_rhs: Conj, } impl<'a, 'K, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, '_, C, T> { @@ -101,8 +92,8 @@ pub mod dot { T::Arch::default().dispatch(Impl:: { ctx, len, - lhs: rb!(lhs), - rhs: rb!(rhs), + lhs, + rhs, conj_lhs, conj_rhs, }) @@ -111,8 +102,8 @@ pub mod dot { #[inline(always)] pub fn inner_prod_no_conj_simd<'K, C: ComplexContainer, T: ComplexField, S: Simd>( simd: SimdCtx<'K, C, T, S>, - lhs: C::Of<&Array<'K, T>>, - rhs: C::Of<&Array<'K, T>>, + lhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, + rhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, ) -> C::Of { help!(C); @@ -121,34 +112,40 @@ pub mod dot { let mut acc2 = simd.zero(); let mut acc3 = simd.zero(); - let (idx4, idx) = simd.batch_indices::<4>(); + let (head, idx4, idx, tail) = simd.batch_indices::<4>(); + + if let Some(i0) = head { + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); + + acc0 = simd.mul_add(l0, r0, acc0); + } for [i0, i1, i2, i3] in idx4 { - let l0 = simd.read(rb!(lhs), i0); - let l1 = simd.read(rb!(lhs), i1); - let l2 = simd.read(rb!(lhs), i2); - let l3 = simd.read(rb!(lhs), i3); + let l0 = simd.read(lhs, i0); + let l1 = simd.read(lhs, i1); + let l2 = simd.read(lhs, i2); + let l3 = simd.read(lhs, i3); - let r0 = simd.read(rb!(rhs), i0); - let r1 = simd.read(rb!(rhs), i1); - let r2 = simd.read(rb!(rhs), i2); - let r3 = simd.read(rb!(rhs), i3); + let r0 = simd.read(rhs, i0); + let r1 = simd.read(rhs, i1); + let r2 = simd.read(rhs, i2); + let r3 = simd.read(rhs, i3); acc0 = simd.mul_add(l0, r0, acc0); acc1 = simd.mul_add(l1, r1, acc1); acc2 = simd.mul_add(l2, r2, acc2); acc3 = simd.mul_add(l3, r3, acc3); } - for i0 in idx { - let l0 = simd.read(rb!(lhs), i0); - let r0 = simd.read(rb!(rhs), i0); + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); acc0 = simd.mul_add(l0, r0, acc0); } + if let Some(i0) = tail { + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); - if simd.has_tail() { - let l0 = simd.read_tail(rb!(lhs)); - let r0 = simd.read_tail(rb!(rhs)); acc0 = simd.mul_add(l0, r0, acc0); } acc0 = simd.add(acc0, acc1); @@ -161,8 +158,8 @@ pub mod dot { #[inline(always)] pub fn inner_prod_conj_lhs_simd<'K, C: ComplexContainer, T: ComplexField, S: Simd>( simd: SimdCtx<'K, C, T, S>, - lhs: C::Of<&Array<'K, T>>, - rhs: C::Of<&Array<'K, T>>, + lhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, + rhs: ColRef<'_, C, T, Dim<'K>, ContiguousFwd>, ) -> C::Of { help!(C); @@ -171,35 +168,41 @@ pub mod dot { let mut acc2 = simd.zero(); let mut acc3 = simd.zero(); - let (idx4, idx) = simd.batch_indices::<4>(); + let (head, idx4, idx, tail) = simd.batch_indices::<4>(); + + if let Some(i0) = head { + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); + + acc0 = simd.conj_mul_add(l0, r0, acc0); + } for [i0, i1, i2, i3] in idx4 { - let l0 = simd.read(rb!(lhs), i0); - let l1 = simd.read(rb!(lhs), i1); - let l2 = simd.read(rb!(lhs), i2); - let l3 = simd.read(rb!(lhs), i3); + let l0 = simd.read(lhs, i0); + let l1 = simd.read(lhs, i1); + let l2 = simd.read(lhs, i2); + let l3 = simd.read(lhs, i3); - let r0 = simd.read(rb!(rhs), i0); - let r1 = simd.read(rb!(rhs), i1); - let r2 = simd.read(rb!(rhs), i2); - let r3 = simd.read(rb!(rhs), i3); + let r0 = simd.read(rhs, i0); + let r1 = simd.read(rhs, i1); + let r2 = simd.read(rhs, i2); + let r3 = simd.read(rhs, i3); acc0 = simd.conj_mul_add(l0, r0, acc0); acc1 = simd.conj_mul_add(l1, r1, acc1); acc2 = simd.conj_mul_add(l2, r2, acc2); acc3 = simd.conj_mul_add(l3, r3, acc3); } - for i0 in idx { - let l0 = simd.read(rb!(lhs), i0); - let r0 = simd.read(rb!(rhs), i0); + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); acc0 = simd.conj_mul_add(l0, r0, acc0); } + if let Some(i0) = tail { + let l0 = simd.read(lhs, i0); + let r0 = simd.read(rhs, i0); - if simd.has_tail() { - let l0 = simd.read_tail(rb!(lhs)); - let r0 = simd.read_tail(rb!(rhs)); - acc0 = simd.mul_add(l0, r0, acc0); + acc0 = simd.conj_mul_add(l0, r0, acc0); } acc0 = simd.add(acc0, acc1); acc2 = simd.add(acc2, acc3); @@ -306,17 +309,9 @@ mod matvec_rowmajor { let lhs = lhs.row(i); let rhs = rhs; let mut tmp = if conj_lhs == conj_rhs { - dot::inner_prod_no_conj_simd::( - simd, - lhs.as_array(), - rhs.as_array(), - ) + dot::inner_prod_no_conj_simd::(simd, lhs.transpose(), rhs) } else { - dot::inner_prod_conj_lhs_simd::( - simd, - lhs.as_array(), - rhs.as_array(), - ) + dot::inner_prod_conj_lhs_simd::(simd, lhs.transpose(), rhs) }; if conj_rhs == Conj::Yes { @@ -433,39 +428,46 @@ mod matvec_colmajor { let M = lhs.nrows(); let simd = SimdCtx::::new(simd, M); - let indices = simd.indices(); + let (head, body, tail) = simd.indices(); let mut dst = dst; match beta { Some(beta) => { - let mut dst = dst.rb_mut().as_array_mut(); + let mut dst = dst.rb_mut(); if !math(beta == one()) { let beta = simd.splat(beta); - for i in indices.clone() { - let y = simd.read(rb!(dst), i); - simd.write(rb_mut!(dst), i, simd.mul(beta, y)); + if let Some(i) = head { + let y = simd.read(dst.rb(), i); + simd.write(dst.rb_mut(), i, simd.mul(beta, y)); } - if simd.has_tail() { - let y = simd.read_tail(rb!(dst)); - simd.write_tail(rb_mut!(dst), simd.mul(beta, y)); + for i in body.clone() { + let y = simd.read(dst.rb(), i); + simd.write(dst.rb_mut(), i, simd.mul(beta, y)); + } + if let Some(i) = tail { + let y = simd.read(dst.rb(), i); + simd.write(dst.rb_mut(), i, simd.mul(beta, y)); } } } None => { - let mut dst = dst.rb_mut().as_array_mut(); - for i in indices.clone() { - simd.write(rb_mut!(dst), i, simd.zero()); + let mut dst = dst.rb_mut(); + if let Some(i) = head { + simd.write(dst.rb_mut(), i, simd.zero()); + } + for i in body.clone() { + simd.write(dst.rb_mut(), i, simd.zero()); } - if simd.has_tail() { - simd.write_tail(rb_mut!(dst), simd.zero()); + if let Some(i) = tail { + simd.write(dst.rb_mut(), i, simd.zero()); } } } for j in lhs.ncols().indices() { - let mut dst = dst.rb_mut().as_array_mut(); - let lhs = lhs.col(j).as_array(); + let mut dst = dst.rb_mut(); + let lhs = lhs.col(j); let rhs = rhs.at(j); let rhs = if conj_rhs == Conj::Yes { math.conj(rhs) @@ -476,26 +478,36 @@ mod matvec_colmajor { let vrhs = simd.splat(as_ref!(rhs)); if conj_lhs == Conj::Yes { - for i in indices.clone() { - let y = simd.read(rb!(dst), i); - let x = simd.read(rb!(lhs), i); - simd.write(rb_mut!(dst), i, simd.conj_mul_add(x, vrhs, y)); + if let Some(i) = head { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y)); } - if simd.has_tail() { - let y = simd.read_tail(rb!(dst)); - let x = simd.read_tail(rb!(lhs)); - simd.write_tail(rb_mut!(dst), simd.conj_mul_add(x, vrhs, y)); + for i in body.clone() { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y)); + } + if let Some(i) = tail { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y)); } } else { - for i in indices.clone() { - let y = simd.read(rb!(dst), i); - let x = simd.read(rb!(lhs), i); - simd.write(rb_mut!(dst), i, simd.mul_add(x, vrhs, y)); + if let Some(i) = head { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y)); + } + for i in body.clone() { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y)); } - if simd.has_tail() { - let y = simd.read_tail(rb!(dst)); - let x = simd.read_tail(rb!(lhs)); - simd.write_tail(rb_mut!(dst), simd.mul_add(x, vrhs, y)); + if let Some(i) = tail { + let y = simd.read(dst.rb(), i); + let x = simd.read(lhs, i); + simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y)); } } } diff --git a/faer/src/linalg/reductions/norm_l1.rs b/faer/src/linalg/reductions/norm_l1.rs index 8dfe1971..6569703a 100644 --- a/faer/src/linalg/reductions/norm_l1.rs +++ b/faer/src/linalg/reductions/norm_l1.rs @@ -12,8 +12,7 @@ fn norm_l1_simd<'N, C: ComplexContainer, T: ComplexField>( ) -> ::Of { struct Impl<'a, 'N, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, - data: C::Of<&'a Array<'N, T>>, - len: Dim<'N>, + data: ColRef<'a, C, T, Dim<'N>, ContiguousFwd>, } impl<'N, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, 'N, C, T> { @@ -21,8 +20,8 @@ fn norm_l1_simd<'N, C: ComplexContainer, T: ComplexField>( #[inline(always)] #[math] fn with_simd(self, simd: S) -> Self::Output { - let Self { ctx, data, len } = self; - let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), len); + let Self { ctx, data } = self; + let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), data.nrows()); help!(C); let zero = simd.splat(as_ref!(math.zero())); @@ -32,40 +31,41 @@ fn norm_l1_simd<'N, C: ComplexContainer, T: ComplexField>( let mut acc2 = zero; let mut acc3 = zero; - let (head, tail) = simd.batch_indices::<4>(); - for [i0, i1, i2, i3] in head { - let x0 = simd.abs1(simd.read(rb!(data), i0)); - let x1 = simd.abs1(simd.read(rb!(data), i1)); - let x2 = simd.abs1(simd.read(rb!(data), i2)); - let x3 = simd.abs1(simd.read(rb!(data), i3)); + let (head, body4, body1, tail) = simd.batch_indices::<4>(); + + if let Some(i0) = head { + let x0 = simd.abs1(simd.read(data, i0)); + acc0 = simd.add(acc0, x0.0); + } + for [i0, i1, i2, i3] in body4 { + let x0 = simd.abs1(simd.read(data, i0)); + let x1 = simd.abs1(simd.read(data, i1)); + let x2 = simd.abs1(simd.read(data, i2)); + let x3 = simd.abs1(simd.read(data, i3)); acc0 = simd.add(acc0, x0.0); acc1 = simd.add(acc1, x1.0); acc2 = simd.add(acc2, x2.0); acc3 = simd.add(acc3, x3.0); } + for i0 in body1 { + let x0 = simd.abs1(simd.read(data, i0)); + acc0 = simd.add(acc0, x0.0); + } + if let Some(i0) = tail { + let x0 = simd.abs1(simd.read(data, i0)); + acc0 = simd.add(acc0, x0.0); + } acc0 = simd.add(acc0, acc1); acc2 = simd.add(acc2, acc3); acc0 = simd.add(acc0, acc2); - for i0 in tail { - let x0 = simd.abs1(simd.read(rb!(data), i0)); - acc0 = simd.add(acc0, x0.0); - } - if simd.has_tail() { - let x0 = simd.abs1(simd.read_tail(rb!(data))); - acc0 = simd.add(acc0, x0.0); - } math.real(simd.reduce_sum(acc0)) } } - T::Arch::default().dispatch(Impl { - ctx, - data: data.as_array(), - len: data.nrows(), - }) + T::Arch::default().dispatch(Impl { ctx, data }) } #[math] diff --git a/faer/src/linalg/reductions/norm_l2.rs b/faer/src/linalg/reductions/norm_l2.rs index 097875ed..63ea414e 100644 --- a/faer/src/linalg/reductions/norm_l2.rs +++ b/faer/src/linalg/reductions/norm_l2.rs @@ -12,8 +12,7 @@ fn norm_l2_simd<'N, C: ComplexContainer, T: ComplexField>( ) -> [::Of; 3] { struct Impl<'a, 'N, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, - data: C::Of<&'a Array<'N, T>>, - len: Dim<'N>, + data: ColRef<'a, C, T, Dim<'N>, ContiguousFwd>, } impl<'N, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, 'N, C, T> { @@ -21,8 +20,8 @@ fn norm_l2_simd<'N, C: ComplexContainer, T: ComplexField>( #[inline(always)] #[math] fn with_simd(self, simd: S) -> Self::Output { - let Self { ctx, data, len } = self; - let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), len); + let Self { ctx, data } = self; + let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), data.nrows()); help!(C); help2!(C::Real); @@ -38,10 +37,18 @@ fn norm_l2_simd<'N, C: ComplexContainer, T: ComplexField>( let mut acc0_big = Real(zero); let mut acc1_big = Real(zero); - let (head, tail) = simd.batch_indices::<2>(); - for [i0, i1] in head { - let x0 = simd.abs1(simd.read(rb!(data), i0)); - let x1 = simd.abs1(simd.read(rb!(data), i1)); + let (head, body2, body1, tail) = simd.batch_indices::<2>(); + + if let Some(i0) = head { + let x0 = simd.abs1(simd.read(data, i0)); + + acc0_sml = simd.abs2_add(simd.mul_real(x0.0, sml), acc0_sml); + acc0_med = simd.abs2_add(x0.0, acc0_med); + acc0_big = simd.abs2_add(simd.mul_real(x0.0, big), acc0_big); + } + for [i0, i1] in body2 { + let x0 = simd.abs1(simd.read(data, i0)); + let x1 = simd.abs1(simd.read(data, i1)); acc0_sml = simd.abs2_add(simd.mul_real(x0.0, sml), acc0_sml); acc1_sml = simd.abs2_add(simd.mul_real(x1.0, sml), acc1_sml); @@ -52,26 +59,24 @@ fn norm_l2_simd<'N, C: ComplexContainer, T: ComplexField>( acc0_big = simd.abs2_add(simd.mul_real(x0.0, big), acc0_big); acc1_big = simd.abs2_add(simd.mul_real(x1.0, big), acc1_big); } - - acc0_sml = Real(simd.add(acc0_sml.0, acc1_sml.0)); - acc0_big = Real(simd.add(acc0_big.0, acc1_big.0)); - acc0_med = Real(simd.add(acc0_med.0, acc1_med.0)); - - for i0 in tail { - let x0 = simd.abs1(simd.read(rb!(data), i0)); + for i0 in body1 { + let x0 = simd.abs1(simd.read(data, i0)); acc0_sml = simd.abs2_add(simd.mul_real(x0.0, sml), acc0_sml); acc0_med = simd.abs2_add(x0.0, acc0_med); acc0_big = simd.abs2_add(simd.mul_real(x0.0, big), acc0_big); } - if simd.has_tail() { - let x0 = simd.abs1(simd.read_tail(rb!(data))); + if let Some(i0) = tail { + let x0 = simd.abs1(simd.read(data, i0)); acc0_sml = simd.abs2_add(simd.mul_real(x0.0, sml), acc0_sml); acc0_med = simd.abs2_add(x0.0, acc0_med); acc0_big = simd.abs2_add(simd.mul_real(x0.0, big), acc0_big); } + acc0_sml = Real(simd.add(acc0_sml.0, acc1_sml.0)); + acc0_big = Real(simd.add(acc0_big.0, acc1_big.0)); + acc0_med = Real(simd.add(acc0_med.0, acc1_med.0)); [ math.real(simd.reduce_sum(acc0_sml.0)), math.real(simd.reduce_sum(acc0_med.0)), @@ -80,11 +85,7 @@ fn norm_l2_simd<'N, C: ComplexContainer, T: ComplexField>( } } - T::Arch::default().dispatch(Impl { - ctx, - data: data.as_array(), - len: data.nrows(), - }) + T::Arch::default().dispatch(Impl { ctx, data }) } #[math] diff --git a/faer/src/linalg/reductions/norm_l2_sqr.rs b/faer/src/linalg/reductions/norm_l2_sqr.rs index 98bb16c9..84d32b2c 100644 --- a/faer/src/linalg/reductions/norm_l2_sqr.rs +++ b/faer/src/linalg/reductions/norm_l2_sqr.rs @@ -12,8 +12,7 @@ fn norm_l2_sqr_simd<'N, C: ComplexContainer, T: ComplexField>( ) -> ::Of { struct Impl<'a, 'N, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, - data: C::Of<&'a Array<'N, T>>, - len: Dim<'N>, + data: ColRef<'a, C, T, Dim<'N>, ContiguousFwd>, } impl<'N, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, 'N, C, T> { @@ -21,8 +20,8 @@ fn norm_l2_sqr_simd<'N, C: ComplexContainer, T: ComplexField>( #[inline(always)] #[math] fn with_simd(self, simd: S) -> Self::Output { - let Self { ctx, data, len } = self; - let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), len); + let Self { ctx, data } = self; + let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), data.nrows()); help!(C); let zero = simd.splat(as_ref!(math.zero())); @@ -32,40 +31,40 @@ fn norm_l2_sqr_simd<'N, C: ComplexContainer, T: ComplexField>( let mut acc2 = Real(zero); let mut acc3 = Real(zero); - let (head, tail) = simd.batch_indices::<4>(); - for [i0, i1, i2, i3] in head { - let x0 = simd.read(rb!(data), i0); - let x1 = simd.read(rb!(data), i1); - let x2 = simd.read(rb!(data), i2); - let x3 = simd.read(rb!(data), i3); + let (head, body4, body1, tail) = simd.batch_indices::<4>(); + if let Some(i0) = head { + let x0 = simd.read(data, i0); + acc0 = simd.abs2_add(x0, acc0); + } + for [i0, i1, i2, i3] in body4 { + let x0 = simd.read(data, i0); + let x1 = simd.read(data, i1); + let x2 = simd.read(data, i2); + let x3 = simd.read(data, i3); acc0 = simd.abs2_add(x0, acc0); acc1 = simd.abs2_add(x1, acc1); acc2 = simd.abs2_add(x2, acc2); acc3 = simd.abs2_add(x3, acc3); } + for i0 in body1 { + let x0 = simd.read(data, i0); + acc0 = simd.abs2_add(x0, acc0); + } + if let Some(i0) = tail { + let x0 = simd.read(data, i0); + acc0 = simd.abs2_add(x0, acc0); + } acc0 = Real(simd.add(acc0.0, acc1.0)); acc2 = Real(simd.add(acc2.0, acc3.0)); acc0 = Real(simd.add(acc0.0, acc2.0)); - for i0 in tail { - let x0 = simd.read(rb!(data), i0); - acc0 = simd.abs2_add(x0, acc0); - } - if simd.has_tail() { - let x0 = simd.read_tail(rb!(data)); - acc0 = simd.abs2_add(x0, acc0); - } math.real(simd.reduce_sum(acc0.0)) } } - T::Arch::default().dispatch(Impl { - ctx, - data: data.as_array(), - len: data.nrows(), - }) + T::Arch::default().dispatch(Impl { ctx, data }) } #[math] diff --git a/faer/src/linalg/reductions/norm_max.rs b/faer/src/linalg/reductions/norm_max.rs index c8a7f7bb..de3e559f 100644 --- a/faer/src/linalg/reductions/norm_max.rs +++ b/faer/src/linalg/reductions/norm_max.rs @@ -12,8 +12,7 @@ fn norm_max_simd<'N, C: ComplexContainer, T: ComplexField>( ) -> ::Of { struct Impl<'a, 'N, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, - data: C::Of<&'a Array<'N, T>>, - len: Dim<'N>, + data: ColRef<'a, C, T, Dim<'N>, ContiguousFwd>, } impl<'N, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, 'N, C, T> { @@ -21,8 +20,8 @@ fn norm_max_simd<'N, C: ComplexContainer, T: ComplexField>( #[inline(always)] #[math] fn with_simd(self, simd: S) -> Self::Output { - let Self { ctx, data, len } = self; - let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), len); + let Self { ctx, data } = self; + let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), data.nrows()); help!(C); let zero = simd.splat(as_ref!(math.zero())); @@ -32,40 +31,40 @@ fn norm_max_simd<'N, C: ComplexContainer, T: ComplexField>( let mut acc2 = Real(zero); let mut acc3 = Real(zero); - let (head, tail) = simd.batch_indices::<4>(); - for [i0, i1, i2, i3] in head { - let x0 = simd.abs_max(simd.read(rb!(data), i0)); - let x1 = simd.abs_max(simd.read(rb!(data), i1)); - let x2 = simd.abs_max(simd.read(rb!(data), i2)); - let x3 = simd.abs_max(simd.read(rb!(data), i3)); + let (head, body4, body1, tail) = simd.batch_indices::<4>(); + + if let Some(i0) = head { + let x0 = simd.abs_max(simd.read(data, i0)); + acc0 = simd.max(acc0, x0); + } + for [i0, i1, i2, i3] in body4 { + let x0 = simd.abs_max(simd.read(data, i0)); + let x1 = simd.abs_max(simd.read(data, i1)); + let x2 = simd.abs_max(simd.read(data, i2)); + let x3 = simd.abs_max(simd.read(data, i3)); acc0 = simd.max(acc0, x0); acc1 = simd.max(acc1, x1); acc2 = simd.max(acc2, x2); acc3 = simd.max(acc3, x3); } - - acc0 = simd.max(acc0, acc1); - acc2 = simd.max(acc2, acc3); - acc0 = simd.max(acc0, acc2); - - for i0 in tail { - let x0 = simd.abs_max(simd.read(rb!(data), i0)); + for i0 in body1 { + let x0 = simd.abs_max(simd.read(data, i0)); acc0 = simd.max(acc0, x0); } - if simd.has_tail() { - let x0 = simd.abs_max(simd.read_tail(rb!(data))); + if let Some(i0) = tail { + let x0 = simd.abs_max(simd.read(data, i0)); acc0 = simd.max(acc0, x0); } + acc0 = simd.max(acc0, acc1); + acc2 = simd.max(acc2, acc3); + acc0 = simd.max(acc0, acc2); + math.real(simd.reduce_max(acc0)) } } - T::Arch::default().dispatch(Impl { - ctx, - data: data.as_array(), - len: data.nrows(), - }) + T::Arch::default().dispatch(Impl { ctx, data }) } #[math] diff --git a/faer/src/linalg/reductions/sum.rs b/faer/src/linalg/reductions/sum.rs index df4a9f23..8cf0787b 100644 --- a/faer/src/linalg/reductions/sum.rs +++ b/faer/src/linalg/reductions/sum.rs @@ -12,8 +12,7 @@ fn sum_simd<'N, C: ComplexContainer, T: ComplexField>( ) -> C::Of { struct Impl<'a, 'N, C: ComplexContainer, T: ComplexField> { ctx: &'a Ctx, - data: C::Of<&'a Array<'N, T>>, - len: Dim<'N>, + data: ColRef<'a, C, T, Dim<'N>, ContiguousFwd>, } impl<'N, C: ComplexContainer, T: ComplexField> pulp::WithSimd for Impl<'_, 'N, C, T> { @@ -21,8 +20,8 @@ fn sum_simd<'N, C: ComplexContainer, T: ComplexField>( #[inline(always)] #[math] fn with_simd(self, simd: S) -> Self::Output { - let Self { ctx, data, len } = self; - let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), len); + let Self { ctx, data } = self; + let simd = SimdCtx::::new(T::simd_ctx(ctx, simd), data.nrows()); help!(C); let zero = simd.splat(as_ref!(math.zero())); @@ -32,40 +31,39 @@ fn sum_simd<'N, C: ComplexContainer, T: ComplexField>( let mut acc2 = zero; let mut acc3 = zero; - let (head, tail) = simd.batch_indices::<4>(); - for [i0, i1, i2, i3] in head { - let x0 = simd.read(rb!(data), i0); - let x1 = simd.read(rb!(data), i1); - let x2 = simd.read(rb!(data), i2); - let x3 = simd.read(rb!(data), i3); + let (head, body4, body1, tail) = simd.batch_indices::<4>(); + if let Some(i0) = head { + let x0 = simd.read(data, i0); + acc0 = simd.add(acc0, x0); + } + for [i0, i1, i2, i3] in body4 { + let x0 = simd.read(data, i0); + let x1 = simd.read(data, i1); + let x2 = simd.read(data, i2); + let x3 = simd.read(data, i3); acc0 = simd.add(acc0, x0); acc1 = simd.add(acc1, x1); acc2 = simd.add(acc2, x2); acc3 = simd.add(acc3, x3); } - - acc0 = simd.add(acc0, acc1); - acc2 = simd.add(acc2, acc3); - acc0 = simd.add(acc0, acc2); - - for i0 in tail { - let x0 = simd.read(rb!(data), i0); + for i0 in body1 { + let x0 = simd.read(data, i0); acc0 = simd.add(acc0, x0); } - if simd.has_tail() { - let x0 = simd.read_tail(rb!(data)); + if let Some(i0) = tail { + let x0 = simd.read(data, i0); acc0 = simd.add(acc0, x0); } + acc0 = simd.add(acc0, acc1); + acc2 = simd.add(acc2, acc3); + acc0 = simd.add(acc0, acc2); + simd.reduce_sum(acc0) } } - T::Arch::default().dispatch(Impl { - ctx, - data: data.as_array(), - len: data.nrows(), - }) + T::Arch::default().dispatch(Impl { ctx, data }) } #[math] diff --git a/faer/src/stats/mod.rs b/faer/src/stats/mod.rs index a697196b..a5c06408 100644 --- a/faer/src/stats/mod.rs +++ b/faer/src/stats/mod.rs @@ -32,13 +32,13 @@ pub struct CwiseMatDistribution { #[derive(Copy, Clone, Debug)] pub struct CwiseColDistribution { pub nrows: Rows, - pub distribution: D, + pub dist: D, } #[derive(Copy, Clone, Debug)] pub struct CwiseRowDistribution { pub ncols: Cols, - pub distribution: D, + pub dist: D, } #[derive(Copy, Clone, Debug)] @@ -60,7 +60,7 @@ impl>> Distribution(&self, rng: &mut R) -> Col { - Col::from_fn(self.nrows, |_| self.distribution.sample(rng)) + Col::from_fn(self.nrows, |_| self.dist.sample(rng)) } } @@ -69,6 +69,6 @@ impl>> Distribution(&self, rng: &mut R) -> Row { - Row::from_fn(self.ncols, |_| self.distribution.sample(rng)) + Row::from_fn(self.ncols, |_| self.dist.sample(rng)) } } diff --git a/faer/src/utils/approx.rs b/faer/src/utils/approx.rs index eeebff36..a187360c 100644 --- a/faer/src/utils/approx.rs +++ b/faer/src/utils/approx.rs @@ -184,7 +184,7 @@ impl< &rhs, &alloc::format!("{rhs_source} at ({i:?}, {j:?})"), crate::hacks::hijack_debug(unsafe { - crate::hacks::coerce::<&C::Of, &C::OfDebug>(&lhs) + crate::hacks::coerce::<&C::Of, &C::OfDebug>(&rhs) }), f, )?; diff --git a/faer/src/utils/simd.rs b/faer/src/utils/simd.rs index 44fc1e92..b6f22fdb 100644 --- a/faer/src/utils/simd.rs +++ b/faer/src/utils/simd.rs @@ -1,14 +1,19 @@ +use crate::internal_prelude::*; use core::marker::PhantomData; -use faer_traits::{help, ComplexContainer, ComplexField, SimdCapabilities}; +use faer_traits::SimdCapabilities; use pulp::Simd; -use super::bound::{Array, Dim, Idx}; +use super::bound::{Dim, Idx}; pub struct SimdCtx<'N, C: ComplexContainer, T: ComplexField, S: Simd> { pub ctx: T::SimdCtx, pub len: Dim<'N>, - simd_len: usize, - mask: T::SimdMask, + offset: usize, + head_end: usize, + body_end: usize, + tail_end: usize, + head_mask: T::SimdMask, + tail_mask: T::SimdMask, } impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> core::fmt::Debug @@ -17,8 +22,12 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> core::fmt::Debug fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("SimdCtx") .field("len", &self.len) - .field("simd_len", &self.simd_len) - .field("mask", &self.mask) + .field("offset", &self.offset) + .field("head_end", &self.head_end) + .field("body_end", &self.body_end) + .field("tail_end", &self.tail_end) + .field("head_mask", &self.head_mask) + .field("tail_mask", &self.tail_mask) .finish() } } @@ -40,64 +49,152 @@ impl, S: Simd> core::ops::Deref for Simd } } -impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> { - #[inline] - pub fn new(simd: T::SimdCtx, len: Dim<'N>) -> Self { - let stride = const { size_of::>() / size_of::() }; - Self { - ctx: simd, - len, - simd_len: *len / stride * stride, - mask: T::simd_tail_mask(&simd, *len % stride), +pub trait SimdIndex<'N, C: ComplexContainer, T: ComplexField, S: Simd> { + fn read( + simd: &SimdCtx<'N, C, T, S>, + slice: ColRef<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + ) -> C::OfSimd>; + + fn write( + simd: &SimdCtx<'N, C, T, S>, + slice: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + value: C::OfSimd>, + ); +} + +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdIndex<'N, C, T, S> + for SimdBody<'N, C, T, S> +{ + #[inline(always)] + fn read( + simd: &SimdCtx<'N, C, T, S>, + slice: ColRef<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + ) -> ::OfSimd> { + help!(C); + unsafe { + simd.load(map!( + slice.as_ptr(), + slice, + &*(slice.wrapping_offset(index.start) as *const T::SimdVec) + )) } } - #[inline] - pub fn new_force_mask(simd: T::SimdCtx, len: Dim<'N>) -> Self { - crate::assert!(*len != 0); - let new_len = *len - 1; - - let stride = const { size_of::>() / size_of::() }; - Self { - ctx: simd, - len, - simd_len: new_len / stride * stride, - mask: T::simd_tail_mask(&simd, (new_len % stride) + 1), + #[inline(always)] + fn write( + simd: &SimdCtx<'N, C, T, S>, + slice: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + value: ::OfSimd>, + ) { + help!(C); + unsafe { + simd.store( + map!( + slice.as_ptr_mut(), + slice, + &mut *(slice.wrapping_offset(index.start) as *mut T::SimdVec) + ), + value, + ); } } +} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdIndex<'N, C, T, S> + for SimdHead<'N, C, T, S> +{ #[inline(always)] - pub fn read( - &self, - slice: C::Of<&Array<'N, T>>, - index: SimdIdx<'N, C, T, S>, - ) -> C::OfSimd> { - core::assert!( - const { - matches!( - T::SIMD_CAPABILITIES, - SimdCapabilities::All | SimdCapabilities::Shuffled - ) - } - ); + fn read( + simd: &SimdCtx<'N, C, T, S>, + slice: ColRef<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + ) -> ::OfSimd> { + help!(C); + unsafe { + simd.mask_load( + simd.head_mask, + map!( + slice.as_ptr(), + slice, + (slice.wrapping_offset(index.start) as *const T::SimdVec) + ), + ) + } + } + #[inline(always)] + fn write( + simd: &SimdCtx<'N, C, T, S>, + slice: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + value: ::OfSimd>, + ) { help!(C); unsafe { - self.load(map!( - slice, - slice, - &*(slice.as_ref().as_ptr().add(index.start.unbound()) as *const T::SimdVec) - )) + simd.mask_store( + simd.head_mask, + map!( + slice.as_ptr_mut(), + slice, + slice.wrapping_offset(index.start) as *mut T::SimdVec + ), + value, + ); } } +} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdIndex<'N, C, T, S> + for SimdTail<'N, C, T, S> +{ #[inline(always)] - pub fn has_tail(&self) -> bool { - self.len.unbound() > self.simd_len + fn read( + simd: &SimdCtx<'N, C, T, S>, + slice: ColRef<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + ) -> ::OfSimd> { + help!(C); + unsafe { + simd.mask_load( + simd.tail_mask, + map!( + slice.as_ptr(), + slice, + (slice.wrapping_offset(index.start) as *const T::SimdVec) + ), + ) + } } #[inline(always)] - pub fn read_tail(&self, slice: C::Of<&Array<'N, T>>) -> C::OfSimd> { + fn write( + simd: &SimdCtx<'N, C, T, S>, + slice: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + index: Self, + value: ::OfSimd>, + ) { + help!(C); + unsafe { + simd.mask_store( + simd.tail_mask, + map!( + slice.as_ptr_mut(), + slice, + slice.wrapping_offset(index.start) as *mut T::SimdVec + ), + value, + ); + } + } +} + +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> { + #[inline] + pub fn new(simd: T::SimdCtx, len: Dim<'N>) -> Self { core::assert!( const { matches!( @@ -106,23 +203,22 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> ) } ); - debug_assert!(self.has_tail()); - help!(C); - unsafe { - self.mask_load( - self.mask, - map!( - slice, - slice, - slice.as_ref().as_ptr().add(self.simd_len) as *const T::SimdVec - ), - ) + let stride = const { size_of::>() / size_of::() }; + Self { + ctx: simd, + len, + offset: 0, + head_end: 0, + body_end: *len / stride, + tail_end: (*len + stride - 1) / stride, + head_mask: T::simd_head_mask(&simd, 0), + tail_mask: T::simd_tail_mask(&simd, *len % stride), } } - #[inline(always)] - pub fn write_tail(&self, slice: C::Of<&mut Array<'N, T>>, value: C::OfSimd>) { + #[inline] + pub fn new_align(simd: T::SimdCtx, len: Dim<'N>, align_offset: usize) -> Self { core::assert!( const { matches!( @@ -131,29 +227,49 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> ) } ); - debug_assert!(self.has_tail()); - help!(C); - unsafe { - self.mask_store( - self.mask, - map!( - slice, - slice, - slice.as_mut().as_mut_ptr().add(self.simd_len) as *mut T::SimdVec - ), - value, - ); + let stride = const { size_of::>() / size_of::() }; + let align_offset = align_offset % stride; + + if align_offset == 0 { + Self::new(simd, len) + } else { + let offset = stride - align_offset; + let full_len = offset + *len; + let head_mask = T::simd_head_mask(&simd, align_offset); + let tail_mask = T::simd_tail_mask(&simd, full_len % stride); + + if align_offset <= *len { + Self { + ctx: simd, + len, + offset, + head_end: 1, + body_end: full_len / stride, + tail_end: (full_len + stride - 1) / stride, + head_mask, + tail_mask, + } + } else { + let head_mask = T::simd_and_mask(&simd, head_mask, tail_mask); + let tail_mask = T::simd_tail_mask(&simd, 0); + + Self { + ctx: simd, + len, + offset, + head_end: 1, + body_end: 1, + tail_end: 1, + head_mask, + tail_mask, + } + } } } - #[inline(always)] - pub fn write( - &self, - slice: C::Of<&mut Array<'N, T>>, - index: SimdIdx<'N, C, T, S>, - value: C::OfSimd>, - ) { + #[inline] + pub fn new_force_mask(simd: T::SimdCtx, len: Dim<'N>) -> Self { core::assert!( const { matches!( @@ -163,38 +279,79 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> } ); - help!(C); - unsafe { - self.store( - map!( - slice, - slice, - &mut *(slice.as_mut().as_mut_ptr().add(index.start.unbound()) - as *mut T::SimdVec) - ), - value, - ); + crate::assert!(*len != 0); + let new_len = *len - 1; + + let stride = const { size_of::>() / size_of::() }; + Self { + ctx: simd, + len, + offset: 0, + head_end: 0, + body_end: new_len / stride, + tail_end: new_len / stride + 1, + head_mask: T::simd_head_mask(&simd, 0), + tail_mask: T::simd_tail_mask(&simd, (new_len % stride) + 1), } } + #[inline(always)] + pub fn read>( + &self, + slice: ColRef<'_, C, T, Dim<'N>, ContiguousFwd>, + index: I, + ) -> C::OfSimd> { + I::read(self, slice, index) + } + + #[inline(always)] + pub fn write>( + &self, + slice: ColMut<'_, C, T, Dim<'N>, ContiguousFwd>, + index: I, + value: C::OfSimd>, + ) { + I::write(self, slice, index, value) + } + #[inline] pub fn indices( &self, - ) -> impl Clone + ExactSizeIterator + DoubleEndedIterator> { + ) -> ( + Option>, + impl Clone + ExactSizeIterator + DoubleEndedIterator>, + Option>, + ) { macro_rules! stride { () => { const { size_of::>() / size_of::() } }; } - let stride = stride!(); - let len = self.simd_len; - - (0..len / stride).map( - #[inline(always)] - move |i| SimdIdx { - start: unsafe { Idx::new_unbound(i * stride!()) }, - mask: PhantomData, + let offset = -(self.offset as isize); + ( + if 0 == self.head_end { + None + } else { + Some(SimdHead { + start: offset, + mask: PhantomData, + }) + }, + (self.head_end..self.body_end).map( + #[inline(always)] + move |i| SimdBody { + start: offset + (i * stride!()) as isize, + mask: PhantomData, + }, + ), + if self.body_end == self.tail_end { + None + } else { + Some(SimdTail { + start: offset + (self.body_end * stride!()) as isize, + mask: PhantomData, + }) }, ) } @@ -203,8 +360,10 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> pub fn batch_indices( &self, ) -> ( - impl Clone + ExactSizeIterator + DoubleEndedIterator; BATCH]>, - impl Clone + ExactSizeIterator + DoubleEndedIterator>, + Option>, + impl Clone + ExactSizeIterator + DoubleEndedIterator; BATCH]>, + impl Clone + ExactSizeIterator + DoubleEndedIterator>, + Option>, ) { const { core::assert!(BATCH.is_power_of_two()) }; @@ -214,38 +373,82 @@ impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> SimdCtx<'N, C, T, S> }; } - let stride = stride!(); - let len = *self.len; + let len = self.body_end - self.head_end; + + let offset = -(self.offset as isize); ( - (0..len / (BATCH * stride)).map(|i| { + if 0 == self.head_end { + None + } else { + Some(SimdHead { + start: offset, + mask: PhantomData, + }) + }, + (self.head_end..self.head_end + len / BATCH).map(move |i| { core::array::from_fn( #[inline(always)] - |k| SimdIdx { - start: unsafe { Idx::new_unbound((i * BATCH + k) * stride!()) }, + |k| SimdBody { + start: offset + ((i * BATCH + k) * stride!()) as isize, mask: PhantomData, }, ) }), - ((len / stride) / BATCH * BATCH..len / stride).map( + (self.head_end + len / BATCH * BATCH..self.body_end).map( #[inline(always)] - move |i| SimdIdx { - start: unsafe { Idx::new_unbound(i * stride!()) }, + move |i| SimdBody { + start: offset + (i * stride!()) as isize, mask: PhantomData, }, ), + if self.body_end == self.tail_end { + None + } else { + Some(SimdTail { + start: offset + (self.body_end * stride!()) as isize, + mask: PhantomData, + }) + }, ) } } #[repr(transparent)] -pub struct SimdIdx<'N, C: ComplexContainer, T: ComplexField, S: Simd> { - start: Idx<'N>, - mask: PhantomData>, +#[derive(Debug)] +pub struct SimdBody<'N, C: ComplexContainer, T: ComplexField, S: Simd> { + start: isize, + mask: PhantomData<(Idx<'N>, T::SimdMask)>, +} +#[repr(transparent)] +#[derive(Debug)] +pub struct SimdHead<'N, C: ComplexContainer, T: ComplexField, S: Simd> { + start: isize, + mask: PhantomData<(Idx<'N>, T::SimdMask)>, +} +#[repr(transparent)] +#[derive(Debug)] +pub struct SimdTail<'N, C: ComplexContainer, T: ComplexField, S: Simd> { + start: isize, + mask: PhantomData<(Idx<'N>, T::SimdMask)>, } -impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Copy for SimdIdx<'N, C, T, S> {} -impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Clone for SimdIdx<'N, C, T, S> { +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Copy for SimdBody<'N, C, T, S> {} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Clone for SimdBody<'N, C, T, S> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Copy for SimdHead<'N, C, T, S> {} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Clone for SimdHead<'N, C, T, S> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Copy for SimdTail<'N, C, T, S> {} +impl<'N, C: ComplexContainer, T: ComplexField, S: Simd> Clone for SimdTail<'N, C, T, S> { #[inline] fn clone(&self) -> Self { *self