diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 40ad80d..585fabc 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -229,6 +229,7 @@ def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None): self._x = x self._y = y self._coefficients = coefficients + self._endpoints = endpoints # Operations for flattening/unflattening representation def tree_flatten(self):