diff --git a/src/col/colmut.rs b/src/col/colmut.rs index a50200dd..16ba5f1b 100644 --- a/src/col/colmut.rs +++ b/src/col/colmut.rs @@ -419,6 +419,43 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> { #[inline(always)] #[track_caller] pub fn at_mut(self, row: R::Idx) -> Mut<'a, E> { + assert!(row < self.nrows()); + unsafe { + E::faer_map( + self.ptr_inbounds_at_mut(row), + #[inline(always)] + |ptr| &mut *ptr, + ) + } + } + + /// Returns references to the element at the given index. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub unsafe fn at_unchecked(self, row: R::Idx) -> Ref<'a, E> { + self.into_const().at_unchecked(row) + } + + /// Returns references to the element at the given index. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub unsafe fn at_mut_unchecked(self, row: R::Idx) -> Mut<'a, E> { unsafe { E::faer_map( self.ptr_inbounds_at_mut(row), diff --git a/src/row/mod.rs b/src/row/mod.rs index bd237f53..a73ab962 100644 --- a/src/row/mod.rs +++ b/src/row/mod.rs @@ -2,7 +2,7 @@ use crate::{ col::{VecImpl, VecOwnImpl}, mat::*, utils::slice::*, - Conj, + Conj, Shape, }; use coe::Coerce; use core::{marker::PhantomData, ptr::NonNull}; @@ -25,28 +25,34 @@ pub trait RowIndex: crate::seal::Seal + Sized { /// Trait for types that can be converted to a row view. pub trait AsRowRef { + type C: Shape; + /// Convert to a row view. - fn as_row_ref(&self) -> RowRef<'_, E>; + fn as_row_ref(&self) -> RowRef<'_, E, Self::C>; } /// Trait for types that can be converted to a mutable row view. pub trait AsRowMut: AsRowRef { /// Convert to a mutable row view. - fn as_row_mut(&mut self) -> RowMut<'_, E>; + fn as_row_mut(&mut self) -> RowMut<'_, E, Self::C>; } impl> AsRowRef for &T { - fn as_row_ref(&self) -> RowRef<'_, E> { + type C = T::C; + + fn as_row_ref(&self) -> RowRef<'_, E, Self::C> { (**self).as_row_ref() } } impl> AsRowRef for &mut T { - fn as_row_ref(&self) -> RowRef<'_, E> { + type C = T::C; + + fn as_row_ref(&self) -> RowRef<'_, E, Self::C> { (**self).as_row_ref() } } impl> AsRowMut for &mut T { - fn as_row_mut(&mut self) -> RowMut<'_, E> { + fn as_row_mut(&mut self) -> RowMut<'_, E, Self::C> { (**self).as_row_mut() } } diff --git a/src/row/rowmut.rs b/src/row/rowmut.rs index e5aa01e6..d8121691 100644 --- a/src/row/rowmut.rs +++ b/src/row/rowmut.rs @@ -233,9 +233,7 @@ impl<'a, E: Entity, C: Shape> RowMut<'a, E, C> { let col_stride = self.col_stride(); unsafe { from_raw_parts_mut(self.as_ptr_mut(), ncols, col_stride) } } -} -impl<'a, E: Entity> RowMut<'a, E> { /// Splits the column vector at the given index into two parts and /// returns an array of each subvector, in the following order: /// * left. @@ -246,7 +244,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn split_at_unchecked(self, col: usize) -> (RowRef<'a, E>, RowRef<'a, E>) { + pub unsafe fn split_at_unchecked(self, col: C::IdxInc) -> (RowRef<'a, E>, RowRef<'a, E>) { self.into_const().split_at_unchecked(col) } @@ -260,7 +258,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn split_at_mut_unchecked(self, col: usize) -> (Self, Self) { + pub unsafe fn split_at_mut_unchecked(self, col: C::IdxInc) -> (RowMut<'a, E>, RowMut<'a, E>) { let (left, right) = self.into_const().split_at_unchecked(col); unsafe { (left.const_cast(), right.const_cast()) } } @@ -275,7 +273,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub fn split_at(self, col: usize) -> (RowRef<'a, E>, RowRef<'a, E>) { + pub fn split_at(self, col: C::IdxInc) -> (RowRef<'a, E>, RowRef<'a, E>) { self.into_const().split_at(col) } @@ -289,7 +287,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub fn split_at_mut(self, col: usize) -> (Self, Self) { + pub fn split_at_mut(self, col: C::IdxInc) -> (RowMut<'a, E>, RowMut<'a, E>) { assert!(col <= self.ncols()); unsafe { self.split_at_mut_unchecked(col) } } @@ -309,9 +307,9 @@ impl<'a, E: Entity> RowMut<'a, E> { pub unsafe fn get_unchecked( self, col: ColRange, - ) -> as RowIndex>::Target + ) -> as RowIndex>::Target where - RowRef<'a, E>: RowIndex, + RowRef<'a, E, C>: RowIndex, { self.into_const().get_unchecked(col) } @@ -328,9 +326,9 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col` must be contained in `[0, self.ncols())`. #[inline(always)] #[track_caller] - pub fn get(self, col: ColRange) -> as RowIndex>::Target + pub fn get(self, col: ColRange) -> as RowIndex>::Target where - RowRef<'a, E>: RowIndex, + RowRef<'a, E, C>: RowIndex, { self.into_const().get(col) } @@ -376,6 +374,79 @@ impl<'a, E: Entity> RowMut<'a, E> { >::get(self, col) } + /// Returns references to the element at the given index, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn at(self, col: C::Idx) -> Ref<'a, E> { + self.into_const().at(col) + } + + /// Returns references to the element at the given index, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn at_mut(self, col: C::Idx) -> Mut<'a, E> { + assert!(col < self.ncols()); + unsafe { + E::faer_map( + self.ptr_inbounds_at_mut(col), + #[inline(always)] + |ptr| &mut *ptr, + ) + } + } + + /// Returns references to the element at the given index. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn at_unchecked(self, col: C::Idx) -> Ref<'a, E> { + self.into_const().at_unchecked(col) + } + + /// Returns references to the element at the given index. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn at_mut_unchecked(self, col: C::Idx) -> Mut<'a, E> { + unsafe { + E::faer_map( + self.ptr_inbounds_at_mut(col), + #[inline(always)] + |ptr| &mut *ptr, + ) + } + } + /// Reads the value of the element at the given index. /// /// # Safety @@ -383,7 +454,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn read_unchecked(&self, col: usize) -> E { + pub unsafe fn read_unchecked(&self, col: C::Idx) -> E { self.rb().read_unchecked(col) } @@ -394,7 +465,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub fn read(&self, col: usize) -> E { + pub fn read(&self, col: C::Idx) -> E { self.rb().read(col) } @@ -405,7 +476,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn write_unchecked(&mut self, col: usize, value: E) { + pub unsafe fn write_unchecked(&mut self, col: C::Idx, value: E) { let units = value.faer_into_units(); let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(col)); E::faer_map( @@ -422,7 +493,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub fn write(&mut self, col: usize, value: E) { + pub fn write(&mut self, col: C::Idx, value: E) { assert!(col < self.ncols()); unsafe { self.write_unchecked(col, value) }; } @@ -433,14 +504,17 @@ impl<'a, E: Entity> RowMut<'a, E> { /// The function panics if any of the following conditions are violated: /// * `self.ncols() == other.ncols()`. #[track_caller] - pub fn copy_from>(&mut self, other: impl AsRowRef) { + pub fn copy_from>( + &mut self, + other: impl AsRowRef, + ) { #[track_caller] #[inline(always)] - fn implementation>( - this: RowMut<'_, E>, - other: RowRef<'_, ViewE>, + fn implementation>( + this: RowMut<'_, E, C>, + other: RowRef<'_, ViewE, C>, ) { - zipped!(this.as_2d_mut(), other.as_2d()) + zipped!(this, other) .for_each(|unzipped!(mut dst, src)| dst.write(src.read().canonicalize())); } implementation(self.rb_mut(), other.as_row_ref()) @@ -470,21 +544,21 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the transpose of `self`. #[inline(always)] #[must_use] - pub fn transpose(self) -> ColRef<'a, E> { + pub fn transpose(self) -> ColRef<'a, E, C> { self.into_const().transpose() } /// Returns a view over the transpose of `self`. #[inline(always)] #[must_use] - pub fn transpose_mut(self) -> ColMut<'a, E> { + pub fn transpose_mut(self) -> ColMut<'a, E, C> { unsafe { self.into_const().transpose().const_cast() } } /// Returns a view over the conjugate of `self`. #[inline(always)] #[must_use] - pub fn conjugate(self) -> RowRef<'a, E::Conj> + pub fn conjugate(self) -> RowRef<'a, E::Conj, C> where E: Conjugate, { @@ -494,7 +568,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the conjugate of `self`. #[inline(always)] #[must_use] - pub fn conjugate_mut(self) -> RowMut<'a, E::Conj> + pub fn conjugate_mut(self) -> RowMut<'a, E::Conj, C> where E: Conjugate, { @@ -503,7 +577,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the conjugate transpose of `self`. #[inline(always)] - pub fn adjoint(self) -> ColRef<'a, E::Conj> + pub fn adjoint(self) -> ColRef<'a, E::Conj, C> where E: Conjugate, { @@ -512,7 +586,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the conjugate transpose of `self`. #[inline(always)] - pub fn adjoint_mut(self) -> ColMut<'a, E::Conj> + pub fn adjoint_mut(self) -> ColMut<'a, E::Conj, C> where E: Conjugate, { @@ -522,7 +596,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the canonical representation of `self`, as well as a flag declaring /// whether `self` is implicitly conjugated or not. #[inline(always)] - pub fn canonicalize(self) -> (RowRef<'a, E::Canonical>, Conj) + pub fn canonicalize(self) -> (RowRef<'a, E::Canonical, C>, Conj) where E: Conjugate, { @@ -532,7 +606,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the canonical representation of `self`, as well as a flag declaring /// whether `self` is implicitly conjugated or not. #[inline(always)] - pub fn canonicalize_mut(self) -> (RowMut<'a, E::Canonical>, Conj) + pub fn canonicalize_mut(self) -> (RowMut<'a, E::Canonical, C>, Conj) where E: Conjugate, { @@ -543,7 +617,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the `self`, with the columns in reversed order. #[inline(always)] #[must_use] - pub fn reverse_cols(self) -> RowRef<'a, E> { + pub fn reverse_cols(self) -> RowRef<'a, E, C> { self.into_const().reverse_cols() } @@ -563,7 +637,11 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> RowRef<'a, E> { + pub unsafe fn subcols_unchecked( + self, + col_start: C::IdxInc, + ncols: H, + ) -> RowRef<'a, E, H> { self.into_const().subcols_unchecked(col_start, ncols) } @@ -576,7 +654,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub fn subcols(self, col_start: usize, ncols: usize) -> RowRef<'a, E> { + pub fn subcols(self, col_start: C::IdxInc, ncols: H) -> RowRef<'a, E, H> { self.into_const().subcols(col_start, ncols) } @@ -589,7 +667,11 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub unsafe fn subcols_mut_unchecked(self, col_start: usize, ncols: usize) -> Self { + pub unsafe fn subcols_mut_unchecked( + self, + col_start: C::IdxInc, + ncols: H, + ) -> RowMut<'a, E, H> { self.into_const() .subcols_unchecked(col_start, ncols) .const_cast() @@ -604,7 +686,7 @@ impl<'a, E: Entity> RowMut<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub fn subcols_mut(self, col_start: usize, ncols: usize) -> Self { + pub fn subcols_mut(self, col_start: C::IdxInc, ncols: H) -> RowMut<'a, E, H> { unsafe { self.into_const().subcols(col_start, ncols).const_cast() } } @@ -710,7 +792,7 @@ impl<'a, E: Entity> RowMut<'a, E> { #[inline] pub fn try_as_slice_mut(self) -> Option> { if self.col_stride() == 1 { - let len = self.ncols(); + let len = self.ncols().unbound(); Some(E::faer_map( self.as_ptr_mut(), #[inline(always)] @@ -730,7 +812,7 @@ impl<'a, E: Entity> RowMut<'a, E> { self, ) -> Option]>> { if self.col_stride() == 1 { - let len = self.ncols(); + let len = self.ncols().unbound(); Some(E::faer_map( self.as_ptr_mut(), #[inline(always)] @@ -743,13 +825,13 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns a view over the matrix. #[inline] - pub fn as_ref(&self) -> RowRef<'_, E> { + pub fn as_ref(&self) -> RowRef<'_, E, C> { (*self).rb() } /// Returns a mutable view over the matrix. #[inline] - pub fn as_mut(&mut self) -> RowMut<'_, E> { + pub fn as_mut(&mut self) -> RowMut<'_, E, C> { (*self).rb_mut() } @@ -771,11 +853,12 @@ impl<'a, E: Entity> RowMut<'a, E> { /// non-empty, otherwise `None`. #[inline] pub fn split_first_mut(self) -> Option<(GroupFor, RowMut<'a, E>)> { - if self.ncols() == 0 { + let this = self.as_dyn_mut(); + if this.ncols() == 0 { None } else { unsafe { - let (head, tail) = { self.split_at_mut_unchecked(1) }; + let (head, tail) = { this.split_at_mut_unchecked(1) }; Some((head.get_mut_unchecked(0), tail)) } } @@ -785,12 +868,13 @@ impl<'a, E: Entity> RowMut<'a, E> { /// non-empty, otherwise `None`. #[inline] pub fn split_last_mut(self) -> Option<(GroupFor, RowMut<'a, E>)> { - if self.ncols() == 0 { + let this = self.as_dyn_mut(); + if this.ncols() == 0 { None } else { - let ncols = self.ncols(); + let ncols = this.ncols(); unsafe { - let (head, tail) = { self.split_at_mut_unchecked(ncols - 1) }; + let (head, tail) = { this.split_at_mut_unchecked(ncols - 1) }; Some((tail.get_mut_unchecked(0), head)) } } @@ -799,16 +883,14 @@ impl<'a, E: Entity> RowMut<'a, E> { /// Returns an iterator over the elements of the row. #[inline] pub fn iter(self) -> iter::ElemIter<'a, E> { - iter::ElemIter { - inner: self.into_const().transpose(), - } + self.into_const().iter() } /// Returns an iterator over the elements of the row. #[inline] pub fn iter_mut(self) -> iter::ElemIterMut<'a, E> { iter::ElemIterMut { - inner: self.transpose_mut(), + inner: self.transpose_mut().as_dyn_mut(), } } @@ -864,9 +946,9 @@ impl<'a, E: Entity> RowMut<'a, E> { #[track_caller] pub fn chunks_mut(self, chunk_size: usize) -> iter::RowElemChunksMut<'a, E> { assert!(chunk_size > 0); - let ncols = self.ncols(); + let ncols = self.ncols().unbound(); iter::RowElemChunksMut { - inner: self, + inner: self.as_dyn_mut(), policy: iter::chunks::ChunkSizePolicy::new(ncols, iter::chunks::ChunkSize(chunk_size)), } } @@ -877,9 +959,9 @@ impl<'a, E: Entity> RowMut<'a, E> { #[track_caller] pub fn partition_mut(self, count: usize) -> iter::RowElemPartitionMut<'a, E> { assert!(count > 0); - let ncols = self.ncols(); + let ncols = self.ncols().unbound(); iter::RowElemPartitionMut { - inner: self, + inner: self.as_dyn_mut(), policy: iter::chunks::PartitionCountPolicy::new( ncols, iter::chunks::PartitionCount(count), @@ -925,7 +1007,7 @@ impl<'a, E: Entity> RowMut<'a, E> { #[doc(hidden)] #[inline(always)] - pub unsafe fn const_cast(self) -> RowMut<'a, E> { + pub unsafe fn const_cast(self) -> RowMut<'a, E, C> { self } } @@ -1009,16 +1091,18 @@ impl core::ops::IndexMut for RowMut<'_, E> { } } -impl AsRowRef for RowMut<'_, E> { +impl AsRowRef for RowMut<'_, E, C> { + type C = C; + #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { + fn as_row_ref(&self) -> RowRef<'_, E, C> { (*self).rb() } } -impl AsRowMut for RowMut<'_, E> { +impl AsRowMut for RowMut<'_, E, C> { #[inline] - fn as_row_mut(&mut self) -> RowMut<'_, E> { + fn as_row_mut(&mut self) -> RowMut<'_, E, C> { (*self).rb_mut() } } diff --git a/src/row/rowown.rs b/src/row/rowown.rs index 407bfa07..fa2fee83 100644 --- a/src/row/rowown.rs +++ b/src/row/rowown.rs @@ -675,7 +675,10 @@ impl Row { /// Copies the values from `other` into `self`. #[inline(always)] #[track_caller] - pub fn copy_from>(&mut self, other: impl AsRowRef) { + pub fn copy_from>( + &mut self, + other: impl AsRowRef, + ) { #[track_caller] #[inline(always)] fn implementation>( @@ -1056,6 +1059,8 @@ impl As2DMut for Row { } impl AsRowRef for Row { + type C = usize; + #[inline] fn as_row_ref(&self) -> RowRef<'_, E> { (*self).as_ref() diff --git a/src/row/rowref.rs b/src/row/rowref.rs index 43c5851d..b31185dd 100644 --- a/src/row/rowref.rs +++ b/src/row/rowref.rs @@ -187,9 +187,7 @@ impl<'a, E: Entity, C: Shape> RowRef<'a, E, C> { __marker: PhantomData, } } -} -impl<'a, E: Entity> RowRef<'a, E> { /// Splits the column vector at the given index into two parts and /// returns an array of each subvector, in the following order: /// * left. @@ -200,26 +198,28 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn split_at_unchecked(self, col: usize) -> (Self, Self) { + pub unsafe fn split_at_unchecked( + self, + col: C::IdxInc, + ) -> (RowRef<'a, E, usize>, RowRef<'a, E, usize>) { debug_assert!(col <= self.ncols()); - let col_stride = self.col_stride(); - - let ncols = self.ncols(); + let ncols = self.ncols().unbound(); unsafe { let top = self.as_ptr(); let bot = self.overflowing_ptr_at(col); + let col = col.unbound(); ( - Self::__from_raw_parts(top, col, col_stride), - Self::__from_raw_parts(bot, ncols - col, col_stride), + RowRef::__from_raw_parts(top, col, col_stride), + RowRef::__from_raw_parts(bot, ncols - col, col_stride), ) } } /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: + /// returns an array of each subvector, in the following order, C: /// * top. /// * bottom. /// @@ -228,7 +228,7 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `col <= self.ncols()`. #[inline(always)] #[track_caller] - pub fn split_at(self, col: usize) -> (Self, Self) { + pub fn split_at(self, col: C::IdxInc) -> (RowRef<'a, E, usize>, RowRef<'a, E, usize>) { assert!(col <= self.ncols()); unsafe { self.split_at_unchecked(col) } } @@ -242,7 +242,7 @@ impl<'a, E: Entity> RowRef<'a, E> { /// /// # Safety /// The behavior is undefined if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. + /// * `col` must be contained in `[0, self.nc, Cols())`. #[inline(always)] #[track_caller] pub unsafe fn get_unchecked( @@ -274,6 +274,38 @@ impl<'a, E: Entity> RowRef<'a, E> { >::get(self, col) } + /// Returns references to the element at the given index, or subvector if `row` is a + /// range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.nc, Cols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn at_unchecked(self, col: C::Idx) -> Ref<'a, E> { + self.transpose().at_unchecked(col) + } + + /// Returns references to the element at the given index, or subvector if `col` is a + /// range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn at(self, col: C::Idx) -> Ref<'a, E> { + self.transpose().at(col) + } + /// Reads the value of the element at the given index. /// /// # Safety @@ -281,9 +313,9 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub unsafe fn read_unchecked(&self, col: usize) -> E { + pub unsafe fn read_unchecked(&self, col: C::Idx) -> E { E::faer_from_units(E::faer_map( - self.get_unchecked(col), + self.at_unchecked(col), #[inline(always)] |ptr| *ptr, )) @@ -296,9 +328,9 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `col < self.ncols()`. #[inline(always)] #[track_caller] - pub fn read(&self, col: usize) -> E { + pub fn read(&self, col: C::Idx) -> E { E::faer_from_units(E::faer_map( - self.get(col), + self.at(col), #[inline(always)] |ptr| *ptr, )) @@ -307,14 +339,14 @@ impl<'a, E: Entity> RowRef<'a, E> { /// Returns a view over the transpose of `self`. #[inline(always)] #[must_use] - pub fn transpose(self) -> ColRef<'a, E> { + pub fn transpose(self) -> ColRef<'a, E, C> { unsafe { ColRef::__from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride()) } } /// Returns a view over the conjugate of `self`. #[inline(always)] #[must_use] - pub fn conjugate(self) -> RowRef<'a, E::Conj> + pub fn conjugate(self) -> RowRef<'a, E::Conj, C> where E: Conjugate, { @@ -334,7 +366,7 @@ impl<'a, E: Entity> RowRef<'a, E> { /// Returns a view over the conjugate transpose of `self`. #[inline(always)] - pub fn adjoint(self) -> ColRef<'a, E::Conj> + pub fn adjoint(self) -> ColRef<'a, E::Conj, C> where E: Conjugate, { @@ -344,7 +376,7 @@ impl<'a, E: Entity> RowRef<'a, E> { /// Returns a view over the canonical representation of `self`, as well as a flag declaring /// whether `self` is implicitly conjugated or not. #[inline(always)] - pub fn canonicalize(self) -> (RowRef<'a, E::Canonical>, Conj) + pub fn canonicalize(self) -> (RowRef<'a, E::Canonical, C>, Conj) where E: Conjugate, { @@ -375,7 +407,7 @@ impl<'a, E: Entity> RowRef<'a, E> { let ncols = self.ncols(); let col_stride = self.col_stride().wrapping_neg(); - let ptr = unsafe { self.ptr_at_unchecked(ncols.saturating_sub(1)) }; + let ptr = unsafe { self.ptr_at_unchecked(ncols.unbound().saturating_sub(1)) }; unsafe { Self::__from_raw_parts(ptr, ncols, col_stride) } } @@ -388,11 +420,19 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> Self { + pub unsafe fn subcols_unchecked( + self, + col_start: C::IdxInc, + ncols: H, + ) -> RowRef<'a, E, H> { debug_assert!(col_start <= self.ncols()); - debug_assert!(ncols <= self.ncols() - col_start); + { + let ncols = ncols.unbound(); + let col_start = col_start.unbound(); + debug_assert!(ncols <= self.ncols().unbound() - col_start); + } let col_stride = self.col_stride(); - unsafe { Self::__from_raw_parts(self.overflowing_ptr_at(col_start), ncols, col_stride) } + unsafe { RowRef::__from_raw_parts(self.overflowing_ptr_at(col_start), ncols, col_stride) } } /// Returns a view over the subvector starting at col `col_start`, and with number of cols @@ -404,9 +444,13 @@ impl<'a, E: Entity> RowRef<'a, E> { /// * `ncols <= self.ncols() - col_start`. #[track_caller] #[inline(always)] - pub fn subcols(self, col_start: usize, ncols: usize) -> Self { + pub fn subcols(self, col_start: C::IdxInc, ncols: H) -> RowRef<'a, E, H> { assert!(col_start <= self.ncols()); - assert!(ncols <= self.ncols() - col_start); + { + let ncols = ncols.unbound(); + let col_start = col_start.unbound(); + assert!(ncols <= self.ncols().unbound() - col_start); + } unsafe { self.subcols_unchecked(col_start, ncols) } } @@ -416,13 +460,7 @@ impl<'a, E: Entity> RowRef<'a, E> { where E: Conjugate, { - let mut mat = Row::new(); - mat.resize_with( - self.ncols(), - #[inline(always)] - |col| unsafe { self.read_unchecked(col).canonicalize() }, - ); - mat + crate::zipped!(self).map(|crate::unzipped!(x)| x.read().canonicalize()) } /// Returns `true` if any of the elements is NaN, otherwise returns `false`. @@ -497,7 +535,7 @@ impl<'a, E: Entity> RowRef<'a, E> { where E: ComplexField, { - self.as_2d_ref().kron(rhs) + self.as_2d().kron(rhs) } /// Returns the row as a contiguous slice if its column stride is equal to `1`. @@ -508,7 +546,7 @@ impl<'a, E: Entity> RowRef<'a, E> { #[inline] pub fn try_as_slice(self) -> Option> { if self.col_stride() == 1 { - let len = self.ncols(); + let len = self.ncols().unbound(); Some(E::faer_map( self.as_ptr(), #[inline(always)] @@ -521,7 +559,7 @@ impl<'a, E: Entity> RowRef<'a, E> { /// Returns a view over the matrix. #[inline] - pub fn as_ref(&self) -> RowRef<'_, E> { + pub fn as_ref(&self) -> RowRef<'_, E, C> { *self } @@ -529,11 +567,12 @@ impl<'a, E: Entity> RowRef<'a, E> { /// non-empty, otherwise `None`. #[inline] pub fn split_first(self) -> Option<(GroupFor, RowRef<'a, E>)> { - if self.ncols() == 0 { + let this = self.as_dyn(); + if this.ncols() == 0 { None } else { unsafe { - let (head, tail) = { self.split_at_unchecked(1) }; + let (head, tail) = { this.split_at_unchecked(1) }; Some((head.get_unchecked(0), tail)) } } @@ -543,11 +582,12 @@ impl<'a, E: Entity> RowRef<'a, E> { /// non-empty, otherwise `None`. #[inline] pub fn split_last(self) -> Option<(GroupFor, RowRef<'a, E>)> { - if self.ncols() == 0 { + let this = self.as_dyn(); + if this.ncols() == 0 { None } else { unsafe { - let (head, tail) = { self.split_at_unchecked(self.ncols() - 1) }; + let (head, tail) = { this.split_at_unchecked(this.ncols() - 1) }; Some((tail.get_unchecked(0), head)) } } @@ -557,7 +597,7 @@ impl<'a, E: Entity> RowRef<'a, E> { #[inline] pub fn iter(self) -> iter::ElemIter<'a, E> { iter::ElemIter { - inner: self.transpose(), + inner: self.transpose().as_dyn(), } } @@ -568,9 +608,9 @@ impl<'a, E: Entity> RowRef<'a, E> { pub fn chunks(self, chunk_size: usize) -> iter::RowElemChunks<'a, E> { assert!(chunk_size > 0); iter::RowElemChunks { - inner: self, + inner: self.as_dyn(), policy: iter::chunks::ChunkSizePolicy::new( - self.ncols(), + self.ncols().unbound(), iter::chunks::ChunkSize(chunk_size), ), } @@ -583,9 +623,9 @@ impl<'a, E: Entity> RowRef<'a, E> { pub fn partition(self, count: usize) -> iter::RowElemPartition<'a, E> { assert!(count > 0); iter::RowElemPartition { - inner: self, + inner: self.as_dyn(), policy: iter::chunks::PartitionCountPolicy::new( - self.ncols(), + self.ncols().unbound(), iter::chunks::PartitionCount(count), ), } @@ -603,15 +643,11 @@ impl<'a, E: Entity> RowRef<'a, E> { self, chunk_size: usize, ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - use crate::utils::DivCeil; use rayon::prelude::*; - assert!(chunk_size > 0); - let chunk_count = self.ncols().msrv_div_ceil(chunk_size); - (0..chunk_count).into_par_iter().map(move |chunk_idx| { - let pos = chunk_size * chunk_idx; - self.subcols(pos, Ord::min(chunk_size, self.ncols() - pos)) - }) + self.transpose() + .par_chunks(chunk_size) + .map(|x| x.transpose()) } /// Returns an iterator that provides exactly `count` successive chunks of the elements of this @@ -628,12 +664,7 @@ impl<'a, E: Entity> RowRef<'a, E> { ) -> impl 'a + rayon::iter::IndexedParallelIterator> { use rayon::prelude::*; - assert!(count > 0); - (0..count).into_par_iter().map(move |chunk_idx| { - let (start, len) = - crate::utils::thread::par_split_indices(self.ncols(), chunk_idx, count); - self.subcols(start, len) - }) + self.transpose().par_partition(count).map(|x| x.transpose()) } } @@ -685,9 +716,11 @@ impl As2D for RowRef<'_, E> { } } -impl AsRowRef for RowRef<'_, E> { +impl AsRowRef for RowRef<'_, E, C> { + type C = C; + #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { + fn as_row_ref(&self) -> RowRef<'_, E, C> { *self } }