Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bspline #39

Merged
merged 24 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d31fc86
added and tested B-spline and c-bspline
BalzaniEdoardo Jul 28, 2023
0a69205
bspline external function
BalzaniEdoardo Jul 29, 2023
13cc208
removed if at the end
BalzaniEdoardo Jul 29, 2023
7c9c6df
removed level from contour
BalzaniEdoardo Jul 29, 2023
9900124
merged main
BalzaniEdoardo Aug 21, 2023
cf33067
"transposed" bspline
BalzaniEdoardo Aug 21, 2023
5680b45
added BSpline and CyclicBspline
BalzaniEdoardo Aug 21, 2023
42f846f
sorted the classes
BalzaniEdoardo Aug 21, 2023
097bcdd
added cyclic bspline to __all__
BalzaniEdoardo Aug 21, 2023
1767250
black linter
BalzaniEdoardo Aug 21, 2023
e3384f9
replaced == to allclose
BalzaniEdoardo Aug 21, 2023
fa582aa
code review fixes
BalzaniEdoardo Aug 28, 2023
6a02fce
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
4b9c246
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
9fd5f9d
simplified algorithm for cyclic splines
BalzaniEdoardo Aug 28, 2023
a8f75b5
fixed regex test
BalzaniEdoardo Aug 28, 2023
378d890
merged
BalzaniEdoardo Aug 28, 2023
2949d85
Update src/neurostatslib/basis.py
BalzaniEdoardo Aug 28, 2023
4103632
split tests
BalzaniEdoardo Aug 28, 2023
f988b1a
Merge branch 'bspline' of github.com:flatironinstitute/generalized-li…
BalzaniEdoardo Aug 28, 2023
efb3c6a
removed commented block
BalzaniEdoardo Sep 10, 2023
ba815cc
fixed leftovers comments
BalzaniEdoardo Sep 13, 2023
0e8b0e0
fixed test adapting it to the new exception logic
BalzaniEdoardo Sep 14, 2023
60ef87a
Merge branch 'main' into bspline
BalzaniEdoardo Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/developers_notes/basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ Abstract Class Basis
├─ Abstract Subclass SplineBasis
│ │
│ └─ Concrete Subclass MSplineBasis
│ ├─ Concrete Subclass MSplineBasis
│ │
│ ├─ Concrete Subclass BSplineBasis
│ │
│ └─ Concrete Subclass CyclicBSplineBasis
├─ Abstract Subclass RaisedCosineBasis
│ │
Expand Down
289 changes: 273 additions & 16 deletions src/neurostatslib/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import abc
from typing import Generator, Tuple

import jax.numpy
import numpy as np
import scipy.linalg
from numpy.typing import NDArray
from numpy.typing import ArrayLike, NDArray
from scipy.interpolate import splev

from neurostatslib.utils import row_wise_kron

