Skip to content

Commit

Permalink
add MatRef|MatMut::from_*_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah committed Sep 13, 2023
1 parent 2ecb1ba commit aa93ff5
Showing 1 changed file with 224 additions and 4 deletions.
228 changes: 224 additions & 4 deletions faer-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3153,7 +3153,7 @@ impl<E: Entity> Clone for MatImpl<E> {

/// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice].
///
/// # Note:
/// # Note
///
/// Unlike a slice, the data pointed to by `MatRef<'_, E>` is allowed to be partially or fully
/// uninitialized under certain conditions ([`std::mem::needs_drop::<E::Unit>()`] must be false). In
Expand All @@ -3175,13 +3175,48 @@ impl<E: Entity> Clone for MatRef<'_, E> {

/// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice].
///
/// # Note:
/// # Note
///
/// Unlike a slice, the data pointed to by `MatMut<'_, E>` is allowed to be partially or fully
/// uninitialized under certain conditions ([`std::mem::needs_drop::<E::Unit>()`] must be false). In
/// this case, care must be taken to not perform any operations that read the uninitialized values,
/// or form references to them, either directly through [`MatMut::read`], or indirectly through any
/// of the numerical library routines, unless it is explicitly permitted.
///
/// # Move semantics
/// Since `MatMut` mutably borrows data, it cannot be [`Copy`]. This means that if we pass a
/// `MatMut` to a function that takes it by value, or use a method that consumes `self` like
/// [`MatMut::transpose`], this renders the original variable unusable.
/// ```compile_fail
/// use faer_core::{Mat, MatMut};
///
/// fn takes_matmut(view: MatMut<'_, f64>) {}
///
/// let mut matrix = Mat::new();
/// let view = matrix.as_mut();
///
/// takes_matmut(view); // `view` is moved (passed by value)
/// takes_matmut(view); // this fails to compile since `view` was moved
/// ```
/// The way to get around it is to use the [`reborrow::ReborrowMut`] trait, which allows us to
/// mutably borrow a `MatMut` to obtain another `MatMut` for the lifetime of the borrow.
/// It's also similarly possible to immutably borrow a `MatMut` to obtain a `MatRef` for the
/// lifetime of the borrow, using [`reborrow::Reborrow`].
/// ```
/// use faer_core::{Mat, MatMut, MatRef};
/// use reborrow::*;
///
/// fn takes_matmut(view: MatMut<'_, f64>) {}
/// fn takes_matref(view: MatRef<'_, f64>) {}
///
/// let mut matrix = Mat::new();
/// let mut view = matrix.as_mut();
///
/// takes_matmut(view.rb_mut());
/// takes_matmut(view.rb_mut());
/// takes_matref(view.rb());
/// // view is still usable here
/// ```
pub struct MatMut<'a, E: Entity> {
inner: MatImpl<E>,
__marker: PhantomData<&'a mut E>,
Expand Down Expand Up @@ -3248,6 +3283,84 @@ pub fn par_split_indices(n: usize, idx: usize, chunk_count: usize) -> (usize, us
}

impl<'a, E: Entity> MatRef<'a, E> {
/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions.
/// The data is interpreted in a column-major format, so that the first chunk of `nrows`
/// values from the slices goes in the first column of the matrix, the second chunk of `nrows`
/// values goes in the second column, and so on.
///
/// # Panics
/// The function panics if any of the following conditions are violated:
/// * `nrows * ncols == slice.len()`
///
/// # Example
/// ```
/// use faer_core::{mat, MatRef};
///
/// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64];
/// let view = MatRef::<f64>::from_column_major_slice(&slice, 3, 2);
///
/// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
/// assert_eq!(expected, view);
/// ```
#[track_caller]
pub fn from_column_major_slice(
slice: E::Group<&'a [E::Unit]>,
nrows: usize,
ncols: usize,
) -> Self {
let size = usize::checked_mul(nrows, ncols).unwrap_or(usize::MAX);
// we don't have to worry about size == usize::MAX == slice.len(), because the length of a
// slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case
// we don't care
E::map(
E::copy(&slice),
#[inline(always)]
|slice| assert!(size == slice.len()),
);
unsafe {
Self::from_raw_parts(
E::map(
slice,
#[inline(always)]
|slice| slice.as_ptr(),
),
nrows,
ncols,
1,
nrows as isize,
)
}
}

/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions.
/// The data is interpreted in a row-major format, so that the first chunk of `ncols`
/// values from the slices goes in the first column of the matrix, the second chunk of `ncols`
/// values goes in the second column, and so on.
///
/// # Panics
/// The function panics if any of the following conditions are violated:
/// * `nrows * ncols == slice.len()`
///
/// # Example
/// ```
/// use faer_core::{mat, MatRef};
///
/// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64];
/// let view = MatRef::<f64>::from_row_major_slice(&slice, 3, 2);
///
/// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
/// assert_eq!(expected, view);
/// ```
#[inline(always)]
#[track_caller]
pub fn from_row_major_slice(
slice: E::Group<&'a [E::Unit]>,
nrows: usize,
ncols: usize,
) -> Self {
Self::from_column_major_slice(slice, ncols, nrows).transpose()
}

/// Creates a `MatRef` from pointers to the matrix data, dimensions, and strides.
///
/// The row (resp. column) stride is the offset from the memory address of a given matrix
Expand Down Expand Up @@ -3295,7 +3408,6 @@ impl<'a, E: Entity> MatRef<'a, E> {
row_stride: isize,
col_stride: isize,
) -> Self {
E::map(E::as_ref(&ptr), |ptr| debug_assert!(!ptr.is_null()));
Self {
inner: MatImpl {
ptr: E::into_copy(E::map(ptr, |ptr| ptr as *mut E::Unit)),
Expand Down Expand Up @@ -3852,6 +3964,84 @@ impl<'a, E: Entity> MatRef<'a, E> {
}

impl<'a, E: Entity> MatMut<'a, E> {
/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions.
/// The data is interpreted in a column-major format, so that the first chunk of `nrows`
/// values from the slices goes in the first column of the matrix, the second chunk of `nrows`
/// values goes in the second column, and so on.
///
/// # Panics
/// The function panics if any of the following conditions are violated:
/// * `nrows * ncols == slice.len()`
///
/// # Example
/// ```
/// use faer_core::{mat, MatMut};
///
/// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64];
/// let view = MatMut::<f64>::from_column_major_slice(&mut slice, 3, 2);
///
/// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
/// assert_eq!(expected, view);
/// ```
#[track_caller]
pub fn from_column_major_slice(
slice: E::Group<&'a mut [E::Unit]>,
nrows: usize,
ncols: usize,
) -> Self {
let size = usize::checked_mul(nrows, ncols).unwrap_or(usize::MAX);
// we don't have to worry about size == usize::MAX == slice.len(), because the length of a
// slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case
// we don't care
E::map(
E::as_ref(&slice),
#[inline(always)]
|slice| assert!(size == slice.len()),
);
unsafe {
Self::from_raw_parts(
E::map(
slice,
#[inline(always)]
|slice| slice.as_mut_ptr(),
),
nrows,
ncols,
1,
nrows as isize,
)
}
}

/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions.
/// The data is interpreted in a row-major format, so that the first chunk of `ncols`
/// values from the slices goes in the first column of the matrix, the second chunk of `ncols`
/// values goes in the second column, and so on.
///
/// # Panics
/// The function panics if any of the following conditions are violated:
/// * `nrows * ncols == slice.len()`
///
/// # Example
/// ```
/// use faer_core::{mat, MatMut};
///
/// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64];
/// let view = MatMut::<f64>::from_row_major_slice(&mut slice, 3, 2);
///
/// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
/// assert_eq!(expected, view);
/// ```
#[inline(always)]
#[track_caller]
pub fn from_row_major_slice(
slice: E::Group<&'a mut [E::Unit]>,
nrows: usize,
ncols: usize,
) -> Self {
Self::from_column_major_slice(slice, ncols, nrows).transpose()
}

/// Creates a `MatMut` from pointers to the matrix data, dimensions, and strides.
///
/// The row (resp. column) stride is the offset from the memory address of a given matrix
Expand Down Expand Up @@ -3902,7 +4092,6 @@ impl<'a, E: Entity> MatMut<'a, E> {
row_stride: isize,
col_stride: isize,
) -> Self {
E::map(E::as_ref(&ptr), |ptr| debug_assert!(!ptr.is_null()));
Self {
inner: MatImpl {
ptr: E::into_copy(ptr),
Expand Down Expand Up @@ -6160,4 +6349,35 @@ mod tests {
x_mut *= 2.0;
assert_eq!(x, expected);
}

#[test]
fn from_slice() {
let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64];

let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]];
let view = MatRef::<f64>::from_column_major_slice(&slice, 3, 2);
assert_eq!(expected, view);
let view = MatMut::<f64>::from_column_major_slice(&mut slice, 3, 2);
assert_eq!(expected, view);

let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let view = MatRef::<f64>::from_row_major_slice(&slice, 3, 2);
assert_eq!(expected, view);
let view = MatMut::<f64>::from_row_major_slice(&mut slice, 3, 2);
assert_eq!(expected, view);
}

#[test]
#[should_panic]
fn from_slice_too_big() {
let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0_f64];
MatRef::<f64>::from_column_major_slice(&slice, 3, 2);
}

#[test]
#[should_panic]
fn from_slice_too_small() {
let slice = [1.0, 2.0, 3.0, 4.0, 5.0_f64];
MatRef::<f64>::from_column_major_slice(&slice, 3, 2);
}
}

0 comments on commit aa93ff5

Please sign in to comment.