Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bspline #39

Merged
merged 24 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d31fc86
added and tested B-spline and c-bspline
BalzaniEdoardo Jul 28, 2023
0a69205
bspline external function
BalzaniEdoardo Jul 29, 2023
13cc208
removed if at the end
BalzaniEdoardo Jul 29, 2023
7c9c6df
removed level from contour
BalzaniEdoardo Jul 29, 2023
9900124
merged main
BalzaniEdoardo Aug 21, 2023
cf33067
"transposed" bspline
BalzaniEdoardo Aug 21, 2023
5680b45
added BSpline and CyclicBspline
BalzaniEdoardo Aug 21, 2023
42f846f
sorted the classes
BalzaniEdoardo Aug 21, 2023
097bcdd
added cyclic bspline to __all__
BalzaniEdoardo Aug 21, 2023
1767250
black linter
BalzaniEdoardo Aug 21, 2023
e3384f9
replaced == to allclose
BalzaniEdoardo Aug 21, 2023
fa582aa
code review fixes
BalzaniEdoardo Aug 28, 2023
6a02fce
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
4b9c246
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
9fd5f9d
simplified algorithm for cyclic splines
BalzaniEdoardo Aug 28, 2023
a8f75b5
fixed regex test
BalzaniEdoardo Aug 28, 2023
378d890
merged
BalzaniEdoardo Aug 28, 2023
2949d85
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
4103632
split tests
BalzaniEdoardo Aug 28, 2023
f988b1a
Merge branch 'bspline' of github.com:flatironinstitute/generalized-li…
BalzaniEdoardo Aug 28, 2023
efb3c6a
removed commented block
BalzaniEdoardo Sep 10, 2023
ba815cc
fixed leftovers comments
BalzaniEdoardo Sep 13, 2023
0e8b0e0
fixed test adapting it to the new exception logic
BalzaniEdoardo Sep 14, 2023
60ef87a
Merge branch 'main' into bspline
BalzaniEdoardo Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/developers_notes/basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ Abstract Class Basis
├─ Abstract Subclass SplineBasis
│ │
│ └─ Concrete Subclass MSplineBasis
│ ├─ Concrete Subclass MSplineBasis
│ │
│ ├─ Concrete Subclass BSplineBasis
│ │
│ └─ Concrete Subclass CyclicBSplineBasis
├─ Abstract Subclass RaisedCosineBasis
│ │
Expand Down
308 changes: 302 additions & 6 deletions src/neurostatslib/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import numpy as np
import scipy.linalg
from numpy.typing import NDArray
from scipy.interpolate import splev

from neurostatslib.utils import row_wise_kron

__all__ = [
"MSplineBasis",
"BSplineBasis",
"CyclicBSplineBasis",
"RaisedCosineBasisLinear",
"RaisedCosineBasisLog",
"OrthExponentialBasis",
Expand Down Expand Up @@ -91,12 +94,6 @@ def evaluate(self, *xi: NDArray) -> NDArray:
-------
:
The generated basis functions.

Raises
------
ValueError
If the time point number is inconsistent between inputs or if the number of inputs doesn't match what
billbrod marked this conversation as resolved.
Show resolved Hide resolved
the Basis object requires.
"""
# checks on input and outputs
self._check_samples_consistency(*xi)
Expand Down Expand Up @@ -465,6 +462,13 @@ def _generate_knots(
)
return self.knot_locs

@staticmethod
def _check_samples_non_empty(sample_pts):
if sample_pts.shape[0] == 0:
raise ValueError(
"Empty sample array provided. At least one sample is required for evaluation!"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here and only for BSpline?

from the tests, looks like you're expecting it when evaluate_on_grid(x) for x<=0, but is that really an issue? for the negative case, we'll get an error in np.linspace and for x=0, you'd get an empty array otherwise, which seems reasonable. Is there a specific reason to check for this?

and in general, don't like that we'll have different behavior for BSpline than for the rest, when really, it's the same concerns, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't check, with x=0 you get an "invalid input data" error message from the scipy splev call.

I am fine with that but:

  • it is a bit nested and less clear for a user what may have caused it.
  • MSpline, RaisedCosineLog, RaisedCosineLinear returns an empty array, not an error.
  • OrthExp raises another class specific error, which is that you need at least as many samples as basis function to define the basis (otherwise the ortonormalization would be not-defined).

Uniformity in behavior is hard to get given that different basis construction have different requirements.
My rationale here was try to surface the error and give a basis-type specific explanation of the error which is understandable.
If we try to get something that looks more uniform there are a few way:

  • return an empty array for B-spline too instead of calling splev in case of empty input
  • raise the error instead of returning an empty array for all of the basis function we have

Which one do you prefer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I'm inclined to work towards uniformity in behavior, and I don't have a strong preference here, but I think raising the error for all of them is the right way? because if someone calls evaluate_on_grid(0), they've really misunderstood what it does. right?

)


class MSplineBasis(SplineBasis):
"""M-spline 1-dimensional basis functions.
Expand Down Expand Up @@ -533,6 +537,218 @@ def _check_n_basis_min(self) -> None:
)