__all__ = [
"MSplineBasis",
"BSplineBasis",
"CyclicBSplineBasis",
"RaisedCosineBasisLinear",
"RaisedCosineBasisLog",
"OrthExponentialBasis",
Expand Down Expand Up @@ -78,7 +82,7 @@ def _get_samples(*n_samples: int) -> Generator[NDArray, ...]:
"""
return (np.linspace(0, 1, n_samples[k]) for k in range(len(n_samples)))

def evaluate(self, *xi: NDArray) -> NDArray:
def evaluate(self, *xi: ArrayLike) -> NDArray:
"""
Evaluate the basis set at the given samples x[0],...,x[n] using the subclass-specific "_evaluate" method.

Expand All @@ -95,9 +99,22 @@ def evaluate(self, *xi: NDArray) -> NDArray:
Raises
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there's supposed to be an empty line here

------
ValueError
If the time point number is inconsistent between inputs or if the number of inputs doesn't match what
billbrod marked this conversation as resolved.
Show resolved Hide resolved
the Basis object requires.
- If the time point number is inconsistent between inputs.
- If the number of inputs doesn't match what the Basis object requires.
"""
# check that the input is array-like
if any(
not isinstance(x, (list, tuple, np.ndarray, jax.numpy.ndarray)) for x in xi
):
raise TypeError("Input samples must be array-like!")

# convert to numpy.array of floats
xi = tuple(np.asarray(x, dtype=float) for x in xi)

# check for non-empty samples
if self._has_zero_samples(tuple(len(x) for x in xi)):
raise ValueError("All sample provided must be non empty.")

# checks on input and outputs
self._check_samples_consistency(*xi)
self._check_input_dimensionality(xi)
Expand Down Expand Up @@ -145,6 +162,9 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
"""
self._check_input_dimensionality(n_samples)

if self._has_zero_samples(n_samples):
raise ValueError("All sample counts provided must be greater than zero.")

# get the samples
sample_tuple = self._get_samples(*n_samples)
Xs = np.meshgrid(*sample_tuple, indexing="ij")
Expand All @@ -156,6 +176,10 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:

return *Xs, Y

@staticmethod
def _has_zero_samples(n_samples: Tuple[int]) -> bool:
return any([n <= 0 for n in n_samples])

def _check_input_dimensionality(self, xi: Tuple) -> None:
"""
Check that the number of inputs provided by the user matches the number of inputs required.
Expand Down Expand Up @@ -456,14 +480,30 @@ def _generate_knots(
mn = np.nanpercentile(sample_pts, np.clip(perc_low * 100, 0, 100))
mx = np.nanpercentile(sample_pts, np.clip(perc_high * 100, 0, 100)) + 10**-8

self.knot_locs = np.concatenate(
knot_locs = np.concatenate(
(
mn * np.ones(self.order - 1),
np.linspace(mn, mx, num_interior_knots + 2),
mx * np.ones(self.order - 1),
)
)
return self.knot_locs
return knot_locs

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Check that the spline-basis has at least as many basis as the order.

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type
"""
if self.n_basis_funcs < self.order:
raise ValueError(
f"{self.__class__.__name__} `order` parameter cannot be larger "
"than `n_basis_funcs` parameter."
)


class MSplineBasis(SplineBasis):
Expand Down Expand Up @@ -506,32 +546,169 @@ def _evaluate(self, sample_pts: NDArray) -> NDArray:

"""
# add knots if not passed
self._generate_knots(sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False)
knot_locs = self._generate_knots(
sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False
)

return np.stack(
[
mspline(sample_pts, self.order, i, self.knot_locs)
mspline(sample_pts, self.order, i, knot_locs)
for i in range(self.n_basis_funcs)
],
axis=1,
)

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Check that MSplineBasis has at least as many basis as the order of the spline.
class BSplineBasis(SplineBasis):
"""
B-spline 1-dimensional basis functions.

Parameters
----------
n_basis_funcs :
Number of basis functions.
order :
Order of the splines used in basis functions. Must lie within [1, n_basis_funcs].
The B-splines have (order-2) continuous derivatives at each interior knot.
The higher this number, the smoother the basis representation will be.

Attributes
----------
order :
Spline order.


References
----------
..[2] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5

"""

def __init__(self, n_basis_funcs: int, order: int = 2):
super().__init__(n_basis_funcs, order=order)

def _evaluate(self, sample_pts: NDArray) -> NDArray:
"""
Evaluate the B-spline basis functions with given sample points.

Parameters
----------
sample_pts :
The sample points at which the B-spline is evaluated.

Returns
-------
NDArray
The basis function evaluated at the samples, shape (n_samples, n_basis_funcs)

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type
AssertionError
If the sample points are not within the B-spline knots range unless `outer_ok=True`.

Notes
-----
The evaluation is performed by looping over each element and using `splev`
from SciPy to compute the basis values.
"""
if self.n_basis_funcs < self.order:

# add knots
knot_locs = self._generate_knots(sample_pts, 0.0, 1.0)

basis_eval = bspline(
sample_pts, knot_locs, order=self.order, der=0, outer_ok=False
)

return basis_eval


class CyclicBSplineBasis(SplineBasis):
"""
B-spline 1-dimensional basis functions for cyclic splines.

Parameters
----------
n_basis_funcs :
Number of basis functions.
order :
Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs].
The B-splines have (order-2) continuous derivatives at each interior knot.
The higher this number, the smoother the basis representation will be.

