Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port basis_set to C++ #66

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions include/sparseir/basis_set.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#pragma once

#include <memory>
#include <Eigen/Dense>

namespace sparseir {

template <typename T>
class FiniteTempBasisSet {
public:
using Scalar = T;
using BasisPtr = std::shared_ptr<AbstractBasis<Scalar>>;
using TauSamplingPtr = std::shared_ptr<TauSampling<Scalar>>;
using MatsubaraSamplingPtr = std::shared_ptr<MatsubaraSampling<std::complex<Scalar>>>;

// Constructors
FiniteTempBasisSet(Scalar beta, Scalar omega_max, Scalar epsilon = std::numeric_limits<Scalar>::quiet_NaN(),
const SVEResult<Scalar>& sve_result = SVEResult<Scalar>())
: beta_(beta), omega_max_(omega_max), epsilon_(epsilon) {
initialize(sve_result);
}

// Accessors
Scalar beta() const { return beta_; }
Scalar omega_max() const { return omega_max_; }
Scalar accuracy() const { return epsilon_; }

const BasisPtr& basis_f() const { return basis_f_; }
const BasisPtr& basis_b() const { return basis_b_; }

const TauSamplingPtr& smpl_tau_f() const { return smpl_tau_f_; }
const TauSamplingPtr& smpl_tau_b() const { return smpl_tau_b_; }

const MatsubaraSamplingPtr& smpl_wn_f() const { return smpl_wn_f_; }
const MatsubaraSamplingPtr& smpl_wn_b() const { return smpl_wn_b_; }

const Eigen::VectorXd& tau() const { return smpl_tau_f_->sampling_points(); }
const std::vector<int>& wn_f() const { return smpl_wn_f_->sampling_frequencies(); }
const std::vector<int>& wn_b() const { return smpl_wn_b_->sampling_frequencies(); }

const SVEResult<Scalar>& sve_result() const { return sve_result_; }

private:
void initialize(const SVEResult<Scalar>& sve_result_input) {
if (std::isnan(epsilon_)) {
epsilon_ = std::numeric_limits<Scalar>::epsilon();
}

LogisticKernel kernel(beta_ * omega_max_);
sve_result_ = sve_result_input.is_valid() ? sve_result_input : compute_sve(kernel);

basis_f_ = std::make_shared<FiniteTempBasis<Fermionic, LogisticKernel>>(
beta_, omega_max_, epsilon_, kernel, sve_result_);
basis_b_ = std::make_shared<FiniteTempBasis<Bosonic, LogisticKernel>>(
beta_, omega_max_, epsilon_, kernel, sve_result_);

// Initialize sampling objects
smpl_tau_f_ = std::make_shared<TauSampling<Scalar>>(basis_f_);
smpl_tau_b_ = std::make_shared<TauSampling<Scalar>>(basis_b_);

smpl_wn_f_ = std::make_shared<MatsubaraSampling<std::complex<Scalar>>>(basis_f_);
smpl_wn_b_ = std::make_shared<MatsubaraSampling<std::complex<Scalar>>>(basis_b_);
}

Scalar beta_;
Scalar omega_max_;
Scalar epsilon_;

BasisPtr basis_f_;
BasisPtr basis_b_;

TauSamplingPtr smpl_tau_f_;
TauSamplingPtr smpl_tau_b_;

MatsubaraSamplingPtr smpl_wn_f_;
MatsubaraSamplingPtr smpl_wn_b_;

SVEResult<Scalar> sve_result_;
};

} // namespace sparseir
1 change: 1 addition & 0 deletions include/sparseir/sparseir-header-only.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
#include "./basis.hpp"
#include "./augment.hpp"
#include "./sampling.hpp"
#include "./basis_set.hpp"
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_executable(libsparseirtests
basis.cxx
augment.cxx
sampling.cxx
basis_set.cxx
)

target_link_libraries(libsparseirtests PRIVATE Catch2::Catch2WithMain)
Expand Down
61 changes: 61 additions & 0 deletions test/basis_set.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <Eigen/Dense>
#include <catch2/catch_test_macros.hpp>

#include <sparseir/sparseir-header-only.hpp>
#include <xprec/ddouble-header-only.hpp>


#include <memory>
#include <complex>

using namespace sparseir;

TEST_CASE("FiniteTempBasisSet consistency tests", "[basis_set]") {

SECTION("Consistency") {
// Define parameters
double beta = 1.0; // Inverse temperature
double omega_max = 10.0; // Maximum frequency
double epsilon = 1e-5; // Desired accuracy

// Create kernels
sparseir::LogisticKernel kernel;

// Create shared_ptr instances of FiniteTempBasis
auto basis_f = std::make_shared<sparseir::FiniteTempBasis<sparseir::Fermionic, sparseir::LogisticKernel>>(
beta, omega_max, epsilon, kernel);
auto basis_b = std::make_shared<sparseir::FiniteTempBasis<sparseir::Bosonic, sparseir::LogisticKernel>>(
beta, omega_max, epsilon, kernel);

// Create TauSampling instances
//sparseir::TauSampling<double> smpl_tau_f(basis_f);
//sparseir::TauSampling<double> smpl_tau_b(basis_b);
//TauSampling<double> smpl_tau_f(basis_f);
//auto basis_b = std::make_shared<FiniteTempBasis<Bosonic>>(/* constructor arguments */);
//TauSampling<double> smpl_tau_b(basis_b);
/*
// Create TauSampling objects
TauSampling<double> smpl_tau_f(basis_f);
TauSampling<double> smpl_tau_b(basis_b);

// Create MatsubaraSampling objects
MatsubaraSampling<std::complex<double>> smpl_wn_f(basis_f);
MatsubaraSampling<std::complex<double>> smpl_wn_b(basis_b);

// Create FiniteTempBasisSet
FiniteTempBasisSet<double> bs(beta, omega_max, epsilon, sve_result);

// Check that sampling points are equal
REQUIRE(smpl_tau_f.sampling_points().isApprox(smpl_tau_b.sampling_points()));
REQUIRE(smpl_tau_f.sampling_points().isApprox(bs.tau()));

// Check that matrices are equal
REQUIRE(smpl_tau_f.matrix().isApprox(smpl_tau_b.matrix()));
REQUIRE(bs.smpl_tau_f()->matrix().isApprox(smpl_tau_f.matrix()));
REQUIRE(bs.smpl_tau_b()->matrix().isApprox(smpl_tau_b.matrix()));

REQUIRE(bs.smpl_wn_f()->matrix().isApprox(smpl_wn_f.matrix()));
REQUIRE(bs.smpl_wn_b()->matrix().isApprox(smpl_wn_b.matrix()));
*/
}
}
Loading