class BSplineBasis(SplineBasis):
"""
B-spline 1-dimensional basis functions.

Parameters
----------
n_basis_funcs :
Number of basis functions.
order :
Order of the splines used in basis functions. Must lie within [1, n_basis_funcs].
The B-splines have (order-2) continuous derivatives at each interior knot.
The higher this number, the smoother the basis representation will be.

Attributes
----------
order :
Spline order.

Methods
-------
_evaluate(x_tuple)
billbrod marked this conversation as resolved.
Show resolved Hide resolved
Evaluate the basis function at the samples x_tuple[0]. x_tuple must be of length 1 in order to pass the checks
of super().evaluate

References
----------
..[2] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

"""

def __init__(self, n_basis_funcs: int, order: int = 2):
super().__init__(n_basis_funcs, order=order)

def _evaluate(self, sample_pts: NDArray) -> NDArray:
"""
Evaluate the B-spline basis functions with given sample points.

Parameters
----------
sample_pts :
The sample points at which the B-spline is evaluated.

Returns
-------
NDArray
The basis function evaluated at the samples (Time points x number of basis)
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved

Raises
------
AssertionError
If the sample points are not within the B-spline knots range unless `outer_ok=True`.

Notes
-----
This method evaluates the B-spline basis functions at the given sample points.
It requires the knots to be defined through the `_generate_knots` method.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

The evaluation is performed by looping over each element and using `splev`
from SciPy to compute the basis values.
"""
super()._check_samples_non_empty(sample_pts)

# add knots
self._generate_knots(sample_pts, 0.0, 1.0)

basis_eval = bspline(
sample_pts, self.knot_locs, order=self.order, der=0, outer_ok=False
)

return basis_eval

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Check that BSplineBasis has at least as many basis as the order of the spline.

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type
"""
if self.n_basis_funcs < self.order:
raise ValueError(
f"{self.__class__.__name__} `order` parameter cannot be larger "
"than `n_basis_funcs` parameter."
)


class CyclicBSplineBasis(SplineBasis):
"""
B-spline 1-dimensional basis functions for cyclic splines.

Parameters
----------
n_basis_funcs :
Number of basis functions.
order :
Order of the splines used in basis functions. Must lie within [1, n_basis_funcs].
billbrod marked this conversation as resolved.
Show resolved Hide resolved
The B-splines have (order-2) continuous derivatives at each interior knot.
The higher this number, the smoother the basis representation will be.

Attributes
----------
n_basis_funcs : int
Number of basis functions.
order : int
Order of the splines used in basis functions.

Methods
-------
_evaluate(sample_pts, der)
billbrod marked this conversation as resolved.
Show resolved Hide resolved
Evaluate the B-spline basis functions with given sample points.
"""

def __init__(self, n_basis_funcs: int, order: int = 2):
super().__init__(n_basis_funcs, order=order)
if self.order < 2:
raise ValueError(
f"Order >= 2 required for cyclic B-spline, "
f"order {self.order} specified instead!"
)

def _evaluate(self, sample_pts: NDArray) -> NDArray:
"""
Evaluate the B-spline basis functions with given sample points.

Parameters
----------
sample_pts :
The sample points at which the B-spline is evaluated. Must be a tuple of length 1.

Returns
-------
NDArray
The basis function evaluated at the samples (Time points x number of basis)
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved

Raises
------
AssertionError
If the sample points are not within the B-spline knots range unless `outer_ok=True`.

