Skip to content

Commit

Permalink
Resolve roots
Browse files Browse the repository at this point in the history
  • Loading branch information
terasakisatoshi committed Dec 12, 2024
1 parent a44f74f commit 8103ed1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 80 deletions.
80 changes: 11 additions & 69 deletions include/sparseir/poly.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,80 +245,22 @@ class PiecewiseLegendrePoly {
}
}

// Equivalent to find_all function in julia above
Eigen::VectorXd find_all(const Eigen::VectorXd& xgrid) const {
Eigen::VectorXd fx(xgrid.size());
for (size_t i = 0; i < xgrid.size(); ++i) {
fx(i) = static_cast<double>((*this)(xgrid[i]));
}
// Find direct hits (zeros)
std::vector<bool> hit(fx.size(), false);
std::vector<double> x_hit;
for (Eigen::Index i = 0; i < fx.size(); ++i) {
hit[i] = (fx(i) == 0.0);
if (hit[i]) {
x_hit.push_back(xgrid(i));
}
}

// Check for sign changes
std::vector<bool> sign_change(fx.size() - 1, false);
bool found_sign_change = false;

for (Eigen::Index i = 0; i < fx.size() - 1; ++i) {
bool different_signs = std::signbit(fx(i)) != std::signbit(fx(i + 1));
bool neither_zero = fx(i) != 0.0 && fx(i + 1) != 0.0;
sign_change[i] = different_signs && neither_zero;
if (sign_change[i]) {
found_sign_change = true;
}
}

if (!found_sign_change) {
// convert to Eigen::VectorXd
return Eigen::Map<Eigen::VectorXd>(x_hit.data(), x_hit.size());
}

// Collect points for bisection
std::vector<double> a, b;
std::vector<double> fa;

for (size_t i = 0; i < sign_change.size(); ++i) {
if (sign_change[i]) {
a.push_back(xgrid(i));
b.push_back(xgrid(i + 1));
fa.push_back(fx(i));
}
}

// Calculate epsilon for floating point types
double eps_x;
if (std::is_floating_point<double>::value) {
eps_x = std::numeric_limits<double>::epsilon() * xgrid.cwiseAbs().maxCoeff();
} else {
eps_x = 0;
}

// Perform bisection for each interval
std::vector<double> x_bisect;
for (size_t i = 0; i < a.size(); ++i) {
x_bisect.push_back(bisect(a[i], b[i], fa[i], eps_x));
}

// Combine and sort results
x_hit.insert(x_hit.end(), x_bisect.begin(), x_bisect.end());
std::sort(x_hit.begin(), x_hit.end());
// convert to Eigen::VectorXd
return Eigen::Map<Eigen::VectorXd>(x_hit.data(), x_hit.size());
}

// Roots function
Eigen::VectorXd roots(double tol = 1e-10) const
{
Eigen::VectorXd grid = this->knots;

std::cout << "grid: " << grid.transpose() << "\n";

Eigen::VectorXd refined_grid = refine_grid(grid, 2);
std::cout << "refined_grid: " << refined_grid.size() << "\n";
return find_all(refined_grid);
auto f = [this](double x) { return this->operator()(x); };
// convert to std::vector<double>
std::vector<double> refined_grid_vec(refined_grid.data(),
refined_grid.data() +
refined_grid.size());
std::vector<double> roots = find_all(f, refined_grid_vec);
std::cout << "roots: " << roots.size() << "\n";
return Eigen::Map<Eigen::VectorXd>(roots.data(), roots.size());
}

// Overloaded operators
Expand Down
22 changes: 11 additions & 11 deletions test/poly.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ TEST_CASE("Overlap")
TEST_CASE("Roots")
{
// Initialize data and knots (from Julia code)
Eigen::MatrixXd data(16, 2);
data << 0.16774734206553019, 0.49223680914312595, -0.8276728567928646,
Eigen::VectorXd v(32);
v << 0.16774734206553019, 0.49223680914312595, -0.8276728567928646,
0.16912891046582143, -0.0016231275318572044, 0.00018381683946452256,
-9.699355027805034e-7, 7.60144228530804e-8, -2.8518324490258146e-10,
1.7090590205708293e-11, -5.0081401126025e-14, 2.1244236198427895e-15,
Expand All @@ -311,14 +311,15 @@ TEST_CASE("Roots")
2.8518324490258146e-10, 1.7090590205708293e-11, 5.0081401126025e-14,
2.1244236198427895e-15, -2.0478095258000225e-16, -2.676573801530628e-16,
-2.338165820094204e-16, -1.2050663212312096e-16;
//data.resize(16, 2);

Eigen::Map<Eigen::MatrixXd> data(v.data(), 16, 2);

Eigen::VectorXd knots(3);
knots << 0.0, 0.5, 1.0;
int l = 3;

sparseir::PiecewiseLegendrePoly pwlp(data, knots, l);
/*

// Find roots
Eigen::VectorXd roots = pwlp.roots();

Expand All @@ -329,14 +330,13 @@ TEST_CASE("Roots")
Eigen::VectorXd expected_roots(3);
expected_roots << 0.1118633448586015, 0.4999999999999998, 0.8881366551413985;

// REQUIRE(roots.size() == expected_roots.size());
REQUIRE(roots.size() == expected_roots.size());
for(Eigen::Index i = 0; i < roots.size(); ++i) {
//REQUIRE(std::abs(roots[i] - expected_roots[i]) < 1e-10);
// Verify roots are in domain
//REQUIRE(roots[i] >= knots[0]);
//REQUIRE(roots[i] <= knots[knots.size()-1]);
REQUIRE(std::abs(roots[i] - expected_roots[i]) < 1e-10);
//Verify roots are in domain
REQUIRE(roots[i] >= knots[0]);
REQUIRE(roots[i] <= knots[knots.size()-1]);
// Verify these are actually roots
//REQUIRE(std::abs(pwlp(roots[i])) < 1e-10);
REQUIRE(std::abs(pwlp(roots[i])) < 1e-10);
}
*/
}

0 comments on commit 8103ed1

Please sign in to comment.