Skip to content

Commit

Permalink
make root finding a little more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
tk committed Apr 18, 2021
1 parent 324c30a commit 5894bea
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
23 changes: 23 additions & 0 deletions examples/simple_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,28 @@ int main(int argc, char** argv)

printf("spline(%.3f): cubic (C^2) = %.3f, cubic hermite = %.3f, linear = %.3f\n", x, s1(x), s2(x), s3(x));

// solve for x so that f(x) = y
double y=0.6;
std::vector<double> sol; // solutions
printf("solutions to f(x) = %f\n", y);

sol=s1.solve(y);
printf("C^2 spline: ");
for(size_t i=0; i<sol.size(); i++)
printf("%f\t", sol[i]);
printf("\n");

sol=s2.solve(y);
printf("C^1 Hermite spline: ");
for(size_t i=0; i<sol.size(); i++)
printf("%f\t", sol[i]);
printf("\n");

sol=s3.solve(y);
printf("linear : ");
for(size_t i=0; i<sol.size(); i++)
printf("%f\t", sol[i]);
printf("\n");

return EXIT_SUCCESS;
}
57 changes: 57 additions & 0 deletions examples/tests/unit_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,60 @@ BOOST_AUTO_TEST_CASE( SplineSolve )
}
}
}
BOOST_AUTO_TEST_CASE( SplineSolve2 )
{
const double max_func = 2e-12; // f(numerical root) <= max_func+noise
const double dy = 2e-15; // change in right hand side: y=y+dy
const int loops = 100;

std::vector<double> X = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14};
std::vector<double> Y = {0, 1, 2, 1, 2, 1, 2, 1, 2, 0, 2, 0};

// setup all possible types of splines which are at least C^0

const double y=1.0;
{
tk::spline s(X,Y, tk::spline::cspline);
for(int i=-loops/2; i<loops/2; i++) {
BOOST_TEST_CONTEXT("spline: C2: f(x)=" << y << " + " << (dy*i)) {
std::vector<double> root = s.solve(y+i*dy);
BOOST_CHECK(root.size()==10);
if(i==0) {
BOOST_CHECK(std::find(root.begin(),root.end(),1.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),3.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),5.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),7.0)!=root.end());
}
for(size_t i=0; i<root.size(); i++) {
double y0 = s(root[i]);
BOOST_CHECK_SMALL(y0-y, max_func);
}
} // BOOST_TEST_CONTEXT
}
}
{
tk::spline s(X,Y, tk::spline::cspline_hermite);
for(int i=-loops/2; i<loops/2; i++) {
BOOST_TEST_CONTEXT("spline: C1 Hermite: f(x)=" << y << " + " << (dy*i)) {
std::vector<double> root = s.solve(y+i*dy);
if(i<0) {
BOOST_CHECK(root.size()==4);
} else if(i==0) {
BOOST_CHECK(root.size()==7);
} else {
BOOST_CHECK(root.size()==10);
}
if(i==0) {
BOOST_CHECK(std::find(root.begin(),root.end(),1.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),3.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),5.0)!=root.end());
BOOST_CHECK(std::find(root.begin(),root.end(),7.0)!=root.end());
}
for(size_t i=0; i<root.size(); i++) {
double y0 = s(root[i]);
BOOST_CHECK_SMALL(y0-y, max_func);
}
} // BOOST_TEST_CONTEXT
}
}
}
29 changes: 26 additions & 3 deletions src/spline.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class band_matrix

};

double get_eps();

std::vector<double> solve_cubic(double a, double b, double c, double d,
int newton_iter=0);

Expand Down Expand Up @@ -569,10 +571,17 @@ std::vector<double> spline::solve(double y, bool ignore_extrapolation) const
// brute force check if piecewise cubic has roots in their resp. segment
// TODO: make more efficient
for(size_t i=0; i<n-1; i++) {
root = internal::solve_cubic(m_y[i]-y,m_b[i],m_c[i],m_d[i],2);
root = internal::solve_cubic(m_y[i]-y,m_b[i],m_c[i],m_d[i],1);
for(size_t j=0; j<root.size(); j++) {
if( (0.0<=root[j]) && (root[j]<m_x[i+1]-m_x[i]) ) {
x.push_back(m_x[i]+root[j]);
double h = (i>0) ? (m_x[i]-m_x[i-1]) : 0.0;
double eps = internal::get_eps()*512.0*std::min(h,1.0);
if( (-eps<=root[j]) && (root[j]<m_x[i+1]-m_x[i]) ) {
double new_root = m_x[i]+root[j];
if(x.size()>0 && x.back()+eps > new_root) {
x.back()=new_root; // avoid spurious duplicate roots
} else {
x.push_back(new_root);
}
}
}
}
Expand Down Expand Up @@ -906,6 +915,20 @@ std::vector<double> solve_cubic(double a, double b, double c, double d,
}
}
}
// ensure if a=0 we get exactly x=0 as root
// TODO: remove this fudge
if(a==0.0) {
assert(z.size()>0); // cubic should always have at least one root
double xmin=fabs(z[0]);
size_t imin=0;
for(size_t i=1; i<z.size(); i++) {
if(xmin>fabs(z[i])) {
xmin=fabs(z[i]);
imin=i;
}
}
z[imin]=0.0; // replace the smallest absolute value with 0
}
std::sort(z.begin(), z.end());
return z;
}
Expand Down

0 comments on commit 5894bea

Please sign in to comment.