Skip to content

Commit

Permalink
Merge pull request #24 from SpM-lab/tsvd_debug
Browse files Browse the repository at this point in the history
Fix bugs in `tsvd`
  • Loading branch information
terasakisatoshi authored Nov 7, 2024
2 parents ecd1926 + ad98e33 commit c5988dc
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 29 deletions.
16 changes: 13 additions & 3 deletions include/sparseir/_linalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ void reflectorApply(Eigen::VectorBlock<Eigen::Block<Eigen::MatrixX<T>, -1, 1, tr
}
}

/*
A will be modified in-place.
*/
template <typename T>
std::pair<QRPivoted<T>, int> rrqr(MatrixX<T>& A, T rtol = std::numeric_limits<T>::epsilon()) {
using std::abs;
Expand Down Expand Up @@ -491,11 +494,12 @@ truncateQRResult(const Eigen::MatrixX<T>& Q, const Eigen::MatrixX<T>& R, int k)
// Truncated SVD (TSVD)
template <typename T>
std::tuple<Eigen::MatrixX<T>, Eigen::VectorX<T>, Eigen::MatrixX<T>>
tsvd(Eigen::MatrixX<T>& A, T rtol = std::numeric_limits<T>::epsilon()) {
tsvd(const Eigen::MatrixX<T>& A, T rtol = std::numeric_limits<T>::epsilon()) {
// Step 1: Apply RRQR to A
QRPivoted<T> A_qr;
int k;
std::tie(A_qr, k) = rrqr<T>(A, rtol);
Eigen::MatrixX<T> A_ = A; // create a copy of A
std::tie(A_qr, k) = rrqr<T>(A_, rtol);
// Step 2: Truncate QR Result to rank k
auto tqr = truncate_qr_result<T>(A_qr, k);
auto p = A_qr.jpvt;
Expand All @@ -511,16 +515,22 @@ tsvd(Eigen::MatrixX<T>& A, T rtol = std::numeric_limits<T>::epsilon()) {
// Do not use the svd_jacobi function directly.
// Better to write a wrrapper function for the SVD.
Eigen::JacobiSVD<decltype(R_trunc)> svd;

// The following comment is taken from Julia's implementation
// # RRQR is an excellent preconditioner for Jacobi. One should then perform
// # Jacobi on RT
svd.compute(R_trunc.transpose(), Eigen::ComputeThinU | Eigen::ComputeThinV);

// Reconstruct A from QR factorization
Eigen::PermutationMatrix<Dynamic, Dynamic> perm(p.size());
perm.indices() = invperm(p);
perm.indices() = p;

Eigen::MatrixX<T> U = Q_trunc * svd.matrixV();
// implement invperm
Eigen::MatrixX<T> V = (perm * svd.matrixU());

Eigen::VectorX<T> s = svd.singularValues();
// TODO: Create a return type for truncated SVD
return std::make_tuple(U, s, V);
}

Expand Down
55 changes: 30 additions & 25 deletions test/_linalg.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ TEST_CASE("RRQR Trunc", "[linalg]") {
auto QR = truncate_qr_result<DDouble>(A_qr, k);
auto Q = QR.first;
auto R = QR.second;

REQUIRE(Q.rows() == m);
REQUIRE(Q.cols() == k);
REQUIRE(R.rows() == k);
REQUIRE(R.cols() == n);
MatrixX<DDouble> A_rec = Q * R * getPropertyP(A_qr).transpose();
REQUIRE(A_rec.isApprox(Aorig, 1e-5 * A.norm()));
}
Expand All @@ -266,43 +269,45 @@ TEST_CASE("TSVD", "[linalg]") {
using std::pow;
// double
{
for (auto tol : {1e-14, 1e-13}) {
VectorX<double> x = VectorX<double>::LinSpaced(201, -1, 1);
MatrixX<double> Aorig(201, 51);
for (auto tol : {1e-14}) {
int N1 = 201;
int N2 = 51;
VectorX<double> x = VectorX<double>::LinSpaced(N1, -1, 1);
//MatrixX<double> Aorig(201, 51);
MatrixX<double> Aorig(N1, N2);
for (int i = 0; i < Aorig.cols(); i++) {
//Aorig.col(i) = x.array().pow(i);
for (int j = 0; j < Aorig.rows(); j++) {
Aorig(j, i) = pow(x(j), i);
}
Aorig.col(i) = x.array().pow(i);
//for (int j = 0; j < Aorig.rows(); j++) {
//Aorig(j, i) = pow(x(j), i);
//}
}

MatrixX<double> A = Aorig;
MatrixX<double> A = Aorig; // create a copy of Aorig

auto tsvd_result = tsvd<double>(A, double(tol));
tsvd<double>(Aorig, double(tol));
auto U = std::get<0>(tsvd_result);
auto s = std::get<1>(tsvd_result);
auto V = std::get<2>(tsvd_result);
int k = s.size();

auto S_diag = s.asDiagonal();
// U * S_diag * V.transpose();
// std::cout << "U: " << U.rows() << "," << U.cols() << std::endl;
// std::cout << "S: " << S_diag.rows() << "," << S_diag.cols() << std::endl;
// std::cout << "V: " << V.rows() << "," << V.cols() << std::endl;
auto B = U * S_diag * V.transpose() - Aorig;
std::cout << Aorig.norm() << std::endl; // Oh...?
std::cout << B.norm() << std::endl; // Oh...?
bool test_completed = false;
REQUIRE(!test_completed);
//REQUIRE((U * S_diag * V).isApprox(Aorig, tol * A.norm()));
/*
auto Areconst = U * S_diag * V.transpose();
auto diff = (A - Areconst).norm() / A.norm();
// std::cout << "diff " << diff << std::endl;
// std::cout << "Areconst " << Areconst.norm() << std::endl;
// std::cout << "Aorig " << Aorig.norm() << std::endl;
// std::cout << "norm diff" << Aorig.norm() - Areconst.norm() << std::endl;

REQUIRE(Areconst.isApprox(Aorig, tol * Aorig.norm()));
REQUIRE((U.transpose() * U).isIdentity());
REQUIRE((V.transpose() * V).isIdentity());
REQUIRE(std::is_sorted(S.data(), S.data() + S.size(), std::greater<DDouble>()));
REQUIRE(std::is_sorted(s.data(), s.data() + s.size(), std::greater<DDouble>()));
REQUIRE(k < std::min(A.rows(), A.cols()));

Eigen::JacobiSVD<MatrixX<DDouble>> svd(A.cast<DDouble>());
REQUIRE(S.isApprox(svd.singularValues().head(k).cast<DDouble>()));
*/
Eigen::JacobiSVD<MatrixX<double>> svd(Aorig.cast<double>());
REQUIRE(s.isApprox(svd.singularValues().head(k)));
REQUIRE(S_diag.toDenseMatrix().isApprox(svd.singularValues().head(k).asDiagonal().toDenseMatrix()));
}
}
}
Expand Down Expand Up @@ -382,7 +387,7 @@ TEST_CASE("SVD of VERY triangular 2x2", "[linalg]") {
REQUIRE(cu.hi() == 1.0);
REQUIRE(su.hi() == 1e-100);
REQUIRE(smax.hi() == 1.0);
REQUIRE(smin.hi() == 1e-200); // so cloe
REQUIRE(smin.hi() == 1e-200);
REQUIRE(cv.hi() == 1e-100);
REQUIRE(sv.hi() == 1.0);
U << cu, -su, su, cu;
Expand Down
2 changes: 1 addition & 1 deletion test/freq.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST_CASE("freq", "[Imaginary value calculation for MatsubaraFreq]")
double beta = 3.0;
BosonicFreq bf(2);
std::complex<double> expected_value_im(0, 2 * M_PI / beta);
std::cout << bf.value_im(beta) << std::endl;
// std::cout << bf.value_im(beta) << std::endl;
REQUIRE(bf.value_im(beta) == expected_value_im);
}

Expand Down

0 comments on commit c5988dc

Please sign in to comment.