From 282a2fa3d1a4d7eb06bd9274fe8f0fe41365f9eb Mon Sep 17 00:00:00 2001 From: SatoshiTerasaki Date: Sat, 21 Dec 2024 10:42:00 +0000 Subject: [PATCH] Port basis_set to C++ --- include/sparseir/basis_set.hpp | 81 +++++++++++++++++++++++ include/sparseir/sparseir-header-only.hpp | 1 + test/CMakeLists.txt | 1 + test/basis_set.cxx | 61 +++++++++++++++++ 4 files changed, 144 insertions(+) create mode 100644 include/sparseir/basis_set.hpp create mode 100644 test/basis_set.cxx diff --git a/include/sparseir/basis_set.hpp b/include/sparseir/basis_set.hpp new file mode 100644 index 0000000..4758001 --- /dev/null +++ b/include/sparseir/basis_set.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include +#include + +namespace sparseir { + +template +class FiniteTempBasisSet { +public: + using Scalar = T; + using BasisPtr = std::shared_ptr>; + using TauSamplingPtr = std::shared_ptr>; + using MatsubaraSamplingPtr = std::shared_ptr>>; + + // Constructors + FiniteTempBasisSet(Scalar beta, Scalar omega_max, Scalar epsilon = std::numeric_limits::quiet_NaN(), + const SVEResult& sve_result = SVEResult()) + : beta_(beta), omega_max_(omega_max), epsilon_(epsilon) { + initialize(sve_result); + } + + // Accessors + Scalar beta() const { return beta_; } + Scalar omega_max() const { return omega_max_; } + Scalar accuracy() const { return epsilon_; } + + const BasisPtr& basis_f() const { return basis_f_; } + const BasisPtr& basis_b() const { return basis_b_; } + + const TauSamplingPtr& smpl_tau_f() const { return smpl_tau_f_; } + const TauSamplingPtr& smpl_tau_b() const { return smpl_tau_b_; } + + const MatsubaraSamplingPtr& smpl_wn_f() const { return smpl_wn_f_; } + const MatsubaraSamplingPtr& smpl_wn_b() const { return smpl_wn_b_; } + + const Eigen::VectorXd& tau() const { return smpl_tau_f_->sampling_points(); } + const std::vector& wn_f() const { return smpl_wn_f_->sampling_frequencies(); } + const std::vector& wn_b() const { return smpl_wn_b_->sampling_frequencies(); } + + const SVEResult& sve_result() const { return sve_result_; } + +private: + void initialize(const SVEResult& sve_result_input) { + if (std::isnan(epsilon_)) { + epsilon_ = std::numeric_limits::epsilon(); + } + + LogisticKernel kernel(beta_ * omega_max_); + sve_result_ = sve_result_input.is_valid() ? sve_result_input : compute_sve(kernel); + + basis_f_ = std::make_shared>( + beta_, omega_max_, epsilon_, kernel, sve_result_); + basis_b_ = std::make_shared>( + beta_, omega_max_, epsilon_, kernel, sve_result_); + + // Initialize sampling objects + smpl_tau_f_ = std::make_shared>(basis_f_); + smpl_tau_b_ = std::make_shared>(basis_b_); + + smpl_wn_f_ = std::make_shared>>(basis_f_); + smpl_wn_b_ = std::make_shared>>(basis_b_); + } + + Scalar beta_; + Scalar omega_max_; + Scalar epsilon_; + + BasisPtr basis_f_; + BasisPtr basis_b_; + + TauSamplingPtr smpl_tau_f_; + TauSamplingPtr smpl_tau_b_; + + MatsubaraSamplingPtr smpl_wn_f_; + MatsubaraSamplingPtr smpl_wn_b_; + + SVEResult sve_result_; +}; + +} // namespace sparseir \ No newline at end of file diff --git a/include/sparseir/sparseir-header-only.hpp b/include/sparseir/sparseir-header-only.hpp index 6e0ec42..aa85f78 100644 --- a/include/sparseir/sparseir-header-only.hpp +++ b/include/sparseir/sparseir-header-only.hpp @@ -13,3 +13,4 @@ #include "./basis.hpp" #include "./augment.hpp" #include "./sampling.hpp" +#include "./basis_set.hpp" diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4a5a816..b6bcabb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable(libsparseirtests basis.cxx augment.cxx sampling.cxx + basis_set.cxx ) target_link_libraries(libsparseirtests PRIVATE Catch2::Catch2WithMain) diff --git a/test/basis_set.cxx b/test/basis_set.cxx new file mode 100644 index 0000000..8e924c9 --- /dev/null +++ b/test/basis_set.cxx @@ -0,0 +1,61 @@ +#include +#include + +#include +#include + + +#include +#include + +using namespace sparseir; + +TEST_CASE("FiniteTempBasisSet consistency tests", "[basis_set]") { + + SECTION("Consistency") { + // Define parameters + double beta = 1.0; // Inverse temperature + double omega_max = 10.0; // Maximum frequency + double epsilon = 1e-5; // Desired accuracy + + // Create kernels + sparseir::LogisticKernel kernel; + + // Create shared_ptr instances of FiniteTempBasis + auto basis_f = std::make_shared>( + beta, omega_max, epsilon, kernel); + auto basis_b = std::make_shared>( + beta, omega_max, epsilon, kernel); + + // Create TauSampling instances + //sparseir::TauSampling smpl_tau_f(basis_f); + //sparseir::TauSampling smpl_tau_b(basis_b); + //TauSampling smpl_tau_f(basis_f); + //auto basis_b = std::make_shared>(/* constructor arguments */); + //TauSampling smpl_tau_b(basis_b); + /* + // Create TauSampling objects + TauSampling smpl_tau_f(basis_f); + TauSampling smpl_tau_b(basis_b); + + // Create MatsubaraSampling objects + MatsubaraSampling> smpl_wn_f(basis_f); + MatsubaraSampling> smpl_wn_b(basis_b); + + // Create FiniteTempBasisSet + FiniteTempBasisSet bs(beta, omega_max, epsilon, sve_result); + + // Check that sampling points are equal + REQUIRE(smpl_tau_f.sampling_points().isApprox(smpl_tau_b.sampling_points())); + REQUIRE(smpl_tau_f.sampling_points().isApprox(bs.tau())); + + // Check that matrices are equal + REQUIRE(smpl_tau_f.matrix().isApprox(smpl_tau_b.matrix())); + REQUIRE(bs.smpl_tau_f()->matrix().isApprox(smpl_tau_f.matrix())); + REQUIRE(bs.smpl_tau_b()->matrix().isApprox(smpl_tau_b.matrix())); + + REQUIRE(bs.smpl_wn_f()->matrix().isApprox(smpl_wn_f.matrix())); + REQUIRE(bs.smpl_wn_b()->matrix().isApprox(smpl_wn_b.matrix())); + */ + } +} \ No newline at end of file