From e53890a30001eb121fd3db8f3911c0ef18b04c46 Mon Sep 17 00:00:00 2001 From: SatoshiTerasaki Date: Sun, 22 Dec 2024 12:14:16 +0900 Subject: [PATCH] Update augment.hpp --- include/sparseir/augment.hpp | 260 +++++++++++++++++++++-------------- 1 file changed, 158 insertions(+), 102 deletions(-) diff --git a/include/sparseir/augment.hpp b/include/sparseir/augment.hpp index 21c2799..c6e1a70 100644 --- a/include/sparseir/augment.hpp +++ b/include/sparseir/augment.hpp @@ -10,156 +10,212 @@ namespace sparseir { -template -class AbstractAugmentation { +// Abstract Augmentation +class AbstractAugmentation : public std::enable_shared_from_this { public: - virtual T operator()(T tau) const = 0; // Ensure this is virtual - virtual T deriv(T tau, int n = 1) const = 0; // Add this line - virtual std::complex hat(int n) const = 0; // Add this line + virtual double operator()(double tau) const = 0; + virtual double operator()(int bosonicFreq) const = 0; + virtual std::function deriv(int order = 1) const = 0; virtual ~AbstractAugmentation() = default; - virtual std::unique_ptr> clone() const = 0; }; -template -class AugmentedBasis : public AbstractBasis { +// TauConst Class +class TauConst : public AbstractAugmentation { public: - AugmentedBasis(std::shared_ptr> basis, - const std::vector>>& augmentations) - : _basis(basis), _augmentations(augmentations), _naug(augmentations.size()) { - //Error Handling: Check for null basis pointer - if (!_basis) { - throw std::invalid_argument("Basis cannot be null"); + double beta; + + explicit TauConst(double beta) : beta(beta) { + if (beta <= 0) { + throw std::domain_error("Temperature must be positive."); } - //Check for valid augmentations - for (const auto& aug : _augmentations) { - if (!aug) { - throw std::invalid_argument("Augmentation cannot be null"); - } + } + + double operator()(double tau) const override { + if (tau < 0 || tau > beta) { + throw std::domain_error("tau must be in [0, beta]."); } + return 1.0 / std::sqrt(beta); } - size_t size() const { return _basis->size() + _naug; } + double operator()(int bosonicFreq) const override { + return (bosonicFreq == 0) ? std::sqrt(beta) : 0.0; + } - Eigen::VectorXd u(const Eigen::VectorXd& tau) const { - Eigen::VectorXd result(size()); - for (size_t i = 0; i < _naug; ++i) { - result(i) = (*_augmentations[i])(tau(i)); - } - for (size_t i = _naug; i < size(); ++i) { - result(i) = _basis->u(tau(i - _naug))(i - _naug); + std::function deriv(int order = 1) const override { + if (order == 0) { + return [this](double tau) { return (*this)(tau); }; } - return result; + return [](double) { return 0.0; }; } +}; + +// TauLinear Class +class TauLinear : public AbstractAugmentation { +public: + double beta; + double norm; - Eigen::VectorXcf uhat(const Eigen::VectorXcf& wn) const { - Eigen::VectorXcf result(size()); - for (size_t i = 0; i < _naug; ++i) { - result(i) = (*_augmentations[i]).hat(wn(i)); + explicit TauLinear(double beta) : beta(beta), norm(std::sqrt(3.0 / beta)) { + if (beta <= 0) { + throw std::domain_error("Temperature must be positive."); } - for (size_t i = _naug; i < size(); ++i) { - result(i) = _basis->uhat(wn(i - _naug))(i - _naug); + } + + double operator()(double tau) const override { + if (tau < 0 || tau > beta) { + throw std::domain_error("tau must be in [0, beta]."); } - return result; + double x = 2.0 / beta * tau - 1.0; + return norm * x; } - Eigen::VectorXd v(const Eigen::VectorXd& w) const { - return _basis->v(w); + double operator()(int bosonicFreq) const override { + double inv_w = (bosonicFreq == 0) ? std::numeric_limits::infinity() : 1.0 / bosonicFreq; + std::complex imag_unit(0.0, 1.0); // 複素数の虚数単位 + return norm * 2.0 / imag_unit.imag() * inv_w; } - Eigen::VectorXd s() const { return _basis->s(); } + std::function deriv(int order = 1) const override { + if (order == 0) { + return [this](double tau) { return (*this)(tau); }; + } else if (order == 1) { + return [this](double) { return norm * 2.0 / beta; }; + } + return [](double) { return 0.0; }; + } +}; - double beta() const { return _basis->beta(); } +// MatsubaraConst Class +class MatsubaraConst : public AbstractAugmentation { +public: + double beta; - double wmax() const { return _basis->wmax(); } + explicit MatsubaraConst(double beta) : beta(beta) { + if (beta <= 0) { + throw std::domain_error("Temperature must be positive."); + } + } - std::shared_ptr statistics() const { - return _basis->statistics(); // Assuming _basis also returns a shared_ptr + double operator()(double tau) const override { + if (tau < 0 || tau > beta) { + throw std::domain_error("tau must be in [0, beta]."); + } + return std::numeric_limits::quiet_NaN(); } -private: - std::shared_ptr> _basis; - std::vector>> _augmentations; - size_t _naug; + double operator()(int matsubaraFreq) const override { + return 1.0; + } + + std::function deriv(int order = 1) const override { + return [this](double tau) { return (*this)(tau); }; + } }; +// AbstractAugmentedFunction +class AbstractAugmentedFunction { +public: + virtual size_t size() const = 0; + virtual Eigen::VectorXd operator()(double x) const = 0; + virtual Eigen::MatrixXd operator()(const Eigen::VectorXd &x) const = 0; + virtual ~AbstractAugmentedFunction() = default; +}; -template -class TauConst : public AbstractAugmentation { +// AugmentedFunction +template +class AugmentedFunction : public AbstractAugmentedFunction { public: - TauConst(T beta) : beta_(beta), norm_(1.0 / std::sqrt(beta)) { - if (beta_ <= 0) { - throw std::invalid_argument("beta must be positive"); - } - } + FB fbasis; + std::vector faug; - T operator()(T tau) const override { - check_tau_range(tau, beta_); - return norm_; - } - T deriv(T tau, int n = 1) const { - if (n == 0) return (*this)(tau); - return 0.0; - } - std::complex hat(int n) const { - return norm_ * std::sqrt(beta_); + AugmentedFunction(FB fbasis, std::vector faug) : fbasis(fbasis), faug(faug) {} + + size_t size() const override { + return faug.size() + fbasis.size(); } - std::unique_ptr> clone() const override { - return std::make_unique>(*this); + + Eigen::VectorXd operator()(double x) const override { + Eigen::VectorXd fbasis_x = fbasis(x); + Eigen::VectorXd faug_x(faug.size()); + for (size_t i = 0; i < faug.size(); ++i) { + faug_x[i] = faug[i](x); + } + Eigen::VectorXd result(faug.size() + fbasis_x.size()); + result << faug_x, fbasis_x; + return result; } -private: - T beta_; - T norm_; + Eigen::MatrixXd operator()(const Eigen::VectorXd &x) const override { + Eigen::MatrixXd fbasis_x = fbasis(x); + Eigen::MatrixXd faug_x(faug.size(), x.size()); + for (size_t i = 0; i < faug.size(); ++i) { + for (size_t j = 0; j < x.size(); ++j) { + faug_x(i, j) = faug[i](x[j]); + } + } + Eigen::MatrixXd result(faug.size() + fbasis_x.rows(), x.size()); + result << faug_x, fbasis_x; + return result; + } }; -template -class TauLinear : public AbstractAugmentation { +// AugmentedBasis +template +class AugmentedBasis : public AbstractBasis { public: - TauLinear(T beta) : beta_(beta), norm_(std::sqrt(3.0 / beta)) { - if (beta_ <= 0) { - throw std::invalid_argument("beta must be positive"); - } + std::shared_ptr basis; + std::vector> augmentations; + F u; + FHAT uhat; + + AugmentedBasis(std::shared_ptr basis, + std::vector> augmentations, + F u, FHAT uhat) + : AbstractBasis(basis->beta), basis(basis), augmentations(augmentations), u(u), uhat(uhat) {} + + size_t size() const override { + return nAug() + basis->size(); } - T operator()(T tau) const override { - check_tau_range(tau, beta_); - return norm_ * (2.0 * tau / beta_ - 1.0); + size_t nAug() const { + return augmentations.size(); } - T deriv(T tau, int n = 1) const override { - if (n == 1) return norm_ * 2.0 / beta_; - return 0.0; + + double accuracy() const override { + return basis->accuracy(); } - std::complex hat(int n) const override { - if (n == 0) return 0.0; - return norm_ * 2.0 * beta_ / (n * M_PI * std::complex(0, 1)); + + double omegaMax() const override { + return basis->omegaMax(); } - std::unique_ptr> clone() const override { - return std::make_unique>(*this); + + static std::shared_ptr create(std::shared_ptr basis, + std::vector> augmentations) { + auto augs = createAugmentations(augmentations, basis); + auto u = createAugmentedTauFunction(basis->u, augs); + auto uhat = createAugmentedMatsubaraFunction(basis->uhat, augs); + return std::make_shared(basis, augs, u, uhat); } private: - T beta_; - T norm_; -}; - -template -class MatsubaraConst : public AbstractAugmentation { -public: - MatsubaraConst(T beta) : beta_(beta) { - if (beta_ <= 0) { - throw std::invalid_argument("beta must be positive"); + static std::vector> createAugmentations(const std::vector> &augmentations, + std::shared_ptr basis) { + std::vector> augs; + for (const auto &aug : augmentations) { + augs.push_back(aug->create(basis)); } + return augs; } - T operator()(T tau) const override { return std::numeric_limits::quiet_NaN(); } - T deriv(T tau, int n = 1) const override { return std::numeric_limits::quiet_NaN(); } - std::complex hat(int n) const override { return 1.0; } - std::unique_ptr> clone() const override { - return std::make_unique>(*this); + static F createAugmentedTauFunction(const F &basisFunc, const std::vector> &augmentations) { + // Placeholder for actual implementation + return basisFunc; } -private: - T beta_; + static FHAT createAugmentedMatsubaraFunction(const FHAT &basisFunc, const std::vector> &augmentations) { + // Placeholder for actual implementation + return basisFunc; + } };