diff --git a/Cargo.toml b/Cargo.toml index 6e612b63..515df772 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "faer-evd", ] exclude = ["faer-bench"] +resolver = "2" [workspace.dependencies] coe-rs = "0.1" diff --git a/faer-cholesky/Cargo.toml b/faer-cholesky/Cargo.toml index afd2f0c4..23bdea2c 100644 --- a/faer-cholesky/Cargo.toml +++ b/faer-cholesky/Cargo.toml @@ -27,7 +27,7 @@ std = ["faer-core/std", "pulp/std"] nightly = ["faer-core/nightly", "pulp/nightly"] [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" diff --git a/faer-cholesky/src/ldlt_diagonal/update.rs b/faer-cholesky/src/ldlt_diagonal/update.rs index 899f10d5..c9ea3ba9 100644 --- a/faer-cholesky/src/ldlt_diagonal/update.rs +++ b/faer-cholesky/src/ldlt_diagonal/update.rs @@ -650,6 +650,10 @@ pub fn rank_r_update_clobber( assert!(w.nrows() == n); assert!(alpha.nrows() == k); + if n == 0 { + return; + } + RankRUpdate { ld: cholesky_factors, w, @@ -832,7 +836,7 @@ pub fn insert_rows_and_cols_clobber( let r = inserted_matrix.ncols(); assert!(cholesky_factors_extended.ncols() == new_n); - assert!(r < new_n); + assert!(r <= new_n); let old_n = new_n - r; assert!(insertion_index <= old_n); diff --git a/faer-cholesky/src/llt/update.rs b/faer-cholesky/src/llt/update.rs index 27b98351..deaca6e6 100644 --- a/faer-cholesky/src/llt/update.rs +++ b/faer-cholesky/src/llt/update.rs @@ -918,6 +918,10 @@ pub fn rank_r_update_clobber( assert!(w.nrows() == n); assert!(alpha.nrows() == k); + if n == 0 { + return Ok(()); + } + RankRUpdate { l: cholesky_factor, w, diff --git a/faer-core/Cargo.toml b/faer-core/Cargo.toml index 61917d11..f753a4df 100644 --- a/faer-core/Cargo.toml +++ b/faer-core/Cargo.toml @@ -31,7 +31,7 @@ std = ["gemm/std", "pulp/std"] nightly = ["gemm/nightly", "pulp/nightly"] [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" diff --git a/faer-core/src/add.rs b/faer-core/src/add.rs deleted file mode 100644 index 82ce2574..00000000 --- a/faer-core/src/add.rs +++ /dev/null @@ -1,213 +0,0 @@ -//! addition and subtraction of matrices - -use crate::{ComplexField, Mat, MatRef}; -use core::ops::{Add, Sub}; - -// add two matrices together -impl<'a, T> Add> for MatRef<'a, T> -where - T: ComplexField, -{ - type Output = Mat; - /// create a new matrix corresponding to the addition of `rhs` to `self`. - /// # Panics - /// Panics if the matrix dimensions do not match. - fn add(self, rhs: MatRef<'_, T>) -> Self::Output { - assert_eq!( - (self.nrows(), self.ncols()), - (rhs.nrows(), rhs.ncols()), - "Matrix dimensions must match" - ); - Self::Output::with_dims(self.nrows(), self.ncols(), |i, j| { - self.read(i, j).add(&rhs.read(i, j)) - }) - } -} - -impl<'a, T> Sub> for MatRef<'a, T> -where - T: ComplexField, -{ - type Output = Mat; - /// create a new matrix corresponding to the subtraction of `rhs` from `self`. - /// # Panics - /// Panics if the matrix dimensions do not match. - fn sub(self, rhs: MatRef<'_, T>) -> Self::Output { - assert_eq!( - (self.nrows(), self.ncols()), - (rhs.nrows(), rhs.ncols()), - "Matrix dimensions must match" - ); - Self::Output::with_dims(self.nrows(), self.ncols(), |i, j| { - self.read(i, j).sub(&rhs.read(i, j)) - }) - } -} - -// implement the add trait for cases where one of the operands is -// an owned matrix by deferring to the case where both are references -// @todo: this will allocate even if one of the operands could be reuse -// and in future we should consider adding an efficient add- and sub-assign -// implementations that are used instead as backends -macro_rules! impl_binary_op { - ($Op:ty) => { - paste::paste! { - impl $Op> for Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: MatRef<'_,T>) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs) - } - } - - impl $Op> for MatRef<'_,T> - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: Mat) -> Self::Output { - self.[<$Op:lower>] (rhs.as_ref()) - } - } - - impl $Op> for Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: Mat) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs.as_ref()) - } - } - - impl $Op<&Mat> for MatRef<'_,T> - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: &Mat) -> Self::Output { - self.[<$Op:lower>] (rhs.as_ref()) - } - } - - impl $Op> for &Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: MatRef<'_,T>) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs) - } - } - - impl $Op<&Mat> for &'_ Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: &Mat) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs.as_ref()) - } - } - - impl $Op> for &Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: Mat) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs.as_ref()) - } - } - - impl $Op<&Mat> for Mat - where - T: ComplexField, - { - type Output = Mat; - fn [<$Op:lower>](self, rhs: &Mat) -> Self::Output { - self.as_ref().[<$Op:lower>] (rhs.as_ref()) - } - } - } - }; -} - -impl_binary_op!(Add); -impl_binary_op!(Sub); - -#[cfg(test)] -#[allow(non_snake_case)] -mod test { - use crate::{mat, Mat}; - use assert_approx_eq::assert_approx_eq; - - fn matrices() -> (Mat, Mat) { - let A = mat![[2.8, -3.3], [-1.7, 5.2], [4.6, -8.3],]; - - let B = mat![[-7.9, 8.3], [4.7, -3.2], [3.8, -5.2],]; - (A, B) - } - - #[test] - #[should_panic] - fn test_adding_matrices_of_different_sizes_should_panic() { - let A = mat![[1.0, 2.0], [3.0, 4.0]]; - let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - _ = A + B; - } - - #[test] - #[should_panic] - fn test_subtracting_two_matrices_of_different_sizes_should_panic() { - let A = mat![[1.0, 2.0], [3.0, 4.0]]; - let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - _ = A - B; - } - - #[test] - fn test_add() { - let (A, B) = matrices(); - - let expected = mat![[-5.1, 5.0], [3.0, 2.0], [8.4, -13.5],]; - - assert_matrix_approx_eq(A.as_ref() + B.as_ref(), &expected); - assert_matrix_approx_eq(&A + &B, &expected); - assert_matrix_approx_eq(A.as_ref() + &B, &expected); - assert_matrix_approx_eq(&A + B.as_ref(), &expected); - assert_matrix_approx_eq(A.as_ref() + B.clone(), &expected); - assert_matrix_approx_eq(&A + B.clone(), &expected); - assert_matrix_approx_eq(A.clone() + B.as_ref(), &expected); - assert_matrix_approx_eq(A.clone() + &B, &expected); - assert_matrix_approx_eq(A + B, &expected); - } - - #[test] - fn test_sub() { - let (A, B) = matrices(); - - let expected = mat![[10.7, -11.6], [-6.4, 8.4], [0.8, -3.1],]; - - assert_matrix_approx_eq(A.as_ref() - B.as_ref(), &expected); - assert_matrix_approx_eq(&A - &B, &expected); - assert_matrix_approx_eq(A.as_ref() - &B, &expected); - assert_matrix_approx_eq(&A - B.as_ref(), &expected); - assert_matrix_approx_eq(A.as_ref() - B.clone(), &expected); - assert_matrix_approx_eq(&A - B.clone(), &expected); - assert_matrix_approx_eq(A.clone() - B.as_ref(), &expected); - assert_matrix_approx_eq(A.clone() - &B, &expected); - assert_matrix_approx_eq(A - B, &expected); - } - - fn assert_matrix_approx_eq(given: Mat, expected: &Mat) { - assert_eq!(given.nrows(), expected.nrows()); - assert_eq!(given.ncols(), expected.ncols()); - for i in 0..given.nrows() { - for j in 0..given.ncols() { - assert_approx_eq!(given.read(i, j), expected.read(i, j)); - } - } - } -} diff --git a/faer-core/src/inverse.rs b/faer-core/src/inverse.rs index 89406a5f..067165d7 100644 --- a/faer-core/src/inverse.rs +++ b/faer-core/src/inverse.rs @@ -128,7 +128,7 @@ unsafe fn invert_unit_lower_triangular_impl( solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, parallelism); } -/// Computes the \[conjugate\] inverse of the lower triangular matrix `src` (with implicit unit +/// Computes the inverse of the lower triangular matrix `src` (with implicit unit /// diagonal) and stores the strictly lower triangular part of the result to `dst`. /// /// # Panics @@ -147,7 +147,7 @@ pub fn invert_unit_lower_triangular( unsafe { invert_unit_lower_triangular_impl(dst, src, parallelism) } } -/// Computes the \[conjugate\] inverse of the lower triangular matrix `src` and stores the +/// Computes the inverse of the lower triangular matrix `src` and stores the /// lower triangular part of the result to `dst`. /// /// # Panics @@ -166,7 +166,7 @@ pub fn invert_lower_triangular( unsafe { invert_lower_triangular_impl(dst, src, parallelism) } } -/// Computes the \[conjugate\] inverse of the upper triangular matrix `src` (with implicit unit +/// Computes the inverse of the upper triangular matrix `src` (with implicit unit /// diagonal) and stores the strictly upper triangular part of the result to `dst`. /// /// # Panics @@ -185,7 +185,7 @@ pub fn invert_unit_upper_triangular( ) } -/// Computes the \[conjugate\] inverse of the upper triangular matrix `src` and stores the +/// Computes the inverse of the upper triangular matrix `src` and stores the /// upper triangular part of the result to `dst`. /// /// # Panics diff --git a/faer-core/src/lib.rs b/faer-core/src/lib.rs index bc4a554e..153bd493 100644 --- a/faer-core/src/lib.rs +++ b/faer-core/src/lib.rs @@ -24,7 +24,7 @@ pub mod permutation; pub mod solve; pub mod zip; -pub mod add; +mod matrix_ops; #[doc(hidden)] pub mod simd; @@ -56,22 +56,116 @@ fn sqrt_impl(re: E, im: E) -> (E, E) { (a, b) } -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c32 { - pub re: f32, - pub im: f32, +/// Native complex floating point types whose real and imaginary parts are stored contiguously. +/// +/// The types [`c32`] and [`c64`] respectively have the same layout as [`num_complex::Complex32`] +/// and [`num_complex::Complex64`]. +/// +/// They differ in the way they are treated by the `faer` library: When stored in a matrix, +/// `Mat` and `Mat` internally contain a single container of contiguously stored +/// `c32` and `c64` values, whereas `Mat` and +/// `Mat` internally contain two containers, separately storing the real +/// and imaginary parts of the complex values. +/// +/// Matrix operations using `c32` and `c64` are usually more efficient and should be preferred in +/// most cases. `num_complex::Complex` matrices have better support for generic data types. +/// +/// The drawing below represents a simplified layout of the `Mat` structure for each of `c32` and +/// `num_complex::Complex32`. +/// +/// ```notcode +/// ┌──────────────────┐ +/// │ Mat │ +/// ├──────────────────┤ +/// │ ptr: *mut c32 ─ ─│─ ─ ─ ─ ┐ +/// │ nrows: usize │ ┌─────────┐ +/// │ ncols: usize │ │ z0: c32 │ +/// │ ... │ │ z1: c32 │ +/// └──────────────────┘ │ z2: c32 │ +/// │ ... │ +/// └─────────┘ +/// +/// ┌───────────────────────┐ +/// │ Mat │ +/// ├───────────────────────┤ +/// │ ptr_real: *mut f32 ─ ─│─ ─ ─ ─ ┐ +/// │ ptr_imag: *mut f32 ─ ─│─ ─ ─ ─ ┼ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ nrows: usize │ ┌──────────┐ ┌──────────┐ +/// │ ncols: usize │ │ re0: f32 │ │ im0: f32 │ +/// │ ... │ │ re1: f32 │ │ im1: f32 │ +/// └───────────────────────┘ │ re2: f32 │ │ im2: f32 │ +/// │ ... │ │ ... │ +/// └──────────┘ └──────────┘ +/// ``` +pub mod complex_native { + // 32-bit complex floating point type. See the module-level documentation for more details. + #[allow(non_camel_case_types)] + #[derive(Copy, Clone, PartialEq)] + #[repr(C)] + pub struct c32 { + pub re: f32, + pub im: f32, + } + + // 64-bit complex floating point type. See the module-level documentation for more details. + #[allow(non_camel_case_types)] + #[derive(Copy, Clone, PartialEq)] + #[repr(C)] + pub struct c64 { + pub re: f64, + pub im: f64, + } + + // 32-bit implicitly conjugated complex floating point type. + #[allow(non_camel_case_types)] + #[derive(Copy, Clone, PartialEq)] + #[repr(C)] + pub struct c32conj { + pub re: f32, + pub neg_im: f32, + } + + // 64-bit implicitly conjugated complex floating point type. + #[allow(non_camel_case_types)] + #[derive(Copy, Clone, PartialEq)] + #[repr(C)] + pub struct c64conj { + pub re: f64, + pub neg_im: f64, + } } -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c64 { - pub re: f64, - pub im: f64, +/// Utilities for split complex number types whose real and imaginary parts are stored separately. +pub mod complex_split { + + /// This structure contains the real and imaginary parts of an implicity conjugated value. + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + #[repr(C)] + pub struct ComplexConj { + pub re: T, + pub neg_im: T, + } + + /// This structure contains a pair of iterators that allow simultaneous iteration over the real + /// and imaginary parts of a collection of complex values. + #[derive(Clone, Debug)] + pub struct ComplexIter { + pub(crate) re: I, + pub(crate) im: I, + } + + /// This structure contains a pair of iterators that allow simultaneous iteration over the real + /// and imaginary parts of a collection of implicitly conjugated complex values. + #[derive(Clone, Debug)] + pub struct ComplexConjIter { + pub(crate) re: I, + pub(crate) neg_im: I, + } } +pub use complex_native::*; +use complex_split::*; + impl c32 { #[inline(always)] pub fn new(re: f32, im: f32) -> Self { @@ -172,22 +266,6 @@ impl From for c64 { } } -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c32conj { - pub re: f32, - pub neg_im: f32, -} - -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c64conj { - pub re: f64, - pub neg_im: f64, -} - unsafe impl bytemuck::Zeroable for c32 {} unsafe impl bytemuck::Zeroable for c32conj {} unsafe impl bytemuck::Zeroable for c64 {} @@ -230,6 +308,12 @@ impl Debug for c64conj { } } +/// Unstable core trait for describing how a scalar value may be split up into individual +/// component. +/// +/// For example, `f64` is treated as a single indivisible unit, but [`num_complex::Complex`] +/// is split up into its real and imaginary components, with each one being stored in a separate +/// container. pub unsafe trait Entity: Clone + PartialEq + Send + Sync + Debug + 'static { type Unit: Clone + Send + Sync + Debug + 'static; type Index: Copy + Send + Sync + Debug + 'static; @@ -238,7 +322,9 @@ pub unsafe trait Entity: Clone + PartialEq + Send + Sync + Debug + 'static { type SimdIndex: Copy + Send + Sync + Debug + 'static; type Group; + /// Must be the same as `Group`. type GroupCopy: Copy; + /// Must be the same as `Group`. type GroupThreadSafe: Send + Sync; type Iter: Iterator>; @@ -398,36 +484,54 @@ impl Conj { } } +/// Trait for types that may be implicitly conjugated. pub unsafe trait Conjugate: Entity { + /// Must have the same layout as `Self`, and `Conj::Unit` must have the same layout as `Unit`. type Conj: Entity + Conjugate; + /// Must have the same layout as `Self`, and `Canonical::Unit` must have the same layout as + /// `Unit`. type Canonical: Entity + Conjugate; + /// Performs the implicit conjugation operation on the given value, returning the canonical + /// form. fn canonicalize(self) -> Self::Canonical; } type SimdGroup = ::Group<::SimdUnit>; +/// Unstable trait containing the operations that a number type needs to implement. pub trait ComplexField: Entity + Conjugate { type Real: RealField; + /// Converts `value` from `f64` to `Self`. + /// The conversion may be lossy when converting to a type with less precision. fn from_f64(value: f64) -> Self; + /// Returns `self + rhs`. fn add(&self, rhs: &Self) -> Self; + /// Returns `self - rhs`. fn sub(&self, rhs: &Self) -> Self; + /// Returns `self * rhs`. fn mul(&self, rhs: &Self) -> Self; + /// Returns an estimate of `lhs * rhs + acc`. #[inline(always)] fn mul_adde(lhs: &Self, rhs: &Self, acc: &Self) -> Self { acc.add(&lhs.mul(rhs)) } + /// Returns an estimate of `conjugate(lhs) * rhs + acc`. #[inline(always)] fn conj_mul_adde(lhs: &Self, rhs: &Self, acc: &Self) -> Self { acc.add(&lhs.conj().mul(rhs)) } + /// Returns `-self`. fn neg(&self) -> Self; + /// Returns `1.0/self`. fn inv(&self) -> Self; + /// Returns `conjugate(self)`. fn conj(&self) -> Self; + /// Returns the square root of `self`. fn sqrt(&self) -> Self; /// Returns the input, scaled by `rhs`. @@ -440,10 +544,15 @@ pub trait ComplexField: Entity + Conjugate { /// /// An implementation may choose either, so long as it chooses consistently. fn score(&self) -> Self::Real; + /// Returns the absolute value of `self`. fn abs(&self) -> Self::Real; + /// Returns the squared absolute value of `self`. fn abs2(&self) -> Self::Real; + /// Returns a NaN value. fn nan() -> Self; + + /// Returns true if `self` is a NaN value, or false otherwise. #[inline(always)] fn is_nan(&self) -> bool { #[allow(clippy::eq_op)] @@ -460,7 +569,9 @@ pub trait ComplexField: Entity + Conjugate { /// Returns the imaginary part. fn imag(&self) -> Self::Real; + /// Returns `0.0`. fn zero() -> Self; + /// Returns `1.0`. fn one() -> Self; fn slice_as_simd(slice: &[Self::Unit]) -> (&[Self::SimdUnit], &[Self::Unit]); @@ -600,6 +711,7 @@ pub trait ComplexField: Entity + Conjugate { } } +/// Unstable trait containing the operations that a real number type needs to implement. pub trait RealField: ComplexField + PartialOrd { fn div(&self, rhs: &Self) -> Self; @@ -2716,24 +2828,6 @@ unsafe impl Entity for f64 { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(C)] -pub struct ComplexConj { - pub re: T, - pub neg_im: T, -} - -#[derive(Clone, Debug)] -pub struct ComplexIter { - re: I, - im: I, -} -#[derive(Clone, Debug)] -pub struct ComplexConjIter { - re: I, - neg_im: I, -} - impl Iterator for ComplexIter { type Item = Complex; @@ -3042,7 +3136,7 @@ unsafe impl Conjugate for ComplexConj { } struct MatImpl { - ptr: E::GroupCopy>, + ptr: E::GroupCopy<*mut E::Unit>, nrows: usize, ncols: usize, row_stride: isize, @@ -3057,6 +3151,15 @@ impl Clone for MatImpl { } } +/// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice]. +/// +/// # Note: +/// +/// Unlike a slice, the data pointed to by `MatRef<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions ([`std::mem::needs_drop::()`] must be false). In +/// this case, care must be taken to not perform any operations that read the uninitialized values, +/// or form references to them, either directly through [`MatRef::read`], or indirectly through any +/// of the numerical library routines, unless it is explicitly permitted. pub struct MatRef<'a, E: Entity> { inner: MatImpl, __marker: PhantomData<&'a E>, @@ -3070,6 +3173,15 @@ impl Clone for MatRef<'_, E> { } } +/// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice]. +/// +/// # Note: +/// +/// Unlike a slice, the data pointed to by `MatMut<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions ([`std::mem::needs_drop::()`] must be false). In +/// this case, care must be taken to not perform any operations that read the uninitialized values, +/// or form references to them, either directly through [`MatMut::read`], or indirectly through any +/// of the numerical library routines, unless it is explicitly permitted. pub struct MatMut<'a, E: Entity> { inner: MatImpl, __marker: PhantomData<&'a mut E>, @@ -3136,6 +3248,44 @@ pub fn par_split_indices(n: usize, idx: usize, chunk_count: usize) -> (usize, us } impl<'a, E: Entity> MatRef<'a, E> { + /// Creates a `MatRef` from pointers to the matrix data, dimensions, and strides. + /// + /// The row (resp. column) stride is the offset from the memory address of a given matrix + /// element at indices `(row: i, col: j)`, to the memory address of the matrix element at + /// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in + /// number of elements, not in bytes. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * For each matrix unit, the entire memory region addressed by the matrix must be contained + /// within a single allocation, accessible in its entirety by the corresponding pointer in + /// `ptr`. + /// * For each matrix unit, the corresponding pointer must be properly aligned, + /// even for a zero-sized matrix. + /// * If [`std::mem::needs_drop::()`], then all the addresses accessible by each + /// matrix unit must point to initialized elements of type `E::Unit`. Otherwise, the values + /// accessible by the matrix must be initialized at some point before they are read, or + /// references to them are formed. + /// * No mutable aliasing is allowed. In other words, none of the elements accessible by any + /// matrix unit may be accessed for writes by any other means for the duration of the lifetime + /// `'a`. + /// + /// # Example + /// + /// ``` + /// use faer_core::{mat, MatRef}; + /// + /// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. + /// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, + /// // which is 4. + /// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, + /// // which is 1. + /// let data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; + /// let matrix = unsafe { MatRef::::from_raw_parts(data.as_ptr() as *const f64, 2, 3, 4, 1) }; + /// + /// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// assert_eq!(expected.as_ref(), matrix); + /// ``` #[inline(always)] #[track_caller] pub unsafe fn from_raw_parts( @@ -3148,9 +3298,7 @@ impl<'a, E: Entity> MatRef<'a, E> { E::map(E::as_ref(&ptr), |ptr| debug_assert!(!ptr.is_null())); Self { inner: MatImpl { - ptr: E::into_copy(E::map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), + ptr: E::into_copy(E::map(ptr, |ptr| ptr as *mut E::Unit)), nrows, ncols, row_stride, @@ -3160,33 +3308,37 @@ impl<'a, E: Entity> MatRef<'a, E> { } } + /// Returns pointers to the matrix data. #[inline(always)] pub fn as_ptr(self) -> E::Group<*const E::Unit> { - E::map(E::from_copy(self.inner.ptr), |ptr| { - ptr.as_ptr() as *const E::Unit - }) + E::map(E::from_copy(self.inner.ptr), |ptr| ptr as *const E::Unit) } + /// Returns the number of rows of the matrix. #[inline(always)] pub fn nrows(&self) -> usize { self.inner.nrows } + /// Returns the number of columns of the matrix. #[inline(always)] pub fn ncols(&self) -> usize { self.inner.ncols } + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. #[inline(always)] pub fn row_stride(&self) -> isize { self.inner.row_stride } + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. #[inline(always)] pub fn col_stride(&self) -> isize { self.inner.col_stride } + /// Returns raw pointers to the element at the given indices. #[inline(always)] pub fn ptr_at(self, row: usize, col: usize) -> E::Group<*const E::Unit> { E::map(self.as_ptr(), |ptr| { @@ -3195,6 +3347,13 @@ impl<'a, E: Entity> MatRef<'a, E> { }) } + /// Returns raw pointers to the element at the given indices, assuming the provided indices + /// are within the matrix dimensions. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn ptr_inbounds_at(self, row: usize, col: usize) -> E::Group<*const E::Unit> { @@ -3206,6 +3365,17 @@ impl<'a, E: Entity> MatRef<'a, E> { }) } + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] pub fn split_at(self, row: usize, col: usize) -> [Self; 4] { @@ -3233,6 +3403,10 @@ impl<'a, E: Entity> MatRef<'a, E> { } } + /// Splits the matrix horizontally at the given row into two parts and returns an array of each + /// submatrix, in the following order: + /// * top. + /// * bottom. #[inline(always)] #[track_caller] pub fn split_at_row(self, row: usize) -> [Self; 2] { @@ -3240,6 +3414,10 @@ impl<'a, E: Entity> MatRef<'a, E> { [top, bot] } + /// Splits the matrix vertically at the given row into two parts and returns an array of each + /// submatrix, in the following order: + /// * left. + /// * right. #[inline(always)] #[track_caller] pub fn split_at_col(self, col: usize) -> [Self; 2] { @@ -3247,12 +3425,32 @@ impl<'a, E: Entity> MatRef<'a, E> { [left, right] } + /// Returns references to the element at the given indices. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn get_unchecked(self, row: usize, col: usize) -> E::Group<&'a E::Unit> { E::map(self.ptr_inbounds_at(row, col), |ptr| &*ptr) } + /// Returns references to the element at the given indices, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn get(self, row: usize, col: usize) -> E::Group<&'a E::Unit> { @@ -3261,18 +3459,43 @@ impl<'a, E: Entity> MatRef<'a, E> { unsafe { self.get_unchecked(row, col) } } + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { E::from_units(E::map(self.get_unchecked(row, col), |ptr| (*ptr).clone())) } + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn read(&self, row: usize, col: usize) -> E { E::from_units(E::map(self.get(row, col), |ptr| (*ptr).clone())) } + /// Returns a view over the transpose of `self`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let transpose = view.transpose(); + /// + /// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + /// assert_eq!(expected.as_ref(), transpose); + /// ``` #[inline(always)] #[must_use] pub fn transpose(self) -> Self { @@ -3288,6 +3511,7 @@ impl<'a, E: Entity> MatRef<'a, E> { } } + /// Returns a view over the conjugate of `self`. #[inline(always)] #[must_use] pub fn conjugate(self) -> MatRef<'a, E::Conj> @@ -3300,8 +3524,8 @@ impl<'a, E: Entity> MatRef<'a, E> { MatRef { inner: MatImpl { ptr: transmute_unchecked::< - E::GroupCopy>, - ::GroupCopy::Unit>>, + E::GroupCopy<*mut E::Unit>, + ::GroupCopy<*mut ::Unit>, >(self.inner.ptr), nrows: self.inner.nrows, ncols: self.inner.ncols, @@ -3313,6 +3537,7 @@ impl<'a, E: Entity> MatRef<'a, E> { } } + /// Returns a view over the conjugate transpose of `self`. #[inline(always)] pub fn adjoint(self) -> MatRef<'a, E::Conj> where @@ -3321,6 +3546,8 @@ impl<'a, E: Entity> MatRef<'a, E> { self.transpose().conjugate() } + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. #[inline(always)] pub fn canonicalize(self) -> (MatRef<'a, E::Canonical>, Conj) where @@ -3332,9 +3559,9 @@ impl<'a, E: Entity> MatRef<'a, E> { MatRef { inner: MatImpl { ptr: transmute_unchecked::< - E::GroupCopy>, + E::GroupCopy<*mut E::Unit>, ::GroupCopy< - NonNull<::Unit>, + *mut ::Unit, >, >(self.inner.ptr), nrows: self.inner.nrows, @@ -3353,6 +3580,19 @@ impl<'a, E: Entity> MatRef<'a, E> { ) } + /// Returns a view over the `self`, with the rows in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed_rows = view.reverse_rows(); + /// + /// let expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; + /// assert_eq!(expected.as_ref(), reversed_rows); + /// ``` #[inline(always)] #[must_use] pub fn reverse_rows(self) -> Self { @@ -3365,6 +3605,19 @@ impl<'a, E: Entity> MatRef<'a, E> { unsafe { Self::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } } + /// Returns a view over the `self`, with the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed_cols = view.reverse_cols(); + /// + /// let expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; + /// assert_eq!(expected.as_ref(), reversed_cols); + /// ``` #[inline(always)] #[must_use] pub fn reverse_cols(self) -> Self { @@ -3376,6 +3629,19 @@ impl<'a, E: Entity> MatRef<'a, E> { unsafe { Self::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } } + /// Returns a view over the `self`, with the rows and the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed = view.reverse_rows_and_cols(); + /// + /// let expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; + /// assert_eq!(expected.as_ref(), reversed); + /// ``` #[inline(always)] #[must_use] pub fn reverse_rows_and_cols(self) -> Self { @@ -3391,6 +3657,33 @@ impl<'a, E: Entity> MatRef<'a, E> { unsafe { Self::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } } + /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with + /// dimensions `(nrows, ncols)`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `col_start <= self.ncols()`. + /// * `nrows <= self.nrows() - row_start`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let submatrix = view.submatrix(2, 1, 2, 2); + /// + /// let expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; + /// assert_eq!(expected.as_ref(), submatrix); + /// ``` #[track_caller] #[inline(always)] pub fn submatrix(self, row_start: usize, col_start: usize, nrows: usize, ncols: usize) -> Self { @@ -3411,30 +3704,109 @@ impl<'a, E: Entity> MatRef<'a, E> { } } + /// Returns a view over the submatrix starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let subrows = view.subrows(1, 2); + /// + /// let expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; + /// assert_eq!(expected.as_ref(), subrows); + /// ``` #[track_caller] #[inline(always)] pub fn subrows(self, row_start: usize, nrows: usize) -> Self { self.submatrix(row_start, 0, nrows, self.ncols()) } + /// Returns a view over the submatrix starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let subcols = view.subcols(2, 1); + /// + /// let expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; + /// assert_eq!(expected.as_ref(), subcols); + /// ``` #[track_caller] #[inline(always)] pub fn subcols(self, col_start: usize, ncols: usize) -> Self { self.submatrix(0, col_start, self.nrows(), ncols) } + /// Returns a view over the row at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_idx < self.nrows()`. #[track_caller] #[inline(always)] pub fn row(self, row_idx: usize) -> Self { self.subrows(row_idx, 1) } + /// Returns a view over the column at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_idx < self.ncols()`. #[track_caller] #[inline(always)] pub fn col(self, col_idx: usize) -> Self { self.subcols(col_idx, 1) } + /// Returns a view over the main diagonal of the matrix. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let diagonal = view.diagonal(); + /// + /// let expected = mat![[1.0], [6.0], [11.0]]; + /// assert_eq!(expected.as_ref(), diagonal); + /// ``` #[track_caller] #[inline(always)] pub fn diagonal(self) -> Self { @@ -3444,7 +3816,7 @@ impl<'a, E: Entity> MatRef<'a, E> { unsafe { Self::from_raw_parts(self.as_ptr(), size, 1, row_stride + col_stride, 0) } } - /// Returns an owning [`Mat`] of the data + /// Returns an owning [`Mat`] of the data. #[inline] pub fn to_owned(&self) -> Mat where @@ -3457,12 +3829,18 @@ impl<'a, E: Entity> MatRef<'a, E> { mat } - /// Returns a thin wrapper that can be used to execute coefficientwise operations on matrices. + /// Returns a thin wrapper that can be used to execute coefficient-wise operations on matrices. #[inline] pub fn cwise(self) -> Zip<(Self,)> { Zip { tuple: (self,) } } + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> MatRef<'_, E> { + *self + } + #[doc(hidden)] #[inline(always)] pub unsafe fn const_cast(self) -> MatMut<'a, E> { @@ -3474,6 +3852,47 @@ impl<'a, E: Entity> MatRef<'a, E> { } impl<'a, E: Entity> MatMut<'a, E> { + /// Creates a `MatMut` from pointers to the matrix data, dimensions, and strides. + /// + /// The row (resp. column) stride is the offset from the memory address of a given matrix + /// element at indices `(row: i, col: j)`, to the memory address of the matrix element at + /// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in + /// number of elements, not in bytes. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * For each matrix unit, the entire memory region addressed by the matrix must be contained + /// within a single allocation, accessible in its entirety by the corresponding pointer in + /// `ptr`. + /// * For each matrix unit, the corresponding pointer must be properly aligned, + /// even for a zero-sized matrix. + /// * If [`std::mem::needs_drop::()`], then all the addresses accessible by each + /// matrix unit must point to initialized elements of type `E::Unit`. Otherwise, the values + /// accessible by the matrix must be initialized at some point before they are read, or + /// references to them are formed. + /// * No aliasing (including self aliasing) is allowed. In other words, none of the elements + /// accessible by any matrix unit may be accessed for reads or writes by any other means for + /// the duration of the lifetime `'a`. No two elements within a single matrix unit may point to + /// the same address (such a thing can be achieved with a zero stride, for example), and no two + /// matrix units may point to the same address. + /// + /// # Example + /// + /// ``` + /// use faer_core::{mat, MatMut}; + /// + /// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. + /// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, + /// // which is 4. + /// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, + /// // which is 1. + /// let mut data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; + /// let mut matrix = + /// unsafe { MatMut::::from_raw_parts(data.as_mut_ptr() as *mut f64, 2, 3, 4, 1) }; + /// + /// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// assert_eq!(expected.as_ref(), matrix); + /// ``` #[inline(always)] #[track_caller] pub unsafe fn from_raw_parts( @@ -3486,9 +3905,7 @@ impl<'a, E: Entity> MatMut<'a, E> { E::map(E::as_ref(&ptr), |ptr| debug_assert!(!ptr.is_null())); Self { inner: MatImpl { - ptr: E::into_copy(E::map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), + ptr: E::into_copy(ptr), nrows, ncols, row_stride, @@ -3498,33 +3915,37 @@ impl<'a, E: Entity> MatMut<'a, E> { } } + /// Returns pointers to the matrix data. #[inline(always)] pub fn as_ptr(self) -> E::Group<*mut E::Unit> { - E::map(E::from_copy(self.inner.ptr), |ptr| { - ptr.as_ptr() as *mut E::Unit - }) + E::from_copy(self.inner.ptr) } + /// Returns the number of rows of the matrix. #[inline(always)] pub fn nrows(&self) -> usize { self.inner.nrows } + /// Returns the number of columns of the matrix. #[inline(always)] pub fn ncols(&self) -> usize { self.inner.ncols } + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. #[inline(always)] pub fn row_stride(&self) -> isize { self.inner.row_stride } + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. #[inline(always)] pub fn col_stride(&self) -> isize { self.inner.col_stride } + /// Returns raw pointers to the element at the given indices. #[inline(always)] pub fn ptr_at(self, row: usize, col: usize) -> E::Group<*mut E::Unit> { let row_stride = self.inner.row_stride; @@ -3535,6 +3956,13 @@ impl<'a, E: Entity> MatMut<'a, E> { }) } + /// Returns raw pointers to the element at the given indices, assuming the provided indices + /// are within the matrix dimensions. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn ptr_inbounds_at(self, row: usize, col: usize) -> E::Group<*mut E::Unit> { @@ -3548,6 +3976,17 @@ impl<'a, E: Entity> MatMut<'a, E> { }) } + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] pub fn split_at(self, row: usize, col: usize) -> [Self; 4] { @@ -3562,6 +4001,10 @@ impl<'a, E: Entity> MatMut<'a, E> { } } + /// Splits the matrix horizontally at the given row into two parts and returns an array of each + /// submatrix, in the following order: + /// * top. + /// * bottom. #[inline(always)] #[track_caller] pub fn split_at_row(self, row: usize) -> [Self; 2] { @@ -3569,6 +4012,10 @@ impl<'a, E: Entity> MatMut<'a, E> { [top, bot] } + /// Splits the matrix vertically at the given row into two parts and returns an array of each + /// submatrix, in the following order: + /// * left. + /// * right. #[inline(always)] #[track_caller] pub fn split_at_col(self, col: usize) -> [Self; 2] { @@ -3576,12 +4023,32 @@ impl<'a, E: Entity> MatMut<'a, E> { [left, right] } + /// Returns mutable references to the element at the given indices. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn get_unchecked(self, row: usize, col: usize) -> E::Group<&'a mut E::Unit> { E::map(self.ptr_inbounds_at(row, col), |ptr| &mut *ptr) } + /// Returns mutable references to the element at the given indices, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn get(self, row: usize, col: usize) -> E::Group<&'a mut E::Unit> { @@ -3590,18 +4057,36 @@ impl<'a, E: Entity> MatMut<'a, E> { unsafe { self.get_unchecked(row, col) } } + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { self.rb().read_unchecked(row, col) } + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn read(&self, row: usize, col: usize) -> E { self.rb().read(row, col) } + /// Writes the value to the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { @@ -3610,6 +4095,12 @@ impl<'a, E: Entity> MatMut<'a, E> { E::map(zipped, |(unit, ptr)| *ptr = unit); } + /// Writes the value to the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn write(&mut self, row: usize, col: usize, value: E) { @@ -3618,12 +4109,19 @@ impl<'a, E: Entity> MatMut<'a, E> { unsafe { self.write_unchecked(row, col, value) }; } + /// Copies the values from `other` into `self`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. #[inline(always)] #[track_caller] pub fn clone_from(&mut self, other: MatRef<'_, E>) { zipped!(self.rb_mut(), other).for_each(|mut dst, src| dst.write(src.read())); } + /// Fills the elements of `self` with zeros. #[inline(always)] #[track_caller] pub fn set_zeros(&mut self) @@ -3633,12 +4131,26 @@ impl<'a, E: Entity> MatMut<'a, E> { zipped!(self.rb_mut()).for_each(|mut x| x.write(E::zero())); } + /// Fills the elements of `self` with copies of `constant`. #[inline(always)] #[track_caller] - pub fn set_constant(&mut self, c: E) { - zipped!(self.rb_mut()).for_each(|mut x| x.write(c.clone())); + pub fn set_constant(&mut self, constant: E) { + zipped!(self.rb_mut()).for_each(|mut x| x.write(constant.clone())); } + /// Returns a view over the transpose of `self`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let transpose = view.transpose(); + /// + /// let mut expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + /// assert_eq!(expected.as_mut(), transpose); + /// ``` #[inline(always)] #[must_use] pub fn transpose(self) -> Self { @@ -3654,6 +4166,7 @@ impl<'a, E: Entity> MatMut<'a, E> { } } + /// Returns a view over the conjugate of `self`. #[inline(always)] #[must_use] pub fn conjugate(self) -> MatMut<'a, E::Conj> @@ -3663,6 +4176,7 @@ impl<'a, E: Entity> MatMut<'a, E> { unsafe { self.into_const().conjugate().const_cast() } } + /// Returns a view over the conjugate transpose of `self`. #[inline(always)] #[must_use] pub fn adjoint(self) -> MatMut<'a, E::Conj> @@ -3672,6 +4186,8 @@ impl<'a, E: Entity> MatMut<'a, E> { self.transpose().conjugate() } + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. #[inline(always)] #[must_use] pub fn canonicalize(self) -> (MatMut<'a, E::Canonical>, Conj) @@ -3682,24 +4198,90 @@ impl<'a, E: Entity> MatMut<'a, E> { unsafe { (canonical.const_cast(), conj) } } + /// Returns a view over the `self`, with the rows in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed_rows = view.reverse_rows(); + /// + /// let mut expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; + /// assert_eq!(expected.as_mut(), reversed_rows); + /// ``` #[inline(always)] #[must_use] pub fn reverse_rows(self) -> Self { unsafe { self.into_const().reverse_rows().const_cast() } } + /// Returns a view over the `self`, with the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed_cols = view.reverse_cols(); + /// + /// let mut expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; + /// assert_eq!(expected.as_mut(), reversed_cols); + /// ``` #[inline(always)] #[must_use] pub fn reverse_cols(self) -> Self { unsafe { self.into_const().reverse_cols().const_cast() } } + /// Returns a view over the `self`, with the rows and the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed = view.reverse_rows_and_cols(); + /// + /// let mut expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; + /// assert_eq!(expected.as_mut(), reversed); + /// ``` #[inline(always)] #[must_use] pub fn reverse_rows_and_cols(self) -> Self { unsafe { self.into_const().reverse_rows_and_cols().const_cast() } } + /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with + /// dimensions `(nrows, ncols)`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `col_start <= self.ncols()`. + /// * `nrows <= self.nrows() - row_start`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let submatrix = view.submatrix(2, 1, 2, 2); + /// + /// let mut expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; + /// assert_eq!(expected.as_mut(), submatrix); + /// ``` #[track_caller] #[inline(always)] pub fn submatrix(self, row_start: usize, col_start: usize, nrows: usize, ncols: usize) -> Self { @@ -3710,12 +4292,63 @@ impl<'a, E: Entity> MatMut<'a, E> { } } + /// Returns a view over the submatrix starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let subrows = view.subrows(1, 2); + /// + /// let mut expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; + /// assert_eq!(expected.as_mut(), subrows); + /// ``` #[track_caller] #[inline(always)] pub fn subrows(self, row_start: usize, nrows: usize) -> Self { let ncols = self.ncols(); self.submatrix(row_start, 0, nrows, ncols) } + + /// Returns a view over the submatrix starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let subcols = view.subcols(2, 1); + /// + /// let mut expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; + /// assert_eq!(expected.as_mut(), subcols); + /// ``` #[track_caller] #[inline(always)] pub fn subcols(self, col_start: usize, ncols: usize) -> Self { @@ -3723,18 +4356,47 @@ impl<'a, E: Entity> MatMut<'a, E> { self.submatrix(0, col_start, nrows, ncols) } + /// Returns a view over the row at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_idx < self.nrows()`. #[track_caller] #[inline(always)] pub fn row(self, row_idx: usize) -> Self { self.subrows(row_idx, 1) } + /// Returns a view over the column at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_idx < self.ncols()`. #[track_caller] #[inline(always)] pub fn col(self, col_idx: usize) -> Self { self.subcols(col_idx, 1) } + /// Returns a view over the main diagonal of the matrix. + /// + /// # Example + /// ``` + /// use faer_core::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let diagonal = view.diagonal(); + /// + /// let mut expected = mat![[1.0], [6.0], [11.0]]; + /// assert_eq!(expected.as_mut(), diagonal); + /// ``` #[track_caller] #[inline(always)] pub fn diagonal(self) -> Self { @@ -3750,11 +4412,23 @@ impl<'a, E: Entity> MatMut<'a, E> { self.rb().to_owned() } - /// Returns a thin wrapper that can be used to execute coefficientwise operations on matrices. + /// Returns a thin wrapper that can be used to execute coefficient-wise operations on matrices. #[inline] pub fn cwise(self) -> Zip<(Self,)> { Zip { tuple: (self,) } } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> MatRef<'_, E> { + self.rb() + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> MatMut<'_, E> { + self.rb_mut() + } } impl<'a, E: RealField> MatRef<'a, Complex> { @@ -3787,12 +4461,12 @@ impl<'a, E: RealField> MatMut<'a, Complex> { } } -impl<'a, U: Conjugate, T: Conjugate> PartialEq> - for MatRef<'a, T> +impl> PartialEq> + for MatRef<'_, T> where T::Canonical: ComplexField, { - fn eq(&self, other: &MatRef<'a, U>) -> bool { + fn eq(&self, other: &MatRef<'_, U>) -> bool { let same_dims = self.nrows() == other.nrows() && self.ncols() == other.ncols(); if !same_dims { false @@ -3813,25 +4487,6 @@ where } } -impl<'a, U: Conjugate, T: Conjugate> PartialEq> - for MatMut<'a, T> -where - T::Canonical: ComplexField, -{ - fn eq(&self, other: &MatMut<'a, U>) -> bool { - self.rb().eq(&other.rb()) - } -} - -impl> PartialEq> for Mat -where - T::Canonical: ComplexField, -{ - fn eq(&self, other: &Mat) -> bool { - self.as_ref().eq(&other.as_ref()) - } -} - #[repr(C)] struct RawMatUnit { ptr: NonNull, @@ -4010,6 +4665,33 @@ impl Drop for ColGuard { } } +/// Heap allocated resizable matrix, similar to a 2D [`Vec`]. +/// +/// # Note +/// +/// The memory layout of `Mat` is guaranteed to be column-major, meaning that it has a row stride +/// of `1`, and an unspecified column stride that can be queried with [`Mat::col_stride`]. +/// +/// This implies that while each individual column is stored contiguously in memory, the matrix as +/// a whole may not necessarily be contiguous. The implementation may add padding at the end of +/// each column when overaligning each column can provide a performance gain. +/// +/// Let us consider a 3×4 matrix +/// +/// ```notcode +/// 0 │ 3 │ 6 │ 9 +/// ───┼───┼───┼─── +/// 1 │ 4 │ 7 │ 10 +/// ───┼───┼───┼─── +/// 2 │ 5 │ 8 │ 11 +/// ``` +/// The memory representation of the data held by such a matrix could look like the following: +/// +/// ```notcode +/// 0 1 2 X 3 4 5 X 6 7 8 X 9 10 11 X +/// ``` +/// +/// where X represents padding elements. #[repr(C)] pub struct Mat { raw: RawMat, @@ -4018,7 +4700,7 @@ pub struct Mat { } #[repr(C)] -pub struct MatUnit { +struct MatUnit { raw: RawMatUnit, nrows: usize, ncols: usize, @@ -4171,8 +4853,7 @@ impl Mat { /// the matrix will not allocate. /// /// # Panics - /// - /// Panics if the total capacity in bytes exceeds `isize::MAX`. + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. #[inline] pub fn with_capacity(row_capacity: usize, col_capacity: usize) -> Self { Self { @@ -4185,8 +4866,7 @@ impl Mat { /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with the provided function. /// /// # Panics - /// - /// Panics if the total capacity in bytes exceeds `isize::MAX`. + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. #[inline] pub fn with_dims(nrows: usize, ncols: usize, f: impl FnMut(usize, usize) -> E) -> Self { let mut this = Self::new(); @@ -4197,8 +4877,7 @@ impl Mat { /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with zeros. /// /// # Panics - /// - /// Panics if the total capacity in bytes exceeds `isize::MAX`. + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. #[inline] pub fn zeros(nrows: usize, ncols: usize) -> Self where @@ -4210,9 +4889,9 @@ impl Mat { /// Set the dimensions of the matrix. /// /// # Safety - /// - /// * `nrows` must be less than `self.row_capacity()`. - /// * `ncols` must be less than `self.col_capacity()`. + /// The behavior is undefined if any of the following conditions are violated: + /// * `nrows < self.row_capacity()`. + /// * `ncols < self.col_capacity()`. /// * The elements that were previously out of bounds but are now in bounds must be /// initialized. #[inline] @@ -4316,8 +4995,7 @@ impl Mat { /// columns without reallocating. Does nothing if the capacity is already sufficient. /// /// # Panics - /// - /// Panics if the new total capacity in bytes exceeds `isize::MAX`. + /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. #[inline] pub fn reserve_exact(&mut self, row_capacity: usize, col_capacity: usize) { if self.row_capacity() >= row_capacity && self.col_capacity() >= col_capacity { @@ -4463,7 +5141,7 @@ impl Mat { /// Resizes the matrix in-place so that the new dimensions are `(new_nrows, new_ncols)`. /// Elements that are now out of bounds are dropped, while new elements are created with the - /// given function `f`, so that elements at position `(i, j)` are created by calling `f(i, j)`. + /// given function `f`, so that elements at indices `(i, j)` are created by calling `f(i, j)`. pub fn resize_with( &mut self, new_nrows: usize, @@ -4528,37 +5206,61 @@ impl Mat { } } + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { self.as_ref().read_unchecked(row, col) } + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn read(&self, row: usize, col: usize) -> E { self.as_ref().read(row, col) } + /// Writes the value to the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { self.as_mut().write_unchecked(row, col, value); } + /// Writes the value to the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. #[inline(always)] #[track_caller] pub fn write(&mut self, row: usize, col: usize, value: E) { self.as_mut().write(row, col, value); } - /// Returns the transpose of `self`. + /// Returns a view over the transpose of `self`. #[inline] pub fn transpose(&self) -> MatRef<'_, E> { self.as_ref().transpose() } - /// Returns the conjugate of `self`. + /// Returns a view over the conjugate of `self`. #[inline] pub fn conjugate(&self) -> MatRef<'_, E::Conj> where @@ -4567,7 +5269,7 @@ impl Mat { self.as_ref().conjugate() } - /// Returns the conjugate transpose of `self`. + /// Returns a view over the conjugate transpose of `self`. #[inline] pub fn adjoint(&self) -> MatRef<'_, E::Conj> where @@ -4603,6 +5305,33 @@ macro_rules! __transpose_impl { }; } +/// Creates a [`Mat`] containing the arguments. +/// +/// ``` +/// use faer_core::mat; +/// +/// let matrix = mat![ +/// [1.0, 5.0, 9.0], +/// [2.0, 6.0, 10.0], +/// [3.0, 7.0, 11.0], +/// [4.0, 8.0, 12.0f64], +/// ]; +/// +/// assert_eq!(matrix.read(0, 0), 1.0); +/// assert_eq!(matrix.read(1, 0), 2.0); +/// assert_eq!(matrix.read(2, 0), 3.0); +/// assert_eq!(matrix.read(3, 0), 4.0); +/// +/// assert_eq!(matrix.read(0, 1), 5.0); +/// assert_eq!(matrix.read(1, 1), 6.0); +/// assert_eq!(matrix.read(2, 1), 7.0); +/// assert_eq!(matrix.read(3, 1), 8.0); +/// +/// assert_eq!(matrix.read(0, 2), 9.0); +/// assert_eq!(matrix.read(1, 2), 10.0); +/// assert_eq!(matrix.read(2, 2), 11.0); +/// assert_eq!(matrix.read(3, 2), 12.0); +/// ``` #[macro_export] macro_rules! mat { () => { @@ -4740,6 +5469,9 @@ enum DynMatUnitImpl<'a, T> { Uninit(DynArray<'a, MaybeUninit>), } +/// A temporary matrix allocated from a [`DynStack`]. +/// +/// [`DynStack`]: dyn_stack::DynStack pub struct DynMat<'a, E: Entity> { inner: E::Group>, nrows: usize, @@ -4949,6 +5681,35 @@ impl<'a, FromE: Entity, ToE: Entity> Coerce> for MatMut<'a, From } } +/// Zips together matrix of the same size, so that coefficient-wise operations can be performed on +/// their elements. +/// +/// # Note +/// The order in which the matrix elements are traversed is unspecified. +/// +/// # Example +/// ``` +/// use faer_core::{mat, zipped, Mat}; +/// +/// let nrows = 2; +/// let ncols = 3; +/// +/// let a = mat![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]; +/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]]; +/// let mut sum = Mat::::zeros(nrows, ncols); +/// +/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|mut sum, a, b| { +/// let a = a.read(); +/// let b = b.read(); +/// sum.write(a + b); +/// }); +/// +/// for i in 0..nrows { +/// for j in 0..ncols { +/// assert_eq!(sum.read(i, j), a.read(i, j) + b.read(i, j)); +/// } +/// } +/// ``` #[macro_export] macro_rules! zipped { ($first: expr $(, $rest: expr)* $(,)?) => { @@ -5020,79 +5781,7 @@ where } } -impl> core::ops::Mul> - for MatRef<'_, LhsE> -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: Mat) -> Self::Output { - self.mul(rhs.as_ref()) - } -} - -impl> core::ops::Mul> - for Mat -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { - self.as_ref().mul(rhs) - } -} - -impl> core::ops::Mul> - for Mat -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: Mat) -> Self::Output { - self.as_ref().mul(rhs.as_ref()) - } -} - -impl> core::ops::Mul<&Mat> - for Mat -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: &Mat) -> Self::Output { - self.as_ref().mul(rhs.as_ref()) - } -} - -impl> core::ops::Mul> - for &Mat -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: Mat) -> Self::Output { - self.as_ref().mul(rhs.as_ref()) - } -} - -impl> core::ops::Mul<&Mat> - for &Mat -where - LhsE::Canonical: ComplexField, -{ - type Output = Mat; - - fn mul(self, rhs: &Mat) -> Self::Output { - self.as_ref().mul(rhs.as_ref()) - } -} - -impl<'a, T: ComplexField> core::ops::MulAssign for MatMut<'a, T> { +impl core::ops::MulAssign for MatMut<'_, T> { fn mul_assign(&mut self, rhs: T) { self.rb_mut().cwise().for_each(|mut x| { let val = x.read(); diff --git a/faer-core/src/matrix_ops.rs b/faer-core/src/matrix_ops.rs new file mode 100644 index 00000000..278169b8 --- /dev/null +++ b/faer-core/src/matrix_ops.rs @@ -0,0 +1,365 @@ +//! addition and subtraction of matrices + +use crate::{zipped, ComplexField, Conjugate, Mat, MatMut, MatRef}; +use core::ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}; +use reborrow::*; + +impl> AddAssign> + for MatMut<'_, LhsE> +{ + fn add_assign(&mut self, rhs: MatRef<'_, RhsE>) { + assert_eq!((self.nrows(), self.ncols()), (rhs.nrows(), rhs.ncols())); + zipped!(self.rb_mut(), rhs).for_each(|mut lhs, rhs| { + lhs.write(lhs.read().add(&rhs.read().canonicalize())); + }); + } +} + +impl> SubAssign> + for MatMut<'_, LhsE> +{ + fn sub_assign(&mut self, rhs: MatRef<'_, RhsE>) { + assert_eq!((self.nrows(), self.ncols()), (rhs.nrows(), rhs.ncols())); + zipped!(self.rb_mut(), rhs).for_each(|mut lhs, rhs| { + lhs.write(lhs.read().sub(&rhs.read().canonicalize())); + }); + } +} + +impl> Add> + for MatRef<'_, LhsE> +where + LhsE::Canonical: ComplexField, +{ + type Output = Mat; + + fn add(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + assert_eq!((self.nrows(), self.ncols()), (rhs.nrows(), rhs.ncols())); + // SAFETY: we checked that the lhs and rhs dimensions are the same, so unchecked access is + // fine + unsafe { + Self::Output::with_dims(self.nrows(), self.ncols(), |i, j| { + self.read_unchecked(i, j) + .canonicalize() + .add(&rhs.read_unchecked(i, j).canonicalize()) + }) + } + } +} + +impl> Sub> + for MatRef<'_, LhsE> +where + LhsE::Canonical: ComplexField, +{ + type Output = Mat; + + fn sub(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + assert_eq!((self.nrows(), self.ncols()), (rhs.nrows(), rhs.ncols())); + // SAFETY: we checked that the lhs and rhs dimensions are the same, so unchecked access is + // fine + unsafe { + Self::Output::with_dims(self.nrows(), self.ncols(), |i, j| { + self.read_unchecked(i, j) + .canonicalize() + .sub(&rhs.read_unchecked(i, j).canonicalize()) + }) + } + } +} + +impl Neg for MatRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Mat; + + fn neg(self) -> Self::Output { + // SAFETY: destination and input dimensions are the same + unsafe { + Self::Output::with_dims(self.nrows(), self.ncols(), |i, j| { + self.read_unchecked(i, j).canonicalize().neg() + }) + } + } +} + +// implement unary traits for cases where the operand is +// an owned matrix by deferring to the case where it's a reference +// @todo: this will allocate even if the operand could be reused +// and in the future we should consider adding an efficient Neg +// implementation that is used instead as a backend +macro_rules! impl_unary_op_single { + ($trait_name: ident, $op: ident, $operand: ty) => { + impl $trait_name for $operand + where + E::Canonical: ComplexField, + { + type Output = Mat; + fn $op(self) -> Self::Output { + self.as_ref().$op() + } + } + }; +} + +// implement binary traits for cases where one of the operands is +// an owned matrix by deferring to the case where both are references +// @todo: this will allocate even if one of the operands could be reused +// and in the future we should consider adding an efficient add- and sub-assign +// implementations that are used instead as backends +macro_rules! impl_binary_op_single { + ($trait_name: ident, $op: ident, $lhs: ty, $rhs: ty) => { + impl> $trait_name<$rhs> + for $lhs + where + LhsE::Canonical: ComplexField, + { + type Output = Mat; + fn $op(self, rhs: $rhs) -> Self::Output { + self.as_ref().$op(rhs.as_ref()) + } + } + }; +} + +macro_rules! impl_assign_op_single { + ($trait_name: ident, $op: ident, $lhs: ty, $rhs: ty) => { + impl> $trait_name<$rhs> for $lhs { + fn $op(&mut self, rhs: $rhs) { + self.as_mut().$op(rhs.as_ref()) + } + } + }; +} + +macro_rules! impl_eq_single { + ($lhs: ty, $rhs: ty) => { + impl> PartialEq<$rhs> for $lhs + where + LhsE::Canonical: ComplexField, + { + fn eq(&self, rhs: &$rhs) -> bool { + PartialEq::eq(&self.as_ref(), &rhs.as_ref()) + } + } + }; +} + +macro_rules! impl_unary_op { + ($trait_name: ident, $op: ident) => { + // possible operands: + // + // Mat + // &Mat + // MatRef + // &MatRef + // MatMut + // &MatMut + + impl_unary_op_single!($trait_name, $op, Mat); + impl_unary_op_single!($trait_name, $op, &Mat); + // impl_unary_op_single!($trait_name, $op, MatRef<'_, E>); + impl_unary_op_single!($trait_name, $op, &MatRef<'_, E>); + impl_unary_op_single!($trait_name, $op, MatMut<'_, E>); + impl_unary_op_single!($trait_name, $op, &MatMut<'_, E>); + }; +} + +macro_rules! impl_binary_op { + ($trait_name: ident, $op: ident) => { + impl_binary_op_single!($trait_name, $op, Mat, Mat); + impl_binary_op_single!($trait_name, $op, Mat, &Mat); + impl_binary_op_single!($trait_name, $op, Mat, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, Mat, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, Mat, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, Mat, &MatMut<'_, RhsE>); + + impl_binary_op_single!($trait_name, $op, &Mat, Mat); + impl_binary_op_single!($trait_name, $op, &Mat, &Mat); + impl_binary_op_single!($trait_name, $op, &Mat, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &Mat, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &Mat, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &Mat, &MatMut<'_, RhsE>); + + impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, Mat); + impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, &Mat); + // impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatRef<'_, LhsE>, &MatMut<'_, RhsE>); + + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, Mat); + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, &Mat); + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatRef<'_, LhsE>, &MatMut<'_, RhsE>); + + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, Mat); + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, &Mat); + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, MatMut<'_, LhsE>, &MatMut<'_, RhsE>); + + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, Mat); + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, &Mat); + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, &MatRef<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, MatMut<'_, RhsE>); + impl_binary_op_single!($trait_name, $op, &MatMut<'_, LhsE>, &MatMut<'_, RhsE>); + }; +} + +macro_rules! impl_assign_op { + ($trait_name: ident, $op: ident) => { + impl_assign_op_single!($trait_name, $op, Mat, Mat); + impl_assign_op_single!($trait_name, $op, Mat, &Mat); + impl_assign_op_single!($trait_name, $op, Mat, MatRef<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, Mat, &MatRef<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, Mat, MatMut<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, Mat, &MatMut<'_, RhsE>); + + impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, Mat); + impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, &Mat); + // impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, MatRef<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, &MatRef<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, MatMut<'_, RhsE>); + impl_assign_op_single!($trait_name, $op, MatMut<'_, LhsE>, &MatMut<'_, RhsE>); + }; +} + +impl_eq_single!(Mat, Mat); +impl_eq_single!(Mat, &Mat); +impl_eq_single!(Mat, MatRef<'_, RhsE>); +impl_eq_single!(Mat, &MatRef<'_, RhsE>); +impl_eq_single!(Mat, MatMut<'_, RhsE>); +impl_eq_single!(Mat, &MatMut<'_, RhsE>); + +impl_eq_single!(&Mat, Mat); +// impl_eq_single!(&Mat, &Mat); +impl_eq_single!(&Mat, MatRef<'_, RhsE>); +// impl_eq_single!(&Mat, &MatRef<'_, RhsE>); +impl_eq_single!(&Mat, MatMut<'_, RhsE>); +// impl_eq_single!(&Mat, &MatMut<'_, RhsE>); + +impl_eq_single!(MatRef<'_, LhsE>, Mat); +impl_eq_single!(MatRef<'_, LhsE>, &Mat); +// impl_eq_single!(MatRef<'_, LhsE>, MatRef<'_, RhsE>); +impl_eq_single!(MatRef<'_, LhsE>, &MatRef<'_, RhsE>); +impl_eq_single!(MatRef<'_, LhsE>, MatMut<'_, RhsE>); +impl_eq_single!(MatRef<'_, LhsE>, &MatMut<'_, RhsE>); + +impl_eq_single!(&MatRef<'_, LhsE>, Mat); +// impl_eq_single!(&MatRef<'_, LhsE>, &Mat); +impl_eq_single!(&MatRef<'_, LhsE>, MatRef<'_, RhsE>); +// impl_eq_single!(&MatRef<'_, LhsE>, &MatRef<'_, RhsE>); +impl_eq_single!(&MatRef<'_, LhsE>, MatMut<'_, RhsE>); +// impl_eq_single!(&MatRef<'_, LhsE>, &MatMut<'_, RhsE>); + +impl_eq_single!(MatMut<'_, LhsE>, Mat); +impl_eq_single!(MatMut<'_, LhsE>, &Mat); +impl_eq_single!(MatMut<'_, LhsE>, MatRef<'_, RhsE>); +impl_eq_single!(MatMut<'_, LhsE>, &MatRef<'_, RhsE>); +impl_eq_single!(MatMut<'_, LhsE>, MatMut<'_, RhsE>); +impl_eq_single!(MatMut<'_, LhsE>, &MatMut<'_, RhsE>); + +impl_eq_single!(&MatMut<'_, LhsE>, Mat); +// impl_eq_single!(&MatMut<'_, LhsE>, &Mat); +impl_eq_single!(&MatMut<'_, LhsE>, MatRef<'_, RhsE>); +// impl_eq_single!(&MatMut<'_, LhsE>, &MatRef<'_, RhsE>); +impl_eq_single!(&MatMut<'_, LhsE>, MatMut<'_, RhsE>); +// impl_eq_single!(&MatMut<'_, LhsE>, &MatMut<'_, RhsE>); + +impl_unary_op!(Neg, neg); + +impl_binary_op!(Add, add); +impl_binary_op!(Sub, sub); +impl_binary_op!(Mul, mul); + +impl_assign_op!(AddAssign, add_assign); +impl_assign_op!(SubAssign, sub_assign); + +#[cfg(test)] +#[allow(non_snake_case)] +mod test { + use crate::{mat, Mat}; + use assert_approx_eq::assert_approx_eq; + + fn matrices() -> (Mat, Mat) { + let A = mat![[2.8, -3.3], [-1.7, 5.2], [4.6, -8.3],]; + + let B = mat![[-7.9, 8.3], [4.7, -3.2], [3.8, -5.2],]; + (A, B) + } + + #[test] + #[should_panic] + fn test_adding_matrices_of_different_sizes_should_panic() { + let A = mat![[1.0, 2.0], [3.0, 4.0]]; + let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + _ = A + B; + } + + #[test] + #[should_panic] + fn test_subtracting_two_matrices_of_different_sizes_should_panic() { + let A = mat![[1.0, 2.0], [3.0, 4.0]]; + let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + _ = A - B; + } + + #[test] + fn test_add() { + let (A, B) = matrices(); + + let expected = mat![[-5.1, 5.0], [3.0, 2.0], [8.4, -13.5],]; + + assert_matrix_approx_eq(A.as_ref() + B.as_ref(), &expected); + assert_matrix_approx_eq(&A + &B, &expected); + assert_matrix_approx_eq(A.as_ref() + &B, &expected); + assert_matrix_approx_eq(&A + B.as_ref(), &expected); + assert_matrix_approx_eq(A.as_ref() + B.clone(), &expected); + assert_matrix_approx_eq(&A + B.clone(), &expected); + assert_matrix_approx_eq(A.clone() + B.as_ref(), &expected); + assert_matrix_approx_eq(A.clone() + &B, &expected); + assert_matrix_approx_eq(A + B, &expected); + } + + #[test] + fn test_sub() { + let (A, B) = matrices(); + + let expected = mat![[10.7, -11.6], [-6.4, 8.4], [0.8, -3.1],]; + + assert_matrix_approx_eq(A.as_ref() - B.as_ref(), &expected); + assert_matrix_approx_eq(&A - &B, &expected); + assert_matrix_approx_eq(A.as_ref() - &B, &expected); + assert_matrix_approx_eq(&A - B.as_ref(), &expected); + assert_matrix_approx_eq(A.as_ref() - B.clone(), &expected); + assert_matrix_approx_eq(&A - B.clone(), &expected); + assert_matrix_approx_eq(A.clone() - B.as_ref(), &expected); + assert_matrix_approx_eq(A.clone() - &B, &expected); + assert_matrix_approx_eq(A - B, &expected); + } + + #[test] + fn test_neg() { + let (A, _) = matrices(); + + let expected = mat![[-2.8, 3.3], [1.7, -5.2], [-4.6, 8.3],]; + + assert_eq!(-A, expected); + } + + fn assert_matrix_approx_eq(given: Mat, expected: &Mat) { + assert_eq!(given.nrows(), expected.nrows()); + assert_eq!(given.ncols(), expected.ncols()); + for i in 0..given.nrows() { + for j in 0..given.ncols() { + assert_approx_eq!(given.read(i, j), expected.read(i, j)); + } + } + } +} diff --git a/faer-core/src/mul.rs b/faer-core/src/mul.rs index 5ee693a8..b9ae35e4 100644 --- a/faer-core/src/mul.rs +++ b/faer-core/src/mul.rs @@ -1,3 +1,5 @@ +//! Matrix multiplication. + use crate::{ c32, c64, simd::*, transmute_unchecked, zipped, ComplexField, Conj, Conjugate, MatMut, MatRef, Parallelism, SimdGroup, @@ -7,6 +9,7 @@ use core::{iter::zip, marker::PhantomData, mem::MaybeUninit}; use pulp::Simd; use reborrow::*; +#[doc(hidden)] pub mod inner_prod { use super::*; use assert2::assert; @@ -617,6 +620,7 @@ mod matvec_colmajor { } } +#[doc(hidden)] pub mod matvec { use super::*; @@ -686,6 +690,7 @@ pub mod matvec { } } +#[doc(hidden)] pub mod outer_prod { use super::*; use assert2::assert; @@ -1902,6 +1907,62 @@ pub fn matmul_with_conj_gemm_dispatch( } } +/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating +/// either or both of the input matrices) and stores the result in `acc`. +/// +/// Performs the operation: +/// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the preexisting +/// values in `acc` are not read, so it is allowed to be a view over uninitialized values if `E: +/// Copy`), +/// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`, +/// +/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is +/// `Conj::Yes`. +/// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is +/// `Conj::Yes`. +/// +/// # Panics +/// +/// Panics if the matrix dimensions are not compatible for matrix multiplication. +/// i.e. +/// - `acc.nrows() == lhs.nrows()` +/// - `acc.ncols() == rhs.ncols()` +/// - `lhs.ncols() == rhs.nrows()` +/// +/// # Example +/// +/// ``` +/// use faer_core::{mat, mul::matmul_with_conj, zipped, Conj, Mat, Parallelism}; +/// +/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; +/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; +/// +/// let mut acc = Mat::::zeros(2, 2); +/// let target = mat![ +/// [ +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 1) + lhs.read(0, 1) * rhs.read(1, 1)), +/// ], +/// [ +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), +/// ], +/// ]; +/// +/// matmul_with_conj( +/// acc.as_mut(), +/// lhs.as_ref(), +/// Conj::No, +/// rhs.as_ref(), +/// Conj::No, +/// None, +/// 2.5, +/// Parallelism::None, +/// ); +/// +/// zipped!(acc.as_ref(), target.as_ref()) +/// .for_each(|acc, target| assert!((acc.read() - target.read()).abs() < 1e-10)); +/// ``` #[inline] #[track_caller] pub fn matmul_with_conj( @@ -1930,6 +1991,54 @@ pub fn matmul_with_conj( ); } +/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and +/// stores the result in `acc`. +/// +/// Performs the operation: +/// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in `acc` +/// are not read, so it is allowed to be a view over uninitialized values if `E: Copy`), +/// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`, +/// +/// # Panics +/// +/// Panics if the matrix dimensions are not compatible for matrix multiplication. +/// i.e. +/// - `acc.nrows() == lhs.nrows()` +/// - `acc.ncols() == rhs.ncols()` +/// - `lhs.ncols() == rhs.nrows()` +/// +/// # Example +/// +/// ``` +/// use faer_core::{mat, mul::matmul, zipped, Mat, Parallelism}; +/// +/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; +/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; +/// +/// let mut acc = Mat::::zeros(2, 2); +/// let target = mat![ +/// [ +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 1) + lhs.read(0, 1) * rhs.read(1, 1)), +/// ], +/// [ +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), +/// ], +/// ]; +/// +/// matmul( +/// acc.as_mut(), +/// lhs.as_ref(), +/// rhs.as_ref(), +/// None, +/// 2.5, +/// Parallelism::None, +/// ); +/// +/// zipped!(acc.as_ref(), target.as_ref()) +/// .for_each(|acc, target| assert!((acc.read() - target.read()).abs() < 1e-10)); +/// ``` #[track_caller] pub fn matmul, RhsE: Conjugate>( acc: MatMut<'_, E>, @@ -2766,8 +2875,14 @@ pub mod triangular { } } - /// Computes the matrix product `[alpha * acc] + beta * Op_lhs(lhs) * Op_rhs(rhs)` and - /// stores the result in `acc`. + /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating + /// either or both of the input matrices) and stores the result in `acc`. + /// + /// Performs the operation: + /// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the + /// preexisting values in `acc` are not read, so it is allowed to be a view over uninitialized + /// values if `E: Copy`), + /// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`, /// /// The left hand side and right hand side may be interpreted as triangular depending on the /// given corresponding matrix structure. @@ -2779,9 +2894,6 @@ pub mod triangular { /// - only the strict triangular half (excluding the diagonal) is computed if the structure is /// strictly triangular or unit triangular. /// - /// If `alpha` is not provided, he preexisting values in `acc` are not read so it is allowed to - /// be a view over uninitialized values if `E: Copy`. - /// /// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is /// `Conj::Yes`. /// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is @@ -2885,6 +2997,74 @@ pub mod triangular { } } + /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and stores the result in + /// `acc`. + /// + /// Performs the operation: + /// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in + /// `acc` are not read, so it is allowed to be a view over uninitialized values if `E: Copy`), + /// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`, + /// + /// The left hand side and right hand side may be interpreted as triangular depending on the + /// given corresponding matrix structure. + /// + /// For the destination matrix, the result is: + /// - fully computed if the structure is rectangular, + /// - only the triangular half (including the diagonal) is computed if the structure is + /// triangular, + /// - only the strict triangular half (excluding the diagonal) is computed if the structure is + /// strictly triangular or unit triangular. + /// + /// # Panics + /// + /// Panics if the matrix dimensions are not compatible for matrix multiplication. + /// i.e. + /// - `acc.nrows() == lhs.nrows()` + /// - `acc.ncols() == rhs.ncols()` + /// - `lhs.ncols() == rhs.nrows()` + /// + /// Additionally, matrices that are marked as triangular must be square, i.e., they must have + /// the same number of rows and columns. + /// + /// # Example + /// + /// ``` + /// use faer_core::{ + /// mat, + /// mul::triangular::{matmul, BlockStructure}, + /// zipped, Conj, Mat, Parallelism, + /// }; + /// + /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; + /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; + /// + /// let mut acc = Mat::::zeros(2, 2); + /// let target = mat![ + /// [ + /// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), + /// 0.0, + /// ], + /// [ + /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), + /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), + /// ], + /// ]; + /// + /// matmul( + /// acc.as_mut(), + /// BlockStructure::TriangularLower, + /// lhs.as_ref(), + /// BlockStructure::Rectangular, + /// rhs.as_ref(), + /// BlockStructure::Rectangular, + /// None, + /// 2.5, + /// Parallelism::None, + /// ); + /// + /// zipped!(acc.as_ref(), target.as_ref()) + /// .for_each(|acc, target| assert!((acc.read() - target.read()).abs() < 1e-10)); + /// ``` #[track_caller] #[inline] pub fn matmul< diff --git a/faer-core/src/zip.rs b/faer-core/src/zip.rs index 4cff1ed5..cb949972 100644 --- a/faer-core/src/zip.rs +++ b/faer-core/src/zip.rs @@ -1,3 +1,5 @@ +//! Matrix zipping module. + use crate::{Entity, MatMut, MatRef}; use assert2::{assert, debug_assert}; use core::mem::MaybeUninit; @@ -8,6 +10,8 @@ mod seal { pub trait Seal {} } +/// Specifies whether the main diagonal should be traversed, when iterating over a triangular chunk +/// of the matrix. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Diag { /// Do not include diagonal of matrix @@ -16,9 +20,11 @@ pub enum Diag { Include, } +/// Read only view over a single matrix element. pub struct Read<'a, E: Entity> { ptr: E::Group<&'a MaybeUninit>, } +/// Read-write view over a single matrix element. pub struct ReadWrite<'a, E: Entity> { ptr: E::Group<&'a mut MaybeUninit>, } @@ -55,6 +61,7 @@ impl ReadWrite<'_, E> { } } +/// Internal trait for abstracting over [`MatRef`] and [`MatMut`]. pub trait Mat<'short, Outlives = &'short Self>: Seal { type Item; type RawSlice; @@ -74,6 +81,7 @@ pub trait Mat<'short, Outlives = &'short Self>: Seal { n_elems: usize, ) -> Self::RawSlice; + #[doc(hidden)] // this is a bad api since it needs to extend the lifetime of slice, but this is somewhat fine // since we only use it internally in this module unsafe fn get_slice_elem(slice: &mut Self::RawSlice, idx: usize) -> Self::Item; @@ -219,6 +227,7 @@ impl<'a, 'short, E: Entity> Mat<'short> for MatMut<'a, E> { } } +/// Structure holding matrix views with matching dimensions. pub struct Zip { pub(crate) tuple: Tuple, } diff --git a/faer-evd/Cargo.toml b/faer-evd/Cargo.toml index ccc22e77..84dee04c 100644 --- a/faer-evd/Cargo.toml +++ b/faer-evd/Cargo.toml @@ -24,7 +24,7 @@ coe-rs = { workspace = true } dbgf = "0.1.0" [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" diff --git a/faer-lu/Cargo.toml b/faer-lu/Cargo.toml index 99bed768..a025d421 100644 --- a/faer-lu/Cargo.toml +++ b/faer-lu/Cargo.toml @@ -29,7 +29,7 @@ std = ["faer-core/std", "pulp/std"] nightly = ["faer-core/nightly", "pulp/nightly"] [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" diff --git a/faer-qr/Cargo.toml b/faer-qr/Cargo.toml index 745da35a..19b85c87 100644 --- a/faer-qr/Cargo.toml +++ b/faer-qr/Cargo.toml @@ -27,7 +27,7 @@ std = ["faer-core/std", "pulp/std"] nightly = ["faer-core/nightly", "pulp/nightly"] [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0" diff --git a/faer-svd/Cargo.toml b/faer-svd/Cargo.toml index ad4548db..087fa0e8 100644 --- a/faer-svd/Cargo.toml +++ b/faer-svd/Cargo.toml @@ -23,7 +23,7 @@ bytemuck = { workspace = true } coe-rs = { workspace = true } [dev-dependencies] -criterion = "0.4" +criterion = "0.5" rand = "0.8.5" nalgebra = "0.32.2" assert_approx_eq = "1.1.0"