Skip to content

Commit

Permalink
replace the old constrained api
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-quinones committed Oct 2, 2024
1 parent 2f4bd6a commit ff832b2
Show file tree
Hide file tree
Showing 31 changed files with 2,239 additions and 2,763 deletions.
84 changes: 63 additions & 21 deletions src/iter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{linalg::entity::GroupFor, ColMut, ColRef, Entity, MatMut, MatRef, RowMut, RowRef};
use crate::{
linalg::entity::GroupFor, mat, ColMut, ColRef, Entity, MatMut, MatRef, RowMut, RowRef, Shape,
};

use self::chunks::ChunkPolicy;

Expand Down Expand Up @@ -173,23 +175,23 @@ pub struct ElemIterMut<'a, E: Entity> {

/// Iterator over the columns of a matrix.
#[derive(Debug, Clone)]
pub struct ColIter<'a, E: Entity> {
pub(crate) inner: MatRef<'a, E>,
pub struct ColIter<'a, E: Entity, R: Shape = usize> {
pub(crate) inner: MatRef<'a, E, R>,
}
/// Iterator over the columns of a matrix.
#[derive(Debug)]
pub struct ColIterMut<'a, E: Entity> {
pub(crate) inner: MatMut<'a, E>,
pub struct ColIterMut<'a, E: Entity, R: Shape = usize> {
pub(crate) inner: MatMut<'a, E, R>,
}
/// Iterator over the rows of a matrix.
#[derive(Debug, Clone)]
pub struct RowIter<'a, E: Entity> {
pub(crate) inner: MatRef<'a, E>,
pub struct RowIter<'a, E: Entity, C: Shape = usize> {
pub(crate) inner: MatRef<'a, E, usize, C>,
}
/// Iterator over the rows of a matrix.
#[derive(Debug)]
pub struct RowIterMut<'a, E: Entity> {
pub(crate) inner: MatMut<'a, E>,
pub struct RowIterMut<'a, E: Entity, C: Shape = usize> {
pub(crate) inner: MatMut<'a, E, usize, C>,
}

impl<'a, E: Entity> Iterator for ElemIter<'a, E> {
Expand Down Expand Up @@ -272,12 +274,22 @@ impl<'a, E: Entity> ExactSizeIterator for ElemIterMut<'a, E> {
}
}