Notes
-----
This method evaluates the B-spline basis functions at the given sample points.
It requires the knots to be defined through the `_generate_knots` method.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
"""
super()._check_samples_non_empty(sample_pts)

self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True)

# for cyclic, do not repeat knots
self.knot_locs = np.unique(self.knot_locs)

knots_orig = self.knot_locs.copy()

nk = knots_orig.shape[0]

# make sure knots are sorted
knots_orig.sort()

# extend knots
xc = knots_orig[nk - 2 * self.order + 1]
knots = np.hstack(
(
self.knot_locs[0]
- self.knot_locs[-1]
+ self.knot_locs[nk - self.order : nk - 1],
self.knot_locs,
)
)
ind = sample_pts > xc

# temporarily set the extended knots as attribute
self.knot_locs = knots
billbrod marked this conversation as resolved.
Show resolved Hide resolved

basis_eval = bspline(
sample_pts, self.knot_locs, order=self.order, der=0, outer_ok=True
)
sample_pts[ind] = sample_pts[ind] - knots.max() + knots_orig[0]

if np.sum(ind):
basis_eval[ind] = basis_eval[ind] + bspline(
sample_pts[ind], self.knot_locs, order=self.order, outer_ok=True, der=0
)
# restore points
sample_pts[ind] = sample_pts[ind] + knots.max() - knots_orig[0]
# restore the original knots
self.knot_locs = knots_orig

return basis_eval

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Check that Cuclic-BSplineBasis has at least as many basis as the order of the spline +2
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
and at least 2*order - 2 basis.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type
"""
if self.n_basis_funcs < max(self.order * 2 - 2, self.order + 2):
raise ValueError(
f"Insufficient basis elements for {self.__class__.__name__} instantiation."
)


class RaisedCosineBasis(Basis, abc.ABC):
def __init__(self, n_basis_funcs: int) -> None:
super().__init__(n_basis_funcs)
Expand Down Expand Up @@ -870,3 +1086,83 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray):
)
/ ((k - 1) * (T[i + k] - T[i]))
)


def bspline(
sample_pts: NDArray,
knots: NDArray,
order: int = 4,
der: int = 0,
outer_ok: bool = False,
):
"""
Calculate and return the evaluation of B-spline basis.

This function evaluates B-spline basis for given sample points. It checks for
out of range points and optionally handles them. It also handles the NaNs if present.

Parameters
----------
sample_pts :
An array containing sample points for which B-spline basis needs to be evaluated.
knots :
An array containing knots for the B-spline basis. The knots are sorted in ascending order.
order :
The order of the B-spline basis. Default is 4.
der :
The derivative of the B-spline basis to be evaluated. Default is 0.
outer_ok :
billbrod marked this conversation as resolved.
Show resolved Hide resolved
If True, allows for evaluation at points outside the range of knots.
Default is False, in which case an assertion error is raised when
points outside the knots range are encountered.

Returns
-------
basis_eval :
An array containing the evaluation of B-spline basis for the given sample points.
Shape (n_samples, n_basis_funcs).

Raises
------
AssertionError
If `outer_ok` is False and the sample points lie outside the B-spline knots range.

Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.
"""
knots.sort()
nk = knots.shape[0]

# check for out of range points (in cyclic b-spline need_outer must be set to False)
need_outer = any(sample_pts < knots[order - 1]) or any(
sample_pts > knots[nk - order]
)
assert (
not need_outer
) | outer_ok, 'sample points must lie within the B-spline knots range unless "outer_ok==True".'

# select knots that are within the knots range (this takes care of eventual NaNs)
in_sample = (sample_pts >= knots[0]) & (sample_pts <= knots[-1])

if need_outer:
reps = order - 1
knots = np.hstack((np.ones(reps) * knots[0], knots, np.ones(reps) * knots[-1]))
nk = knots.shape[0]
else:
reps = 0

# number of basis elements
n_basis = nk - order

# initialize the basis element container
basis_eval = np.zeros((n_basis - 2 * reps, sample_pts.shape[0]))

# loop one element at the time and evaluate the basis using splev
id_basis = np.eye(n_basis, nk, dtype=np.int8)
for i in range(reps, len(knots) - order - reps):
basis_eval[i - reps, in_sample] = splev(
sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der
)

return basis_eval.T
Loading