Skip to content

Commit

Permalink
fix sparse qr bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-quinones committed Sep 5, 2024
1 parent 9b896ea commit cb0d69e
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 126 deletions.
2 changes: 2 additions & 0 deletions src/col/col_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
// usize

use super::*;
use crate::assert;
use core::ops::RangeFull;

type Range = core::ops::Range<usize>;
type RangeInclusive = core::ops::RangeInclusive<usize>;
type RangeFrom = core::ops::RangeFrom<usize>;
Expand Down
2 changes: 1 addition & 1 deletion src/linalg/evd/hessenberg_real_evd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ fn schur_move<E: RealField>(
0
}

fn schur_swap<E: RealField>(
pub fn schur_swap<E: RealField>(
mut a: MatMut<E>,
mut q: Option<MatMut<E>>,
j0: usize,
Expand Down
198 changes: 102 additions & 96 deletions src/linalg/evd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,11 @@ fn solve_shifted_upper_quasi_triangular_system<E: RealField>(
x.write(i, x.read(i).faer_sub(dot1));

// solve
// [a b [x0 [r0
// c a]× x1] = r1]
// [a b] [x0] [r0]
// [c a]×[x1] = [r1]
//
// [x0 [a -b [r0
// x1] = -c a]× r1] / det
// [x0] [a -b] [r0]
// [x1] = [-c a]×[r1] / det
let a = h.read(i, i).faer_sub(p);
let b = h.read(i - 1, i);
let c = h.read(i, i - 1);
Expand Down Expand Up @@ -1103,25 +1103,7 @@ pub fn compute_evd_real_custom_epsilon<E: RealField>(
params,
);

let (mut x, _) = temp_mat_zeroed::<E>(n, n, stack);
let mut x = x.as_mut();

let mut norm = zero_threshold;
zipped!(h.rb()).for_each_triangular_upper(
crate::linalg::zip::Diag::Include,
|unzipped!(x)| {
norm = norm.faer_add(x.read().faer_abs());
},
);
// subdiagonal
zipped!(h
.rb()
.submatrix(1, 0, n - 1, n - 1)
.diagonal()
.column_vector())
.for_each(|unzipped!(x)| {
norm = norm.faer_add(x.read().faer_abs());
});
let (mut x, _) = temp_mat_uninit::<E>(n, n, stack);

let mut h = h.transpose_mut();

Expand All @@ -1140,79 +1122,7 @@ pub fn compute_evd_real_custom_epsilon<E: RealField>(
}
let h = h.rb();

{
let mut k = n;
loop {
if k == 0 {
break;
}
k -= 1;

if k == 0 || h.read(k, k - 1) == E::faer_zero() {
// real eigenvalue
let p = h.read(k, k);

x.write(k, k, E::faer_one());

// solve (h[:k, :k] - p I) X = -h[:k, k]
// form RHS
for i in 0..k {
x.write(i, k, h.read(i, k).faer_neg());
}

solve_shifted_upper_quasi_triangular_system(
h.get(..k, ..k),
p,
x.rb_mut().get_mut(..k, k),
epsilon,
norm,
parallelism,
);
} else {
// complex eigenvalue pair

let p = h.read(k, k);
let q = h
.read(k, k - 1)
.faer_abs()
.faer_sqrt()
.faer_mul(h.read(k - 1, k).faer_abs().faer_sqrt());

if h.read(k - 1, k).faer_abs() >= h.read(k, k - 1) {
x.write(k - 1, k - 1, E::faer_one());
x.write(k, k, q.faer_div(h.read(k - 1, k)));
} else {
x.write(k - 1, k - 1, q.faer_neg().faer_div(h.read(k, k - 1)));
x.write(k, k, E::faer_one());
}
x.write(k - 1, k, E::faer_zero());
x.write(k, k - 1, E::faer_zero());

// solve (h[:k-1, :k-1] - (p + iq) I) X = RHS
// form RHS
for i in 0..k - 1 {
x.write(
i,
k - 1,
x.read(k - 1, k - 1).faer_neg().faer_mul(h.read(i, k - 1)),
);
x.write(i, k, x.read(k, k).faer_neg().faer_mul(h.read(i, k)));
}

solve_complex_shifted_upper_quasi_triangular_system(
h.get(..k - 1, ..k - 1),
p,
q,
x.rb_mut().get_mut(..k - 1, k - 1..k + 1),
epsilon,
norm,
parallelism,
);

k -= 1;
}
}
}
real_schur_to_eigen(h, x.rb_mut(), parallelism);

triangular::matmul(
u.rb_mut(),
Expand Down Expand Up @@ -1243,6 +1153,102 @@ pub fn compute_evd_real_custom_epsilon<E: RealField>(
}
}

#[doc(hidden)]
pub fn real_schur_to_eigen<E: RealField>(S: MatRef<'_, E>, Q: MatMut<E>, parallelism: Parallelism) {
let epsilon = E::faer_epsilon();
let zero_threshold = E::Real::faer_zero_threshold();
let n = S.nrows();

let mut Q = Q;
Q.fill_zero();

let mut norm = zero_threshold;
zipped!(S.rb()).for_each_triangular_upper(crate::linalg::zip::Diag::Include, |unzipped!(x)| {
norm = norm.faer_add(x.read().faer_abs());
});
// subdiagonal
zipped!(S
.rb()
.submatrix(1, 0, n - 1, n - 1)
.diagonal()
.column_vector())
.for_each(|unzipped!(x)| {
norm = norm.faer_add(x.read().faer_abs());
});

let mut k = n;
loop {
if k == 0 {
break;
}
k -= 1;

if k == 0 || S.read(k, k - 1) == E::faer_zero() {
// real eigenvalue
let p = S.read(k, k);

Q.write(k, k, E::faer_one());

// solve (h[:k, :k] - p I) X = -h[:k, k]
// form RHS
for i in 0..k {
Q.write(i, k, S.read(i, k).faer_neg());
}

solve_shifted_upper_quasi_triangular_system(
S.get(..k, ..k),
p,
Q.rb_mut().get_mut(..k, k),
epsilon,
norm,
parallelism,
);
} else {
// complex eigenvalue pair

let p = S.read(k, k);
let q = S
.read(k, k - 1)
.faer_abs()
.faer_sqrt()
.faer_mul(S.read(k - 1, k).faer_abs().faer_sqrt());

if S.read(k - 1, k).faer_abs() >= S.read(k, k - 1) {
Q.write(k - 1, k - 1, E::faer_one());
Q.write(k, k, q.faer_div(S.read(k - 1, k)));
} else {
Q.write(k - 1, k - 1, q.faer_neg().faer_div(S.read(k, k - 1)));
Q.write(k, k, E::faer_one());
}
Q.write(k - 1, k, E::faer_zero());
Q.write(k, k - 1, E::faer_zero());

// solve (h[:k-1, :k-1] - (p + iq) I) X = RHS
// form RHS
for i in 0..k - 1 {
Q.write(
i,
k - 1,
Q.read(k - 1, k - 1).faer_neg().faer_mul(S.read(i, k - 1)),
);
Q.write(i, k, Q.read(k, k).faer_neg().faer_mul(S.read(i, k)));
}

solve_complex_shifted_upper_quasi_triangular_system(
S.get(..k - 1, ..k - 1),
p,
q,
Q.rb_mut().get_mut(..k - 1, k - 1..k + 1),
epsilon,
norm,
parallelism,
);

k -= 1;
}
}
}

/// Computes the size and alignment of required workspace for performing an eigenvalue
/// decomposition. The eigenvectors may be optionally computed.
pub fn compute_evd_req<E: ComplexField>(
Expand Down
6 changes: 3 additions & 3 deletions src/mat/mat_index.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::*;
use crate::assert;

// RangeFull
// Range
// RangeInclusive
// RangeTo
// RangeToInclusive
// usize

use super::*;
use crate::assert;
use core::ops::RangeFull;

type Range = core::ops::Range<usize>;
type RangeInclusive = core::ops::RangeInclusive<usize>;
type RangeFrom = core::ops::RangeFrom<usize>;
Expand Down
1 change: 1 addition & 0 deletions src/row/row_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// usize

use super::*;
use crate::assert;
use core::ops::RangeFull;

type Range = core::ops::Range<usize>;
Expand Down
Loading

0 comments on commit cb0d69e

Please sign in to comment.