Skip to content

Commit

Permalink
Merge pull request #67 from SpM-lab/terasaki/rewrite-augment.hpp
Browse files Browse the repository at this point in the history
Update augment.hpp
  • Loading branch information
terasakisatoshi authored Dec 22, 2024
2 parents 99d774b + e53890a commit 697ae88
Showing 1 changed file with 158 additions and 102 deletions.
260 changes: 158 additions & 102 deletions include/sparseir/augment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,156 +10,212 @@

namespace sparseir {

template <typename T>
class AbstractAugmentation {
// Abstract Augmentation
class AbstractAugmentation : public std::enable_shared_from_this<AbstractAugmentation> {
public:
virtual T operator()(T tau) const = 0; // Ensure this is virtual
virtual T deriv(T tau, int n = 1) const = 0; // Add this line
virtual std::complex<T> hat(int n) const = 0; // Add this line
virtual double operator()(double tau) const = 0;
virtual double operator()(int bosonicFreq) const = 0;
virtual std::function<double(double)> deriv(int order = 1) const = 0;
virtual ~AbstractAugmentation() = default;
virtual std::unique_ptr<AbstractAugmentation<T>> clone() const = 0;
};

template <typename S>
class AugmentedBasis : public AbstractBasis<S> {
// TauConst Class
class TauConst : public AbstractAugmentation {
public:
AugmentedBasis(std::shared_ptr<AbstractBasis<S>> basis,
const std::vector<std::unique_ptr<AbstractAugmentation<S>>>& augmentations)
: _basis(basis), _augmentations(augmentations), _naug(augmentations.size()) {
//Error Handling: Check for null basis pointer
if (!_basis) {
throw std::invalid_argument("Basis cannot be null");
double beta;

explicit TauConst(double beta) : beta(beta) {
if (beta <= 0) {
throw std::domain_error("Temperature must be positive.");
}
//Check for valid augmentations
for (const auto& aug : _augmentations) {
if (!aug) {
throw std::invalid_argument("Augmentation cannot be null");
}
}

double operator()(double tau) const override {
if (tau < 0 || tau > beta) {
throw std::domain_error("tau must be in [0, beta].");
}
return 1.0 / std::sqrt(beta);
}

size_t size() const { return _basis->size() + _naug; }
double operator()(int bosonicFreq) const override {
return (bosonicFreq == 0) ? std::sqrt(beta) : 0.0;
}

Eigen::VectorXd u(const Eigen::VectorXd& tau) const {
Eigen::VectorXd result(size());
for (size_t i = 0; i < _naug; ++i) {
result(i) = (*_augmentations[i])(tau(i));
}
for (size_t i = _naug; i < size(); ++i) {
result(i) = _basis->u(tau(i - _naug))(i - _naug);
std::function<double(double)> deriv(int order = 1) const override {
if (order == 0) {
return [this](double tau) { return (*this)(tau); };
}
return result;
return [](double) { return 0.0; };
}
};

// TauLinear Class
class TauLinear : public AbstractAugmentation {
public:
double beta;
double norm;

Eigen::VectorXcf uhat(const Eigen::VectorXcf& wn) const {
Eigen::VectorXcf result(size());
for (size_t i = 0; i < _naug; ++i) {
result(i) = (*_augmentations[i]).hat(wn(i));
explicit TauLinear(double beta) : beta(beta), norm(std::sqrt(3.0 / beta)) {
if (beta <= 0) {
throw std::domain_error("Temperature must be positive.");
}
for (size_t i = _naug; i < size(); ++i) {
result(i) = _basis->uhat(wn(i - _naug))(i - _naug);
}

double operator()(double tau) const override {
if (tau < 0 || tau > beta) {
throw std::domain_error("tau must be in [0, beta].");
}
return result;
double x = 2.0 / beta * tau - 1.0;
return norm * x;
}

Eigen::VectorXd v(const Eigen::VectorXd& w) const {
return _basis->v(w);
double operator()(int bosonicFreq) const override {
double inv_w = (bosonicFreq == 0) ? std::numeric_limits<double>::infinity() : 1.0 / bosonicFreq;
std::complex<double> imag_unit(0.0, 1.0); // 複素数の虚数単位
return norm * 2.0 / imag_unit.imag() * inv_w;
}

Eigen::VectorXd s() const { return _basis->s(); }
std::function<double(double)> deriv(int order = 1) const override {
if (order == 0) {
return [this](double tau) { return (*this)(tau); };
} else if (order == 1) {
return [this](double) { return norm * 2.0 / beta; };
}
return [](double) { return 0.0; };
}
};

double beta() const { return _basis->beta(); }
// MatsubaraConst Class
class MatsubaraConst : public AbstractAugmentation {
public:
double beta;

double wmax() const { return _basis->wmax(); }
explicit MatsubaraConst(double beta) : beta(beta) {
if (beta <= 0) {
throw std::domain_error("Temperature must be positive.");
}
}

std::shared_ptr<Statistics> statistics() const {
return _basis->statistics(); // Assuming _basis also returns a shared_ptr
double operator()(double tau) const override {
if (tau < 0 || tau > beta) {
throw std::domain_error("tau must be in [0, beta].");
}
return std::numeric_limits<double>::quiet_NaN();
}

private:
std::shared_ptr<AbstractBasis<S>> _basis;
std::vector<std::unique_ptr<AbstractAugmentation<S>>> _augmentations;
size_t _naug;
double operator()(int matsubaraFreq) const override {
return 1.0;
}

std::function<double(double)> deriv(int order = 1) const override {
return [this](double tau) { return (*this)(tau); };
}
};

// AbstractAugmentedFunction
class AbstractAugmentedFunction {
public:
virtual size_t size() const = 0;
virtual Eigen::VectorXd operator()(double x) const = 0;
virtual Eigen::MatrixXd operator()(const Eigen::VectorXd &x) const = 0;
virtual ~AbstractAugmentedFunction() = default;
};

template <typename T>
class TauConst : public AbstractAugmentation<T> {
// AugmentedFunction
template <typename FB, typename FA>
class AugmentedFunction : public AbstractAugmentedFunction {
public:
TauConst(T beta) : beta_(beta), norm_(1.0 / std::sqrt(beta)) {
if (beta_ <= 0) {
throw std::invalid_argument("beta must be positive");
}
}
FB fbasis;
std::vector<FA> faug;

T operator()(T tau) const override {
check_tau_range(tau, beta_);
return norm_;
}
T deriv(T tau, int n = 1) const {
if (n == 0) return (*this)(tau);
return 0.0;
}
std::complex<T> hat(int n) const {
return norm_ * std::sqrt(beta_);
AugmentedFunction(FB fbasis, std::vector<FA> faug) : fbasis(fbasis), faug(faug) {}

size_t size() const override {
return faug.size() + fbasis.size();
}
std::unique_ptr<AbstractAugmentation<T>> clone() const override {
return std::make_unique<TauConst<T>>(*this);

Eigen::VectorXd operator()(double x) const override {
Eigen::VectorXd fbasis_x = fbasis(x);
Eigen::VectorXd faug_x(faug.size());
for (size_t i = 0; i < faug.size(); ++i) {
faug_x[i] = faug[i](x);
}
Eigen::VectorXd result(faug.size() + fbasis_x.size());
result << faug_x, fbasis_x;
return result;
}

private:
T beta_;
T norm_;
Eigen::MatrixXd operator()(const Eigen::VectorXd &x) const override {
Eigen::MatrixXd fbasis_x = fbasis(x);
Eigen::MatrixXd faug_x(faug.size(), x.size());
for (size_t i = 0; i < faug.size(); ++i) {
for (size_t j = 0; j < x.size(); ++j) {
faug_x(i, j) = faug[i](x[j]);
}
}
Eigen::MatrixXd result(faug.size() + fbasis_x.rows(), x.size());
result << faug_x, fbasis_x;
return result;
}
};

template <typename T>
class TauLinear : public AbstractAugmentation<T> {
// AugmentedBasis
template <typename S, typename B, typename A, typename F, typename FHAT>
class AugmentedBasis : public AbstractBasis<S> {
public:
TauLinear(T beta) : beta_(beta), norm_(std::sqrt(3.0 / beta)) {
if (beta_ <= 0) {
throw std::invalid_argument("beta must be positive");
}
std::shared_ptr<B> basis;
std::vector<std::shared_ptr<A>> augmentations;
F u;
FHAT uhat;

AugmentedBasis(std::shared_ptr<B> basis,
std::vector<std::shared_ptr<A>> augmentations,
F u, FHAT uhat)
: AbstractBasis<S>(basis->beta), basis(basis), augmentations(augmentations), u(u), uhat(uhat) {}

size_t size() const override {
return nAug() + basis->size();
}

T operator()(T tau) const override {
check_tau_range(tau, beta_);
return norm_ * (2.0 * tau / beta_ - 1.0);
size_t nAug() const {
return augmentations.size();
}
T deriv(T tau, int n = 1) const override {
if (n == 1) return norm_ * 2.0 / beta_;
return 0.0;

double accuracy() const override {
return basis->accuracy();
}
std::complex<T> hat(int n) const override {
if (n == 0) return 0.0;
return norm_ * 2.0 * beta_ / (n * M_PI * std::complex<T>(0, 1));

double omegaMax() const override {
return basis->omegaMax();
}
std::unique_ptr<AbstractAugmentation<T>> clone() const override {
return std::make_unique<TauLinear<T>>(*this);

static std::shared_ptr<AugmentedBasis> create(std::shared_ptr<B> basis,
std::vector<std::shared_ptr<A>> augmentations) {
auto augs = createAugmentations(augmentations, basis);
auto u = createAugmentedTauFunction(basis->u, augs);
auto uhat = createAugmentedMatsubaraFunction(basis->uhat, augs);
return std::make_shared<AugmentedBasis>(basis, augs, u, uhat);
}

private:
T beta_;
T norm_;
};

template <typename T>
class MatsubaraConst : public AbstractAugmentation<T> {
public:
MatsubaraConst(T beta) : beta_(beta) {
if (beta_ <= 0) {
throw std::invalid_argument("beta must be positive");
static std::vector<std::shared_ptr<A>> createAugmentations(const std::vector<std::shared_ptr<A>> &augmentations,
std::shared_ptr<B> basis) {
std::vector<std::shared_ptr<A>> augs;
for (const auto &aug : augmentations) {
augs.push_back(aug->create(basis));
}
return augs;
}

T operator()(T tau) const override { return std::numeric_limits<T>::quiet_NaN(); }
T deriv(T tau, int n = 1) const override { return std::numeric_limits<T>::quiet_NaN(); }
std::complex<T> hat(int n) const override { return 1.0; }
std::unique_ptr<AbstractAugmentation<T>> clone() const override {
return std::make_unique<MatsubaraConst<T>>(*this);
static F createAugmentedTauFunction(const F &basisFunc, const std::vector<std::shared_ptr<A>> &augmentations) {
// Placeholder for actual implementation
return basisFunc;
}

private:
T beta_;
static FHAT createAugmentedMatsubaraFunction(const FHAT &basisFunc, const std::vector<std::shared_ptr<A>> &augmentations) {
// Placeholder for actual implementation
return basisFunc;
}
};


Expand Down

0 comments on commit 697ae88

Please sign in to comment.