From e67487ebe53794adb7c76f62988828adcda7bd45 Mon Sep 17 00:00:00 2001 From: Satoshi Terasaki Date: Tue, 24 Dec 2024 11:57:01 +0900 Subject: [PATCH 1/2] Improve test/augment.cxx --- test/augment.cxx | 97 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/test/augment.cxx b/test/augment.cxx index d4c7013..64526a9 100644 --- a/test/augment.cxx +++ b/test/augment.cxx @@ -7,48 +7,79 @@ #include #include -#include +#include // for Approx +#include #include #include using namespace sparseir; using namespace std; -TEST_CASE("AbstractAugmentation") { - SECTION("TauConst") { - double beta = 1000.0; - auto tc = TauConst(beta); - double tau = 0.5; - double y = tc(tau); - auto dtc = tc.deriv(); - double x = 2.0; - REQUIRE(tc.beta == beta); - REQUIRE(dtc(x) == 0.0); +TEST_CASE("SparseIR Basis Functions", "[SparseIR]") +{ + using Catch::Approx; + using namespace sparseir; + + SECTION("TauConst") + { + REQUIRE_THROWS_AS(TauConst(-34), std::domain_error); + + TauConst tc(123); + REQUIRE(tc.beta == 123.0); + + REQUIRE_THROWS_AS(tc(-3), std::domain_error); + REQUIRE_THROWS_AS(tc(321), std::domain_error); + + REQUIRE(tc(100) == 1 / std::sqrt(123)); + //REQUIRE(tc(MatsubaraFreq(0)) == std::sqrt(123)); + //REQUIRE(tc(MatsubaraFreq(92)) == 0.0); + //REQUIRE_THROWS_AS(tc(MatsubaraFreq(93)), std::runtime_error); + + //REQUIRE(sparseir::deriv(tc)(4.2) == 0.0); + //REQUIRE(sparseir::deriv(tc, 0) == tc); } - SECTION("TauLinear") { - double beta = 1000.0; - double tau_0 = 0.5; - double tau_1 = 1.0; - auto tl = TauLinear(beta); - double tau = 0.75; - double y = tl(tau); - auto dtl = tl.deriv(); - double x = 2.0; - REQUIRE(true); - //REQUIRE(dtl(x) == -beta * (tau - tau_0) / pow(tau_1 - tau_0, 2)); + + SECTION("TauLinear") + { + REQUIRE_THROWS_AS(TauLinear(-34), std::domain_error); + + TauLinear tl(123); + REQUIRE(tl.beta == Approx(123.0)); + + REQUIRE_THROWS_AS(tl(-3), std::domain_error); + REQUIRE_THROWS_AS(tl(321), std::domain_error); + REQUIRE(tl.norm == Approx(std::sqrt(3.0 / 123.0))); + REQUIRE(tl(100) == std::sqrt(3.0 / 123.0) * (2.0 / (123. * 100.) - 1.)); + // REQUIRE(tl(MatsubaraFreq(0)) == Approx(0.0)); + // REQUIRE(tl(MatsubaraFreq(92)) == + // Approx(std::sqrt(3 / 123) * 2 / std::complex(0, 1) * + // 123 / (92 * M_PI))); + // REQUIRE_THROWS_AS(tl(MatsubaraFreq(93)), std::runtime_error); + + //REQUIRE(sparseir::deriv(tl, 0) == tl); + //REQUIRE(sparseir::deriv(tl)(4.2) == + // Approx(std::sqrt(3 / 123) * 2 / 123)); + //REQUIRE(sparseir::deriv(tl, 2)(4.2) == Approx(0.0)); } - SECTION("MatsubaraConst") { - double beta = 1000.0; - double w_0 = 0.5; - double w_1 = 1.0; - auto mc = MatsubaraConst(beta); - double w = 0.75; - double y = mc(w); - auto dmc = mc.deriv(); - double x = 2.0; - REQUIRE(true); - //REQUIRE(dmc(x) == -beta * (w - w_0) / pow(w_1 - w_0, 2)); + + SECTION("MatsubaraConst") + { + REQUIRE_THROWS_AS(MatsubaraConst(-34), std::domain_error); + + MatsubaraConst mc(123); + REQUIRE(mc.beta == Approx(123.0)); + + REQUIRE_THROWS_AS(mc(-3), std::domain_error); + REQUIRE_THROWS_AS(mc(321), std::domain_error); + + REQUIRE(std::isnan(mc(100))); + //REQUIRE(mc(MatsubaraFreq(0)) == Approx(1.0)); + //REQUIRE(mc(MatsubaraFreq(92)) == Approx(1.0)); + //REQUIRE(mc(MatsubaraFreq(93)) == Approx(1.0)); + + //REQUIRE(sparseir::deriv(mc) == mc); + //REQUIRE(sparseir::deriv(mc, 0) == mc); } } From 79201c268949b52c7bef5dbcf466472f282432e4 Mon Sep 17 00:00:00 2001 From: SatoshiTerasaki Date: Tue, 24 Dec 2024 12:34:38 +0900 Subject: [PATCH 2/2] Update --- include/sparseir/augment.hpp | 4 ++-- include/sparseir/freq.hpp | 31 +++++++++++++++++++++++-------- test/augment.cxx | 31 ++++++++++++++++++++++--------- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/include/sparseir/augment.hpp b/include/sparseir/augment.hpp index ebe5152..73f333b 100644 --- a/include/sparseir/augment.hpp +++ b/include/sparseir/augment.hpp @@ -82,8 +82,8 @@ class TauLinear : public AbstractAugmentation { } std::complex operator()(MatsubaraFreq n) const override { - std::invalid_argument("TauConst is not a Fermionic basis."); - return std::numeric_limits::quiet_NaN(); + throw std::invalid_argument("TauConst is not a Fermionic basis."); + return std::numeric_limits>::quiet_NaN(); } std::function deriv(int order = 1) const override { diff --git a/include/sparseir/freq.hpp b/include/sparseir/freq.hpp index 7bdb3f6..c519441 100644 --- a/include/sparseir/freq.hpp +++ b/include/sparseir/freq.hpp @@ -41,7 +41,7 @@ inline std::unique_ptr create_statistics(int zeta) throw std::domain_error("Unknown statistics type"); } -// Matsubara frequency class template +// MatsubaraFreq class template template class MatsubaraFreq { static_assert(std::is_base_of::value, @@ -50,21 +50,36 @@ class MatsubaraFreq { public: int n; - inline MatsubaraFreq(int n) : n(n) - { + // コンストラクタ + inline MatsubaraFreq(int n) : n(n) { + static_assert(std::is_same::value || std::is_same::value, + "S must be Fermionic or Bosonic"); S stat; - if (!stat.allowed(n)) + if (!stat.allowed(n)) { throw std::domain_error("Frequency is not allowed for this type"); + } + instance_ = std::make_shared(stat); // 適切な型のインスタンスを保持 } - inline double value(double beta) const { return n * M_PI / beta; } - inline std::complex value_im(double beta) const - { + // 値の計算 + inline double value(double beta) const { + return n * M_PI / beta; + } + + // 複素数の値 + inline std::complex value_im(double beta) const { return std::complex(0, value(beta)); } - inline S statistics() const { return S(); } + // 統計型の取得 + inline S statistics() const { return *std::static_pointer_cast(instance_); } + + // n の取得 inline int get_n() const { return n; } + +private: + // インスタンスを保持する共有ポインタ + std::shared_ptr instance_; }; // Typedefs for convenience diff --git a/test/augment.cxx b/test/augment.cxx index 64526a9..4f6e66d 100644 --- a/test/augment.cxx +++ b/test/augment.cxx @@ -50,12 +50,22 @@ TEST_CASE("SparseIR Basis Functions", "[SparseIR]") REQUIRE_THROWS_AS(tl(-3), std::domain_error); REQUIRE_THROWS_AS(tl(321), std::domain_error); REQUIRE(tl.norm == Approx(std::sqrt(3.0 / 123.0))); - REQUIRE(tl(100) == std::sqrt(3.0 / 123.0) * (2.0 / (123. * 100.) - 1.)); - // REQUIRE(tl(MatsubaraFreq(0)) == Approx(0.0)); - // REQUIRE(tl(MatsubaraFreq(92)) == - // Approx(std::sqrt(3 / 123) * 2 / std::complex(0, 1) * - // 123 / (92 * M_PI))); - // REQUIRE_THROWS_AS(tl(MatsubaraFreq(93)), std::runtime_error); + double tau = 100; + REQUIRE(tl(tau) == std::sqrt(3.0 / 123.0) * (2.0 / 123. * tau - 1.)); + + MatsubaraFreq freq0(0); + REQUIRE(tl(freq0) == 0.0); + MatsubaraFreq freq92(92); + // Calculate the expected complex value + std::complex expected_value = std::sqrt(3. / 123.) * 2. / std::complex(0, 1) * 123. / (92. * M_PI); + // Get the actual value from the function + std::complex actual_value = tl(freq92); + + REQUIRE(actual_value.real() == Approx(expected_value.real())); + REQUIRE(actual_value.imag() == Approx(expected_value.imag())); + + MatsubaraFreq freq93(93); + REQUIRE_THROWS_AS(tl(freq93), std::invalid_argument); //REQUIRE(sparseir::deriv(tl, 0) == tl); //REQUIRE(sparseir::deriv(tl)(4.2) == @@ -74,9 +84,12 @@ TEST_CASE("SparseIR Basis Functions", "[SparseIR]") REQUIRE_THROWS_AS(mc(321), std::domain_error); REQUIRE(std::isnan(mc(100))); - //REQUIRE(mc(MatsubaraFreq(0)) == Approx(1.0)); - //REQUIRE(mc(MatsubaraFreq(92)) == Approx(1.0)); - //REQUIRE(mc(MatsubaraFreq(93)) == Approx(1.0)); + MatsubaraFreq freq0(0); + REQUIRE(mc(freq0) == 1.0); + MatsubaraFreq freq92(0); + REQUIRE(mc(freq92) == 1.0); + MatsubaraFreq freq93(93); + REQUIRE(mc(freq93) == 1.0); //REQUIRE(sparseir::deriv(mc) == mc); //REQUIRE(sparseir::deriv(mc, 0) == mc);