Attributes
----------
n_basis_funcs : int
Number of basis functions.
order : int
Order of the splines used in basis functions.
"""

def __init__(self, n_basis_funcs: int, order: int = 2):
super().__init__(n_basis_funcs, order=order)
if self.order < 2:
raise ValueError(
f"{self.__class__.__name__} `order` parameter cannot be larger "
"than `n_basis_funcs` parameter."
f"Order >= 2 required for cyclic B-spline, "
f"order {self.order} specified instead!"
)

def _evaluate(self, sample_pts: NDArray) -> NDArray:
"""
Evaluate the B-spline basis functions with given sample points.

Parameters
----------
sample_pts :
The sample points at which the B-spline is evaluated. Must be a tuple of length 1.

Returns
-------
NDArray
The basis function evaluated at the samples, shape (n_samples, n_basis_funcs)

Raises
------
AssertionError
If the sample points are not within the B-spline knots range unless `outer_ok=True`.

Notes
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
"""

knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True)

# for cyclic, do not repeat knots
knot_locs = np.unique(knot_locs)

nk = knot_locs.shape[0]

# make sure knots are sorted
knot_locs.sort()

# extend knots
xc = knot_locs[nk - self.order]
knots = np.hstack(
(
knot_locs[0] - knot_locs[-1] + knot_locs[nk - self.order : nk - 1],
knot_locs,
)
)
ind = sample_pts > xc

basis_eval = bspline(sample_pts, knots, order=self.order, der=0, outer_ok=True)
sample_pts[ind] = sample_pts[ind] - knots.max() + knot_locs[0]

if np.sum(ind):
basis_eval[ind] = basis_eval[ind] + bspline(
sample_pts[ind], knots, order=self.order, outer_ok=True, der=0
)
# restore points
sample_pts[ind] = sample_pts[ind] + knots.max() - knot_locs[0]

return basis_eval


class RaisedCosineBasis(Basis, abc.ABC):
def __init__(self, n_basis_funcs: int) -> None:
Expand Down Expand Up @@ -870,3 +1047,83 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray):
)
/ ((k - 1) * (T[i + k] - T[i]))
)


def bspline(
sample_pts: NDArray,
knots: NDArray,
order: int = 4,
der: int = 0,
outer_ok: bool = False,
):
"""
Calculate and return the evaluation of B-spline basis.

This function evaluates B-spline basis for given sample points. It checks for
out of range points and optionally handles them. It also handles the NaNs if present.

Parameters
----------
sample_pts :
An array containing sample points for which B-spline basis needs to be evaluated.
knots :
An array containing knots for the B-spline basis. The knots are sorted in ascending order.
order :
The order of the B-spline basis.
der :
The derivative of the B-spline basis to be evaluated.
outer_ok :
If True, allows for evaluation at points outside the range of knots.
Default is False, in which case an assertion error is raised when
points outside the knots range are encountered.

Returns
-------
basis_eval :
An array containing the evaluation of B-spline basis for the given sample points.
Shape (n_samples, n_basis_funcs).

Raises
------
AssertionError
If `outer_ok` is False and the sample points lie outside the B-spline knots range.

Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.
"""
knots.sort()
nk = knots.shape[0]

# check for out of range points (in cyclic b-spline need_outer must be set to False)
need_outer = any(sample_pts < knots[order - 1]) or any(
sample_pts > knots[nk - order]
)
assert (
not need_outer
) | outer_ok, 'sample points must lie within the B-spline knots range unless "outer_ok==True".'

# select knots that are within the knots range (this takes care of eventual NaNs)
in_sample = (sample_pts >= knots[0]) & (sample_pts <= knots[-1])

if need_outer:
reps = order - 1
knots = np.hstack((np.ones(reps) * knots[0], knots, np.ones(reps) * knots[-1]))
nk = knots.shape[0]
else:
reps = 0

# number of basis elements
n_basis = nk - order

# initialize the basis element container
basis_eval = np.zeros((n_basis - 2 * reps, sample_pts.shape[0]))

# loop one element at the time and evaluate the basis using splev
id_basis = np.eye(n_basis, nk, dtype=np.int8)
for i in range(reps, len(knots) - order - reps):
basis_eval[i - reps, in_sample] = splev(
sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der
)

return basis_eval.T
Loading