diff --git a/include/sparseir/sve.hpp b/include/sparseir/sve.hpp index 4ece025..a75d349 100644 --- a/include/sparseir/sve.hpp +++ b/include/sparseir/sve.hpp @@ -413,13 +413,16 @@ std::shared_ptr> determine_sve(const K& kernel, double safe_ep // Function to truncate singular values template -inline void truncate_singular_values( +inline std::tuple>, std::vector>, std::vector>> truncate( std::vector> &u_list, std::vector> &s_list, std::vector> &v_list, - T rtol, - int lmax) + T rtol=0.0, + int lmax=std::numeric_limits::max()) { + std::vector> u_list_truncated; + std::vector> s_list_truncated; + std::vector> v_list_truncated; // Collect all singular values std::vector all_singular_values; for (const auto &s : s_list) @@ -456,11 +459,12 @@ inline void truncate_singular_values( } if (scount < s.size()) { - u_list[idx] = u_list[idx].leftCols(scount); - s_list[idx] = s_list[idx].head(scount); - v_list[idx] = v_list[idx].leftCols(scount); + u_list_truncated.push_back(u_list[idx].leftCols(scount)); + s_list_truncated.push_back(s_list[idx].head(scount)); + v_list_truncated.push_back(v_list[idx].leftCols(scount)); } } + return std::make_tuple(u_list_truncated, s_list_truncated, v_list_truncated); } @@ -470,26 +474,33 @@ auto pre_postprocess(K &kernel, double safe_epsilon, int n_gauss, double cutoff auto sve = determine_sve(kernel, safe_epsilon, n_gauss); // Compute SVDs std::vector> matrices = sve->matrices(); - std::vector>> svds; + // TODO: implement SVD Resutls + std::vector, Eigen::MatrixX, Eigen::MatrixX>> svds; for (const auto& mat : matrices) { - Eigen::BDCSVD> svd(mat, Eigen::ComputeThinU | Eigen::ComputeThinV); + auto svd = sparseir::compute_svd(mat); svds.push_back(svd); } // Extract singular values and vectors - std::vector> u_list, v_list; - std::vector> s_list; + std::vector> u_list_, v_list_; + std::vector> s_list_; for (const auto& svd : svds) { - u_list.push_back(svd.matrixU()); - s_list.push_back(svd.singularValues()); - v_list.push_back(svd.matrixV()); + auto u = std::get<0>(svd); + auto s = std::get<1>(svd); + auto v = std::get<2>(svd); + u_list_.push_back(u); + s_list_.push_back(s); + v_list_.push_back(v); } // Apply cutoff and lmax T cutoff_actual = std::isnan(cutoff) ? 2 * T(std::numeric_limits::epsilon()) : T(cutoff); - truncate_singular_values(u_list, s_list, v_list, cutoff_actual, lmax); + std::vector> u_list_truncated; + std::vector> s_list_truncated; + std::vector> v_list_truncated; + std::tie(u_list_truncated, s_list_truncated, v_list_truncated) = truncate(u_list_, s_list_, v_list_, cutoff_actual, lmax); // Postprocess to get the SVEResult - return sve->postprocess(u_list, s_list, v_list); + return sve->postprocess(u_list_truncated, s_list_truncated, v_list_truncated); } // Function to compute SVE result diff --git a/test/sve.cxx b/test/sve.cxx index 22663ea..50c921d 100644 --- a/test/sve.cxx +++ b/test/sve.cxx @@ -177,30 +177,39 @@ TEST_CASE("sve.cpp", "[choose_accuracy]") { REQUIRE(sparseir::choose_accuracy(1e-6, "Float64", "auto") == std::make_tuple(1.0e-6, "Float64", "default")); REQUIRE(sparseir::choose_accuracy(1e-6, "Float64", "accurate") == std::make_tuple(1.0e-6, "Float64", "accurate")); - /* + SECTION("truncate") { - sparseir::CentrosymmSVE sve(LogisticKernel(5), 1e-6, "Float64"); + sparseir::CentrosymmSVE sve(sparseir::LogisticKernel(5), 1e-6); + std::vector> matrices = sve.matrices(); + REQUIRE(matrices.size() == 2); + std::vector, Eigen::MatrixX, Eigen::MatrixX>> svds; + for (const auto& mat : matrices) { + auto svd = sparseir::compute_svd(mat); + svds.push_back(svd); + } - auto svds = sparseir::compute_svd(sparseir::matrices(sve)); - std::tuple, std::vector, std::vector> svd_tuple = svds; - std::vector u_ = std::get<0>(svd_tuple); - std::vector s_ = std::get<1>(svd_tuple); - std::vector v_ = std::get<2>(svd_tuple); + // Extract singular values and vectors + std::vector> u_list, v_list; + std::vector> s_list; + for (const auto& svd : svds) { + auto u = std::get<0>(svd); + auto s = std::get<1>(svd); + auto v = std::get<2>(svd); + u_list.push_back(u); + s_list.push_back(s); + v_list.push_back(v); + } for (int lmax = 3; lmax <= 20; ++lmax) { - std::tuple, std::vector, std::vector> truncated = sparseir::truncate(u_, s_, v_, lmax); - std::vector u = std::get<0>(truncated); - std::vector s = std::get<1>(truncated); - std::vector v = std::get<2>(truncated); - - std::tuple, std::vector, std::vector> postprocessed = sparseir::postprocess(sve, u, s, v); - std::vector u_post = std::get<0>(postprocessed); - std::vector s_post = std::get<1>(postprocessed); - std::vector v_post = std::get<2>(postprocessed); - REQUIRE(u_post.size() == s_post.size()); - REQUIRE(s_post.size() == v_post.size()); - REQUIRE(s_post.size() <= static_cast(lmax - 1)); + auto truncated = sparseir::truncate(u_list, s_list, v_list, 1e-8, lmax); + auto u = std::get<0>(truncated); + auto s = std::get<1>(truncated); + auto v = std::get<2>(truncated); + + auto sveresult = sve.postprocess(u, s, v); + REQUIRE(sveresult.u.size() == sveresult.s.size()); + REQUIRE(sveresult.s.size() == sveresult.v.size()); + REQUIRE(sveresult.s.size() <= static_cast(lmax - 1)); } } - */ }