impl<'a, E: Entity> Iterator for ColIter<'a, E> {
type Item = ColRef<'a, E>;
impl<'a, E: Entity, R: Shape> Iterator for ColIter<'a, E, R> {
type Item = ColRef<'a, E, R>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match core::mem::take(&mut self.inner).split_first_col() {
let nrows = self.inner.nrows();
match core::mem::replace(
&mut self.inner,
mat::from_column_major_slice_generic(
E::faer_map(E::UNIT, |()| &[] as &[E::Unit]),
nrows,
0,
),
)
.split_first_col()
{
Some((head, tail)) => {
self.inner = tail;
Some(head)
Expand All @@ -291,10 +303,20 @@ impl<'a, E: Entity> Iterator for ColIter<'a, E> {
(self.inner.ncols(), Some(self.inner.ncols()))
}
}
impl<'a, E: Entity> DoubleEndedIterator for ColIter<'a, E> {
impl<'a, E: Entity, R: Shape> DoubleEndedIterator for ColIter<'a, E, R> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
match core::mem::take(&mut self.inner).split_last_col() {
let nrows = self.inner.nrows();
match core::mem::replace(
&mut self.inner,
mat::from_column_major_slice_generic(
E::faer_map(E::UNIT, |()| &[] as &[E::Unit]),
nrows,
0,
),
)
.split_last_col()
{
Some((head, tail)) => {
self.inner = tail;
Some(head)
Expand All @@ -303,19 +325,29 @@ impl<'a, E: Entity> DoubleEndedIterator for ColIter<'a, E> {
}
}
}
impl<'a, E: Entity> ExactSizeIterator for ColIter<'a, E> {
impl<'a, E: Entity, R: Shape> ExactSizeIterator for ColIter<'a, E, R> {
#[inline]
fn len(&self) -> usize {
self.inner.ncols()
}
}

impl<'a, E: Entity> Iterator for ColIterMut<'a, E> {
type Item = ColMut<'a, E>;
impl<'a, E: Entity, R: Shape> Iterator for ColIterMut<'a, E, R> {
type Item = ColMut<'a, E, R>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match core::mem::take(&mut self.inner).split_first_col_mut() {
let nrows = self.inner.nrows();
match core::mem::replace(
&mut self.inner,
mat::from_column_major_slice_mut_generic(
E::faer_map(E::UNIT, |()| &mut [] as &mut [E::Unit]),
nrows,
0,
),
)
.split_first_col_mut()
{
Some((head, tail)) => {
self.inner = tail;
Some(head)
Expand All @@ -329,10 +361,20 @@ impl<'a, E: Entity> Iterator for ColIterMut<'a, E> {
(self.inner.ncols(), Some(self.inner.ncols()))
}
}
impl<'a, E: Entity> DoubleEndedIterator for ColIterMut<'a, E> {
impl<'a, E: Entity, R: Shape> DoubleEndedIterator for ColIterMut<'a, E, R> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
match core::mem::take(&mut self.inner).split_last_col_mut() {
let nrows = self.inner.nrows();
match core::mem::replace(
&mut self.inner,
mat::from_column_major_slice_mut_generic(
E::faer_map(E::UNIT, |()| &mut [] as &mut [E::Unit]),
nrows,
0,
),
)
.split_last_col_mut()
{
Some((head, tail)) => {
self.inner = tail;
Some(head)
Expand All @@ -341,7 +383,7 @@ impl<'a, E: Entity> DoubleEndedIterator for ColIterMut<'a, E> {
}
}
}
impl<'a, E: Entity> ExactSizeIterator for ColIterMut<'a, E> {
impl<'a, E: Entity, R: Shape> ExactSizeIterator for ColIterMut<'a, E, R> {
#[inline]
fn len(&self) -> usize {
self.inner.ncols()
Expand Down
14 changes: 5 additions & 9 deletions src/linalg/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,22 +555,18 @@ pub mod inner_prod {
}
} else {
with_dim!(nrows, nrows);
with_dim!(ncols, 1);
let zero_idx = ncols.check(0);

let a = crate::utils::constrained::mat::MatRef::new(a.as_2d(), nrows, ncols);
let b = crate::utils::constrained::mat::MatRef::new(b.as_2d(), nrows, ncols);
let a = a.as_shape(nrows);
let b = b.as_shape(nrows);

let mut acc = E::faer_zero();
if conj_lhs == conj_rhs {
for i in nrows.indices() {
acc = acc.faer_add(E::faer_mul(a.read(i, zero_idx), b.read(i, zero_idx)));
acc = acc.faer_add(E::faer_mul(a.read(i), b.read(i)));
}
} else {
for i in nrows.indices() {
acc = acc.faer_add(E::faer_mul(
a.read(i, zero_idx).faer_conj(),
b.read(i, zero_idx),
));
acc = acc.faer_add(E::faer_mul(a.read(i).faer_conj(), b.read(i)));
}
}
acc
Expand Down
21 changes: 18 additions & 3 deletions src/mat/matmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
linalg::zip,
unzipped, zipped, Idx, IdxInc, Unbind,
};
use core::ops::Range;

/// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice].
///
Expand Down Expand Up @@ -1460,6 +1461,18 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> {
unsafe { self.into_const().subcols(col_start, ncols).const_cast() }
}

#[track_caller]
#[inline(always)]
pub fn subcols_range(self, cols: Range<IdxInc<C>>) -> MatRef<'a, E, R, usize> {
self.into_const().subcols_range(cols)
}

#[track_caller]
#[inline(always)]
pub fn subcols_range_mut(self, cols: Range<IdxInc<C>>) -> MatMut<'a, E, R, usize> {
unsafe { self.into_const().subcols_range(cols).const_cast() }
}

/// Returns a view over the row at the given index.
///
/// # Safety
Expand Down Expand Up @@ -1755,7 +1768,7 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> {

/// Returns an iterator over the columns of the matrix.
#[inline]
pub fn col_iter(self) -> iter::ColIter<'a, E> {
pub fn col_iter(self) -> iter::ColIter<'a, E, R> {
self.into_const().col_iter()
}

Expand All @@ -1767,9 +1780,11 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatMut<'a, E, R, C> {

/// Returns an iterator over the columns of the matrix.
#[inline]
pub fn col_iter_mut(self) -> iter::ColIterMut<'a, E> {
pub fn col_iter_mut(self) -> iter::ColIterMut<'a, E, R> {
let nrows = self.nrows();
let ncols = self.ncols();
iter::ColIterMut {
inner: self.as_dyn_mut(),
inner: self.as_shape_mut(nrows, ncols.unbound()),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/mat/matown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ impl<E: Entity, R: Shape, C: Shape> Mat<E, R, C> {

/// Returns an iterator over the columns of the matrix.
#[inline]
pub fn col_iter(&self) -> iter::ColIter<'_, E> {
pub fn col_iter(&self) -> iter::ColIter<'_, E, R> {
self.as_ref().col_iter()
}

Expand All @@ -441,7 +441,7 @@ impl<E: Entity, R: Shape, C: Shape> Mat<E, R, C> {

/// Returns an iterator over the columns of the matrix.
#[inline]
pub fn col_iter_mut(&mut self) -> iter::ColIterMut<'_, E> {
pub fn col_iter_mut(&mut self) -> iter::ColIterMut<'_, E, R> {
self.as_mut().col_iter_mut()
}

Expand Down
15 changes: 13 additions & 2 deletions src/mat/matref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
assert, debug_assert, diag::DiagRef, iter, iter::chunks::ChunkPolicy, unzipped,
utils::bound::*, zipped, Idx, IdxInc, Shape, Unbind,
};
use core::ops::Range;
use generativity::make_guard;

/// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice].
Expand Down Expand Up @@ -818,6 +819,14 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatRef<'a, E, R, C> {
unsafe { self.subcols_unchecked(col_start, ncols) }
}

#[track_caller]
#[inline(always)]
pub fn subcols_range(self, cols: Range<IdxInc<C>>) -> MatRef<'a, E, R, usize> {
assert!(all(cols.start <= self.ncols(), cols.end <= self.ncols()));
let ncols = cols.end.unbound().saturating_sub(cols.start.unbound());
unsafe { self.subcols_unchecked(cols.start, ncols) }
}

/// Returns a view over the row at the given index.
///
/// # Safety
Expand Down Expand Up @@ -1167,9 +1176,11 @@ impl<'a, E: Entity, R: Shape, C: Shape> MatRef<'a, E, R, C> {

/// Returns an iterator over the columns of the matrix.
#[inline]
pub fn col_iter(self) -> iter::ColIter<'a, E> {
pub fn col_iter(self) -> iter::ColIter<'a, E, R> {
let nrows = self.nrows();
let ncols = self.ncols();
iter::ColIter {
inner: self.as_dyn(),
inner: self.as_shape(nrows, ncols.unbound()),
}
}

Expand Down
16 changes: 7 additions & 9 deletions src/perm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, utils::constrained, *};
use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, *};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;

Expand Down Expand Up @@ -217,22 +217,20 @@ pub fn permute_rows<I: Index, E: ComplexField>(

with_dim!(m, src.nrows());
with_dim!(n, src.ncols());
let mut dst = constrained::mat::MatMut::new(dst, m, n);
let src = constrained::mat::MatRef::new(src, m, n);
let perm = constrained::perm::PermRef::new(perm_indices, m).arrays().0;
let mut dst = dst.as_shape_mut(m, n);
let src = src.as_shape(m, n);
let perm = perm_indices.as_shape(m).bound_arrays().0;

if dst.rb().into_inner().row_stride().unsigned_abs()
< dst.rb().into_inner().col_stride().unsigned_abs()
{
if dst.rb().row_stride().unsigned_abs() < dst.rb().col_stride().unsigned_abs() {
for j in n.indices() {
for i in m.indices() {
dst.rb_mut().write(i, j, src.read(perm[i].zx(), j));
}
}
} else {
for i in m.indices() {
let src_i = src.into_inner().row(perm[i].zx().unbound());
let mut dst_i = dst.rb_mut().into_inner().row_mut(i.unbound());
let src_i = src.row(perm[i].zx());
let mut dst_i = dst.rb_mut().row_mut(i);

dst_i.copy_from(src_i);
}
Expand Down
23 changes: 22 additions & 1 deletion src/perm/permref.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::*;
use crate::assert;
use crate::{
assert,
utils::bound::{Array, Dim},
};

/// Immutable permutation matrix view.
#[derive(Debug)]
Expand Down Expand Up @@ -176,3 +179,21 @@ impl<'a, I: Index, N: Shape> PermRef<'a, I, N> {
}
}
}

impl<'a, 'N, I: Index> PermRef<'a, I, Dim<'N>> {
/// Returns the permutation as an array.
#[inline]
pub fn bound_arrays(
self,
) -> (
&'a Array<'N, Idx<Dim<'N>, I>>,
&'a Array<'N, Idx<Dim<'N>, I>>,
) {
unsafe {
(
&*(self.forward as *const [Idx<Dim<'N>, I>] as *const Array<'N, Idx<Dim<'N>, I>>),
&*(self.inverse as *const [Idx<Dim<'N>, I>] as *const Array<'N, Idx<Dim<'N>, I>>),
)
}
}
}
2 changes: 1 addition & 1 deletion src/sparse/csc/matmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseColMatMut<'a, I, E, R, C

/// Returns the row indices.
#[inline]
pub fn row_indices(&self) -> &'a [Idx<R, I>] {
pub fn row_indices(&self) -> &'a [I] {
self.symbolic.row_ind
}

Expand Down
2 changes: 1 addition & 1 deletion src/sparse/csc/matown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ impl<I: Index, E: Entity, R: Shape, C: Shape> SparseColMat<I, E, R, C> {

/// Returns the row indices.
#[inline]
pub fn row_indices(&self) -> &'_ [Idx<R, I>] {
pub fn row_indices(&self) -> &'_ [I] {
&self.symbolic.row_ind
}

Expand Down
2 changes: 1 addition & 1 deletion src/sparse/csc/matref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ impl<'a, I: Index, E: Entity, R: Shape, C: Shape> SparseColMatRef<'a, I, E, R, C

/// Returns the row indices.
#[inline]
pub fn row_indices(&self) -> &'a [Idx<R, I>] {
pub fn row_indices(&self) -> &'a [I] {
self.symbolic.row_ind
}

Expand Down
Loading

0 comments on commit ff832b2

Please sign in to comment.