Skip to content

Commit

Permalink
Merge pull request #46 from SpM-lab/terasaki/write-tests-truncate
Browse files Browse the repository at this point in the history
Write tests for `truncate`
  • Loading branch information
terasakisatoshi authored Dec 11, 2024
2 parents 73975a9 + c1a69f9 commit 271ad56
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
41 changes: 26 additions & 15 deletions include/sparseir/sve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,16 @@ std::shared_ptr<AbstractSVE<K, T>> determine_sve(const K& kernel, double safe_ep

// Function to truncate singular values
template <typename T>
inline void truncate_singular_values(
inline std::tuple<std::vector<Eigen::MatrixX<T>>, std::vector<Eigen::VectorX<T>>, std::vector<Eigen::MatrixX<T>>> truncate(
std::vector<Eigen::MatrixX<T>> &u_list,
std::vector<Eigen::VectorX<T>> &s_list,
std::vector<Eigen::MatrixX<T>> &v_list,
T rtol,
int lmax)
T rtol=0.0,
int lmax=std::numeric_limits<int>::max())
{
std::vector<Eigen::MatrixX<T>> u_list_truncated;
std::vector<Eigen::VectorX<T>> s_list_truncated;
std::vector<Eigen::MatrixX<T>> v_list_truncated;
// Collect all singular values
std::vector<T> all_singular_values;
for (const auto &s : s_list)
Expand Down Expand Up @@ -456,11 +459,12 @@ inline void truncate_singular_values(
}
if (scount < s.size())
{
u_list[idx] = u_list[idx].leftCols(scount);
s_list[idx] = s_list[idx].head(scount);
v_list[idx] = v_list[idx].leftCols(scount);
u_list_truncated.push_back(u_list[idx].leftCols(scount));
s_list_truncated.push_back(s_list[idx].head(scount));
v_list_truncated.push_back(v_list[idx].leftCols(scount));
}
}
return std::make_tuple(u_list_truncated, s_list_truncated, v_list_truncated);
}


Expand All @@ -470,26 +474,33 @@ auto pre_postprocess(K &kernel, double safe_epsilon, int n_gauss, double cutoff
auto sve = determine_sve<K, T>(kernel, safe_epsilon, n_gauss);
// Compute SVDs
std::vector<Eigen::MatrixX<T>> matrices = sve->matrices();
std::vector<Eigen::BDCSVD<Eigen::MatrixX<T>>> svds;
// TODO: implement SVD Resutls
std::vector<std::tuple<Eigen::MatrixX<T>, Eigen::MatrixX<T>, Eigen::MatrixX<T>>> svds;
for (const auto& mat : matrices) {
Eigen::BDCSVD<Eigen::MatrixX<T>> svd(mat, Eigen::ComputeThinU | Eigen::ComputeThinV);
auto svd = sparseir::compute_svd(mat);
svds.push_back(svd);
}

// Extract singular values and vectors
std::vector<Eigen::MatrixX<T>> u_list, v_list;
std::vector<Eigen::VectorX<T>> s_list;
std::vector<Eigen::MatrixX<T>> u_list_, v_list_;
std::vector<Eigen::VectorX<T>> s_list_;
for (const auto& svd : svds) {
u_list.push_back(svd.matrixU());
s_list.push_back(svd.singularValues());
v_list.push_back(svd.matrixV());
auto u = std::get<0>(svd);
auto s = std::get<1>(svd);
auto v = std::get<2>(svd);
u_list_.push_back(u);
s_list_.push_back(s);
v_list_.push_back(v);
}

// Apply cutoff and lmax
T cutoff_actual = std::isnan(cutoff) ? 2 * T(std::numeric_limits<double>::epsilon()) : T(cutoff);
truncate_singular_values(u_list, s_list, v_list, cutoff_actual, lmax);
std::vector<Eigen::MatrixX<T>> u_list_truncated;
std::vector<Eigen::VectorX<T>> s_list_truncated;
std::vector<Eigen::MatrixX<T>> v_list_truncated;
std::tie(u_list_truncated, s_list_truncated, v_list_truncated) = truncate(u_list_, s_list_, v_list_, cutoff_actual, lmax);
// Postprocess to get the SVEResult
return sve->postprocess(u_list, s_list, v_list);
return sve->postprocess(u_list_truncated, s_list_truncated, v_list_truncated);
}

// Function to compute SVE result
Expand Down
49 changes: 29 additions & 20 deletions test/sve.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,39 @@ TEST_CASE("sve.cpp", "[choose_accuracy]") {
REQUIRE(sparseir::choose_accuracy(1e-6, "Float64", "auto") == std::make_tuple(1.0e-6, "Float64", "default"));
REQUIRE(sparseir::choose_accuracy(1e-6, "Float64", "accurate") == std::make_tuple(1.0e-6, "Float64", "accurate"));

/*

SECTION("truncate") {
sparseir::CentrosymmSVE sve(LogisticKernel(5), 1e-6, "Float64");
sparseir::CentrosymmSVE<sparseir::LogisticKernel, double> sve(sparseir::LogisticKernel(5), 1e-6);
std::vector<Eigen::MatrixX<double>> matrices = sve.matrices();
REQUIRE(matrices.size() == 2);
std::vector<std::tuple<Eigen::MatrixX<double>, Eigen::MatrixX<double>, Eigen::MatrixX<double>>> svds;
for (const auto& mat : matrices) {
auto svd = sparseir::compute_svd(mat);
svds.push_back(svd);
}

auto svds = sparseir::compute_svd(sparseir::matrices(sve));
std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> svd_tuple = svds;
std::vector<double> u_ = std::get<0>(svd_tuple);
std::vector<double> s_ = std::get<1>(svd_tuple);
std::vector<double> v_ = std::get<2>(svd_tuple);
// Extract singular values and vectors
std::vector<Eigen::MatrixX<double>> u_list, v_list;
std::vector<Eigen::VectorX<double>> s_list;
for (const auto& svd : svds) {
auto u = std::get<0>(svd);
auto s = std::get<1>(svd);
auto v = std::get<2>(svd);
u_list.push_back(u);
s_list.push_back(s);
v_list.push_back(v);
}

for (int lmax = 3; lmax <= 20; ++lmax) {
std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> truncated = sparseir::truncate(u_, s_, v_, lmax);
std::vector<double> u = std::get<0>(truncated);
std::vector<double> s = std::get<1>(truncated);
std::vector<double> v = std::get<2>(truncated);
std::tuple<std::vector<double>, std::vector<double>, std::vector<double>> postprocessed = sparseir::postprocess(sve, u, s, v);
std::vector<double> u_post = std::get<0>(postprocessed);
std::vector<double> s_post = std::get<1>(postprocessed);
std::vector<double> v_post = std::get<2>(postprocessed);
REQUIRE(u_post.size() == s_post.size());
REQUIRE(s_post.size() == v_post.size());
REQUIRE(s_post.size() <= static_cast<size_t>(lmax - 1));
auto truncated = sparseir::truncate(u_list, s_list, v_list, 1e-8, lmax);
auto u = std::get<0>(truncated);
auto s = std::get<1>(truncated);
auto v = std::get<2>(truncated);

auto sveresult = sve.postprocess(u, s, v);
REQUIRE(sveresult.u.size() == sveresult.s.size());
REQUIRE(sveresult.s.size() == sveresult.v.size());
REQUIRE(sveresult.s.size() <= static_cast<size_t>(lmax - 1));
}
}
*/
}

0 comments on commit 271ad56

Please sign in to comment.