Skip to content

Commit

Permalink
merged main
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 31, 2024
2 parents 6056838 + d4f6524 commit c1e2297
Showing 1 changed file with 214 additions and 2 deletions.
216 changes: 214 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,19 @@ def fit(self, X: FeatureMatrix, y=None):
-------
self :
The transformer object.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(100, 2))
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X)
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand All @@ -224,6 +237,28 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
-------
:
The data transformed by the basis functions.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(10000, 2))
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_
>>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)
>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1])
"""
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
Expand All @@ -249,6 +284,21 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
array-like
The data transformed by the basis functions, after fitting the basis
functions to the data.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(100, 1))
>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X)
"""
return self._basis.compute_features(*self._unpack_inputs(X))

Expand Down Expand Up @@ -755,6 +805,19 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
input samples with the basis functions. The output shape varies based on
the subclass and mode.
Examples
-------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis
>>> # Generate data
>>> num_samples = 10000
>>> X = np.random.normal(size=(num_samples, )) # raw time series
>>> basis = BSplineBasis(10)
>>> features = basis.compute_features(X) # basis transformed time series
>>> features.shape
(10000, 10)
Notes
-----
Subclasses should implement how to handle the transformation specific to their
Expand Down Expand Up @@ -932,6 +995,19 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
This differs from the numpy.meshgrid default, which uses Cartesian indexing.
For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$.
Examples
--------
>>> # Evaluate and visualize 4 M-spline basis functions of order 3:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import MSplineBasis
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> p = plt.plot(sample_points, basis_values)
>>> _ = plt.title('M-Spline Basis Functions')
>>> _ = plt.xlabel('Domain')
>>> _ = plt.ylabel('Basis Function Value')
>>> _ = plt.legend([f'Function {i+1}' for i in range(4)]);
"""
self._check_input_dimensionality(n_samples)

Expand Down Expand Up @@ -1121,7 +1197,22 @@ class AdditiveBasis(Basis):
n_basis_funcs : int
Number of basis functions.
Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 2))
>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = basis_1 + basis_2
>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1233,6 +1324,22 @@ class MultiplicativeBasis(Basis):
n_basis_funcs : int
Number of basis functions.
Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 3))
>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = basis_1 * basis_2
>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1351,7 +1458,6 @@ class SplineBasis(Basis, abc.ABC):
----------
order : int
Spline order.
"""

def __init__(
Expand Down Expand Up @@ -1673,6 +1779,14 @@ class BSplineBasis(SplineBasis):
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import BSplineBasis
>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1752,6 +1866,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import BSplineBasis
>>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1790,6 +1912,16 @@ class CyclicBSplineBasis(SplineBasis):
Number of basis functions.
order : int
Order of the splines used in basis functions.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import CyclicBSplineBasis
>>> X = np.random.normal(size=(1000, 1))
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1897,6 +2029,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import CyclicBSplineBasis
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1929,6 +2069,16 @@ class RaisedCosineBasisLinear(Basis):
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLinear
>>> X = np.random.normal(size=(1000, 1))
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2068,6 +2218,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
basis_funcs :
Raised cosine basis functions, shape (n_samples, n_basis_funcs)
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import RaisedCosineBasisLinear
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -2125,6 +2282,16 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLog
>>> X = np.random.normal(size=(1000, 1))
>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2281,6 +2448,18 @@ class OrthExponentialBasis(Basis):
For example, changing the `predictor_causality`, which by default is set to `"causal"`.
Note that one cannot change the default value for the `axis` parameter. Basis assumes
that the convolution axis is `axis=0`.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import OrthExponentialBasis
>>> X = np.random.normal(size=(1000, 1))
>>> n_basis_funcs = 5
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = ortho_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -2436,6 +2615,16 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Evaluated exponentially decaying basis functions, numerically
orthogonalized, shape (n_samples, n_basis_funcs)
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import OrthExponentialBasis
>>> n_basis_funcs = 5
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand All @@ -2458,6 +2647,17 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray:
-------
spline
M-spline basis function, shape (n_sample_points, ).
Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import mspline
>>> sample_points = linspace(0, 1, 100)
>>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline
>>> mspline_eval.shape
(100,)
"""
# Boundary conditions.
if (T[i + k] - T[i]) < 1e-6:
Expand Down Expand Up @@ -2524,6 +2724,18 @@ def bspline(
Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.
Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import bspline
>>> sample_points = linspace(0, 1, 100)
>>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1])
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
>>> bspline_eval.shape
(100, 10)
"""
knots.sort()
nk = knots.shape[0]
Expand Down

0 comments on commit c1e2297

Please sign in to comment.