From 1e2932a2fdbc2fcb4d0fe90c7199ce42f5216c15 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 14 Nov 2024 11:04:13 -0500 Subject: [PATCH 001/109] initial moving pieces --- README.md | 10 +- docs/background/plot_01_1D_basis_function.py | 12 +- docs/background/plot_02_ND_basis_function.py | 12 +- docs/background/plot_03_1D_convolution.py | 8 +- docs/how_to_guide/plot_02_glm_demo.py | 2 +- docs/how_to_guide/plot_03_glm_pytree.py | 4 +- docs/how_to_guide/plot_05_batch_glm.py | 2 +- .../plot_06_sklearn_pipeline_cv_demo.py | 18 +- docs/quickstart.md | 45 +- docs/tutorials/plot_02_head_direction.py | 4 +- docs/tutorials/plot_03_grid_cells.py | 4 +- docs/tutorials/plot_04_v1_cells.py | 2 +- docs/tutorials/plot_05_place_cells.py | 6 +- docs/tutorials/plot_06_calcium_imaging.py | 6 +- src/nemos/_documentation_utils/plotting.py | 2 +- src/nemos/basis/__init__.py | 3 + src/nemos/{basis.py => basis/_basis.py} | 3979 +++++------------ src/nemos/basis/_basis_mixin.py | 156 + src/nemos/basis/_raised_cosine_basis.py | 408 ++ src/nemos/basis/_spline_basis.py | 768 ++++ src/nemos/basis/basis.py | 358 ++ src/nemos/identifiability_constraints.py | 2 +- src/nemos/typing.py | 4 + tests/conftest.py | 2 +- tests/test_basis.py | 64 +- tests/test_identifiability_constraints.py | 2 +- tests/test_pipeline.py | 20 +- tests/test_simulation.py | 2 +- 28 files changed, 3035 insertions(+), 2870 deletions(-) create mode 100644 src/nemos/basis/__init__.py rename src/nemos/{basis.py => basis/_basis.py} (52%) create mode 100644 src/nemos/basis/_basis_mixin.py create mode 100644 src/nemos/basis/_raised_cosine_basis.py create mode 100644 src/nemos/basis/_spline_basis.py create mode 100644 src/nemos/basis/basis.py diff --git a/README.md b/README.md index e9abf895..1dbeab9f 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,9 @@ In this example, we'll construct a time-series of features using the basis objec import nemos as nmo # Instantiate the basis -basis_1 = nmo.basis.MSplineBasis(n_basis_funcs=5) -basis_2 = nmo.basis.CyclicBSplineBasis(n_basis_funcs=6) -basis_3 = nmo.basis.MSplineBasis(n_basis_funcs=7) +basis_1 = nemos.basis.basis.EvalMSpline(n_basis_funcs=5) +basis_2 = nemos.basis.basis.CyclicBSplineBasis(n_basis_funcs=6) +basis_3 = nemos.basis.basis.EvalMSpline(n_basis_funcs=7) basis = basis_1 * basis_2 + basis_3 @@ -111,8 +111,8 @@ import nemos as nmo # generate 5 basis functions of 100 time-bins, # and convolve the counts with the basis. -X = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100 - ).compute_features(spike_counts) +X = nemos.basis.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100 + ).compute_features(spike_counts) ``` #### Population GLM diff --git a/docs/background/plot_01_1D_basis_function.py b/docs/background/plot_01_1D_basis_function.py index 3e22f052..7759b30b 100644 --- a/docs/background/plot_01_1D_basis_function.py +++ b/docs/background/plot_01_1D_basis_function.py @@ -16,14 +16,12 @@ import numpy as np import pynapple as nap -import nemos as nmo - # Initialize hyperparameters order = 4 n_basis = 10 # Define the 1D basis function object -bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order) +bspline = nemos.basis.basis.BSplineBasis(n_basis_funcs=n_basis, order=order) # %% # ## Evaluating a Basis @@ -52,7 +50,7 @@ # You can specify a range for the support of your basis by setting the `bounds` # parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN. -bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) +bspline_range = nemos.basis.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) print("Evaluated basis:") # 0.5 is within the support, 0.1 is outside the support @@ -82,8 +80,8 @@ # # Let's see how this two modalities operate. -eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval") -conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100) +eval_mode = nemos.basis.basis.EvalMSpline(n_basis_funcs=n_basis, mode="eval") +conv_mode = nemos.basis.basis.EvalMSpline(n_basis_funcs=n_basis, mode="conv", window_size=100) # define an input angles = np.linspace(0, np.pi*4, 201) @@ -153,7 +151,7 @@ # evaluate a log-spaced cosine raised function basis. # Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) +raised_cosine_log = nemos.basis.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/docs/background/plot_02_ND_basis_function.py b/docs/background/plot_02_ND_basis_function.py index 095633e9..eab9c125 100644 --- a/docs/background/plot_02_ND_basis_function.py +++ b/docs/background/plot_02_ND_basis_function.py @@ -66,11 +66,9 @@ import matplotlib.pyplot as plt import numpy as np -import nemos as nmo - # Define 1D basis objects -a_basis = nmo.basis.MSplineBasis(n_basis_funcs=15, order=3) -b_basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=14) +a_basis = nemos.basis.basis.EvalMSpline(n_basis_funcs=15, order=3) +b_basis = nemos.basis.basis.RaisedCosineBasisLog(n_basis_funcs=14) # Define the 2D additive basis object additive_basis = a_basis + b_basis @@ -239,9 +237,9 @@ T = 10 n_basis = 8 -a_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -b_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -c_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +a_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +b_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +c_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) diff --git a/docs/background/plot_03_1D_convolution.py b/docs/background/plot_03_1D_convolution.py index 79f6ddd9..66e3a1fe 100644 --- a/docs/background/plot_03_1D_convolution.py +++ b/docs/background/plot_03_1D_convolution.py @@ -39,7 +39,7 @@ # create three filters -basis_obj = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=3) +basis_obj = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=3) _, w = basis_obj.evaluate_on_grid(ws) plt.plot(w) @@ -124,17 +124,17 @@ # with `Basis.compute_features`. Let's see how we can get the same results through `Basis`. # define basis with different predictor causality -causal_basis = nmo.basis.RaisedCosineBasisLinear( +causal_basis = nemos.basis.basis.RaisedCosineBasisLinear( n_basis_funcs=3, mode="conv", window_size=ws, predictor_causality="causal" ) -acausal_basis = nmo.basis.RaisedCosineBasisLinear( +acausal_basis = nemos.basis.basis.RaisedCosineBasisLinear( n_basis_funcs=3, mode="conv", window_size=ws, predictor_causality="acausal" ) -anticausal_basis = nmo.basis.RaisedCosineBasisLinear( +anticausal_basis = nemos.basis.basis.RaisedCosineBasisLinear( n_basis_funcs=3, mode="conv", window_size=ws, predictor_causality="anti-causal" ) diff --git a/docs/how_to_guide/plot_02_glm_demo.py b/docs/how_to_guide/plot_02_glm_demo.py index f1c6e3b2..733744e4 100644 --- a/docs/how_to_guide/plot_02_glm_demo.py +++ b/docs/how_to_guide/plot_02_glm_demo.py @@ -263,7 +263,7 @@ # define a basis function n_basis_funcs = 20 -basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs) +basis = nemos.basis.basis.RaisedCosineBasisLog(n_basis_funcs) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/docs/how_to_guide/plot_03_glm_pytree.py b/docs/how_to_guide/plot_03_glm_pytree.py index 02181d32..d1a4edba 100644 --- a/docs/how_to_guide/plot_03_glm_pytree.py +++ b/docs/how_to_guide/plot_03_glm_pytree.py @@ -186,7 +186,7 @@ unit_no = 7 spikes = nwb['units'][unit_no] -basis = nmo.basis.CyclicBSplineBasis(10, order=5) +basis = nemos.basis.basis.CyclicBSplineBasis(10, order=5) x = np.linspace(-np.pi, np.pi, 100) plt.figure() plt.plot(x, basis(x)) @@ -251,7 +251,7 @@ # feature, we need a 2d basis. Let's use some raised cosine bumps and organize # our data similarly. -pos_basis = nmo.basis.RaisedCosineBasisLinear(10) * nmo.basis.RaisedCosineBasisLinear(10) +pos_basis = nemos.basis.basis.RaisedCosineBasisLinear(10) * nemos.basis.basis.RaisedCosineBasisLinear(10) spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data) X['spatial_position'] = pos_basis(*spatial_pos.values.T) diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py index 7454d6f1..12e55638 100644 --- a/docs/how_to_guide/plot_05_batch_glm.py +++ b/docs/how_to_guide/plot_05_batch_glm.py @@ -54,7 +54,7 @@ # # Here we instantiate the basis with a window size of 40 time bins. It corresponds to a 200ms windows # for a 5ms bin size. -basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=40) +basis = nemos.basis.basis.RaisedCosineBasisLog(5, mode="conv", window_size=40) # %% # ## Batch definition diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py index ca9b167a..f2b83d4a 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py @@ -110,9 +110,9 @@ # Instantiating a `TransformerBasis` can be done either using the constructor directly or with `Basis.to_transformer()`: # %% -bas = nmo.basis.RaisedCosineBasisLinear(5, mode="conv", window_size=5) +bas = nemos.basis.basis.RaisedCosineBasisLinear(5, mode="conv", window_size=5) # these two ways of creating the TransformerBasis are equivalent -trans_bas_a = nmo.basis.TransformerBasis(bas) +trans_bas_a = nemos.basis.basis.TransformerBasis(bas) trans_bas_b = bas.to_transformer() # %% @@ -141,7 +141,7 @@ [ ( "transformerbasis", - nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6)), + nemos.basis.basis.TransformerBasis(nemos.basis.basis.RaisedCosineBasisLinear(6)), ), ( "glm", @@ -364,12 +364,12 @@ param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis___basis=( - nmo.basis.RaisedCosineBasisLinear(5), - nmo.basis.RaisedCosineBasisLinear(10), - nmo.basis.RaisedCosineBasisLog(5), - nmo.basis.RaisedCosineBasisLog(10), - nmo.basis.MSplineBasis(5), - nmo.basis.MSplineBasis(10), + nemos.basis.basis.RaisedCosineBasisLinear(5), + nemos.basis.basis.RaisedCosineBasisLinear(10), + nemos.basis.basis.RaisedCosineBasisLog(5), + nemos.basis.basis.RaisedCosineBasisLog(10), + nemos.basis.basis.EvalMSpline(5), + nemos.basis.basis.EvalMSpline(10), ), ) diff --git a/docs/quickstart.md b/docs/quickstart.md index aa234489..eedca3ab 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -160,10 +160,10 @@ you need to specify the number of basis functions. For some `basis` objects, add ```python ->>> import nemos as nmo +>> > import nemos as nmo ->>> n_basis_funcs = 10 ->>> basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs) +>> > n_basis_funcs = 10 +>> > basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs) ``` @@ -199,11 +199,11 @@ number of sample points. ```python ->>> import nemos as nmo +>> > import nemos as nmo ->>> n_basis_funcs = 10 ->>> # define a filter bank of 10 basis function, 200 samples long. ->>> basis = nmo.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) +>> > n_basis_funcs = 10 +>> > # define a filter bank of 10 basis function, 200 samples long. +>> > basis = nemos.basis.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) ``` @@ -324,30 +324,29 @@ Let's see how you can greatly streamline your analysis pipeline by integrating ` !!! note You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1). - ```python ->>> import nemos as nmo ->>> import pynapple as nap +>> > import nemos as nmo +>> > import pynapple as nap ->>> path = nmo.fetch.fetch_data("A2929-200711.nwb") ->>> data = nap.load_file(path) +>> > path = nmo.fetch.fetch_data("A2929-200711.nwb") +>> > data = nap.load_file(path) ->>> # load spikes and head direction ->>> spikes = data["units"] ->>> head_dir = data["ry"] +>> > # load spikes and head direction +>> > spikes = data["units"] +>> > head_dir = data["ry"] ->>> # restrict and bin ->>> counts = spikes[6].count(0.01, ep=head_dir.time_support) +>> > # restrict and bin +>> > counts = spikes[6].count(0.01, ep=head_dir.time_support) ->>> # down-sample head direction ->>> upsampled_head_dir = head_dir.bin_average(0.01) +>> > # down-sample head direction +>> > upsampled_head_dir = head_dir.bin_average(0.01) ->>> # create your features ->>> X = nmo.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) +>> > # create your features +>> > X = nemos.basis.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) ->>> # add a neuron axis and fit model ->>> model = nmo.glm.GLM().fit(X, counts) +>> > # add a neuron axis and fit model +>> > model = nmo.glm.GLM().fit(X, counts) ``` diff --git a/docs/tutorials/plot_02_head_direction.py b/docs/tutorials/plot_02_head_direction.py index dc4740d0..b8386b3c 100644 --- a/docs/tutorials/plot_02_head_direction.py +++ b/docs/tutorials/plot_02_head_direction.py @@ -349,7 +349,7 @@ # the cost of adding additional parameters. # a basis object can be instantiated in "conv" mode for convolving the input. -basis = nmo.basis.RaisedCosineBasisLog( +basis = nemos.basis.basis.RaisedCosineBasisLog( n_basis_funcs=8, mode="conv", window_size=window_size ) @@ -512,7 +512,7 @@ # define a basis function that expects an input of shape (num_samples, num_neurons). num_neurons = count.shape[1] -basis = nmo.basis.RaisedCosineBasisLog( +basis = nemos.basis.basis.RaisedCosineBasisLog( n_basis_funcs=8, mode="conv", window_size=window_size, label="convolved counts" ) diff --git a/docs/tutorials/plot_03_grid_cells.py b/docs/tutorials/plot_03_grid_cells.py index f8532835..8b51b1dc 100644 --- a/docs/tutorials/plot_03_grid_cells.py +++ b/docs/tutorials/plot_03_grid_cells.py @@ -95,9 +95,9 @@ # We can define a two-dimensional basis for position by multiplying two one-dimensional bases, # see [here](../../background/plot_02_ND_basis_function) for more details. -basis_2d = nmo.basis.RaisedCosineBasisLinear( +basis_2d = nemos.basis.basis.RaisedCosineBasisLinear( n_basis_funcs=10 -) * nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10) +) * nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=10) # %% # Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid. diff --git a/docs/tutorials/plot_04_v1_cells.py b/docs/tutorials/plot_04_v1_cells.py index ea3da54a..eb2d4cfd 100644 --- a/docs/tutorials/plot_04_v1_cells.py +++ b/docs/tutorials/plot_04_v1_cells.py @@ -255,7 +255,7 @@ def fill_forward(time_series, data, ep=None, out_of_range=np.nan): # GLM: window_size = 100 -basis = nmo.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size) +basis = nemos.basis.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size) convolved_input = basis.compute_features(filtered_stimulus) diff --git a/docs/tutorials/plot_05_place_cells.py b/docs/tutorials/plot_05_place_cells.py index 24ea3f63..b6ed5da4 100644 --- a/docs/tutorials/plot_05_place_cells.py +++ b/docs/tutorials/plot_05_place_cells.py @@ -236,9 +236,9 @@ # - theta phase : `nmo.basis.CyclicBSplineBasis` # - speed : `nmo.basis.MSplineBasis` -position_basis = nmo.basis.MSplineBasis(n_basis_funcs=10) -phase_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12) -speed_basis = nmo.basis.MSplineBasis(n_basis_funcs=15) +position_basis = nemos.basis.basis.EvalMSpline(n_basis_funcs=10) +phase_basis = nemos.basis.basis.CyclicBSplineBasis(n_basis_funcs=12) +speed_basis = nemos.basis.basis.EvalMSpline(n_basis_funcs=15) # %% # In addition, we will consider position and phase to be a joint variable. In NeMoS, we can combine basis by multiplying them and adding them. In this case the final basis object for our model can be made in one line : diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py index 54d83666..41b6cca6 100644 --- a/docs/tutorials/plot_06_calcium_imaging.py +++ b/docs/tutorials/plot_06_calcium_imaging.py @@ -10,12 +10,10 @@ """ import warnings -from warnings import catch_warnings import jax import jax.numpy as jnp import matplotlib.pyplot as plt -import numpy as np import pynapple as nap from sklearn.linear_model import LinearRegression @@ -127,12 +125,12 @@ # # We can combine the two bases. -heading_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12) +heading_basis = nemos.basis.basis.CyclicBSplineBasis(n_basis_funcs=12) # define a basis that expect all the other neurons as predictors, i.e. shape (num_samples, num_neurons - 1) num_neurons = Y.shape[1] ws = 10 -coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=ws) +coupling_basis = nemos.basis.basis.RaisedCosineBasisLog(3, mode="conv", window_size=ws) # %% # We need to combine the bases. diff --git a/src/nemos/_documentation_utils/plotting.py b/src/nemos/_documentation_utils/plotting.py index 5514787f..ed202867 100644 --- a/src/nemos/_documentation_utils/plotting.py +++ b/src/nemos/_documentation_utils/plotting.py @@ -33,7 +33,7 @@ from matplotlib.patches import Rectangle from numpy.typing import NDArray -from ..basis import RaisedCosineBasisLog +from nemos.basis.basis import RaisedCosineBasisLog warnings.warn( "plotting functions contained within `_documentation_utils` are intended for nemos's documentation. " diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py new file mode 100644 index 00000000..29cbf9af --- /dev/null +++ b/src/nemos/basis/__init__.py @@ -0,0 +1,3 @@ +from .basis import (EvalMSpline, ConvMSpline, EvalCyclicBSpline, ConvCyclicBSpline, EvalBSpline, ConvBSpline, + RaisedCosineBasisLog, RaisedCosineBasisLog, RaisedCosineBasisLog, RaisedCosineBasisLog, + OrthExponentialBasis) diff --git a/src/nemos/basis.py b/src/nemos/basis/_basis.py similarity index 52% rename from src/nemos/basis.py rename to src/nemos/basis/_basis.py index c709177d..8b1e2686 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis/_basis.py @@ -1,5 +1,3 @@ -"""Bases classes.""" - # required to get ArrayLike to render correctly from __future__ import annotations @@ -11,35 +9,16 @@ import jax import numpy as np -import scipy.linalg from numpy.typing import ArrayLike, NDArray from pynapple import Tsd, TsdFrame -from scipy.interpolate import splev - -from .base_class import Base -from .convolve import create_convolutional_predictor -from .type_casting import support_pynapple -from .utils import row_wise_kron -from .validation import check_fraction_valid_samples - -FeatureMatrix = Union[NDArray, TsdFrame] - -__all__ = [ - "MSplineBasis", - "BSplineBasis", - "CyclicBSplineBasis", - "RaisedCosineBasisLinear", - "RaisedCosineBasisLog", - "OrthExponentialBasis", - "AdditiveBasis", - "MultiplicativeBasis", - "TransformerBasis", -] +from ..base_class import Base +from ..convolve import create_convolutional_predictor +from ..type_casting import support_pynapple -def __dir__() -> list[str]: - return __all__ - +from ..utils import row_wise_kron +from ..typing import FeatureMatrix +from ..validation import check_fraction_valid_samples def check_transform_input(func: Callable) -> Callable: """Check input before calling basis. @@ -106,367 +85,527 @@ def min_max_rescale_samples( return sample_pts, scaling -class TransformerBasis: - """Basis as `scikit-learn` transformers. - - This class abstracts the underlying basis function details, offering methods - similar to scikit-learn's transformers but specifically designed for basis - transformations. It supports fitting to data (calculating any necessary parameters - of the basis functions), transforming data (applying the basis functions to - data), and both fitting and transforming in one step. +class Basis(Base, abc.ABC): + """ + Abstract base class for defining basis functions for feature transformation. - `TransformerBasis`, unlike `Basis`, is compatible with scikit-learn pipelining and - model selection, enabling the cross-validation of the basis type and parameters, - for example `n_basis_funcs`. See the example section below. + Basis functions are mathematical constructs that can represent data in alternative, + often more compact or interpretable forms. This class provides a template for such + transformations, with specific implementations defining the actual behavior. Parameters ---------- - basis : - A concrete subclass of `Basis`. + n_basis_funcs : + The number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. - Examples - -------- - >>> from nemos.basis import BSplineBasis, TransformerBasis - >>> from nemos.glm import GLM - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> import numpy as np - >>> np.random.seed(123) + Raises + ------ + ValueError: + - If `mode` is not 'eval' or 'conv'. + - If `kwargs` are not None and `mode =="eval". + - If `kwargs` include parameters not recognized or do not have + default values in `create_convolutional_predictor`. + - If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis). + """ - >>> # Generate data - >>> num_samples, num_features = 10000, 1 - >>> x = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) - >>> features = basis.compute_features(x) # basis transformed time series - >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights - >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts + def __init__( + self, + n_basis_funcs: int, + mode: Literal["eval", "conv"] = "eval", + label: Optional[str] = None, + **kwargs, + ) -> None: + self.n_basis_funcs = n_basis_funcs + self._n_input_dimensionality = 0 - >>> # transformer can be used in pipelines - >>> transformer = TransformerBasis(basis) - >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) - >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API - >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas - >>> # TransformerBasis parameter can be cross-validated. - >>> # 5-fold cross-validate the number of basis - >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) - >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) - >>> grid_cv = grid_cv.fit(x[:, None], y) - >>> print("Cross-validated number of basis:", grid_cv.best_params_) - Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} - """ + self._conv_kwargs = kwargs - def __init__(self, basis: Basis): - self._basis = copy.deepcopy(basis) + self._mode = mode - @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack impute without using transpose. + self._n_basis_input = None - Unpack horizontally stacked inputs using slicing. This works gracefully with `pynapple`, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a `pynapple` - exception since `pynapple` assumes that the time axis is the first axis. + # these parameters are going to be set at the first call of `compute_features` + # since we cannot know a-priori how many features may be convolved + self._n_output_features = None + self._input_shape = None - Parameters - ---------- - X: - The inputs horizontally stacked. + if label is None: + self._label = self.__class__.__name__ + else: + self._label = str(label) - Returns - ------- - : - A tuple of each individual input. + # pop the two mode dependent kwargs + window_size = kwargs.pop("window_size", None) + if window_size: + self._window_size = window_size + bounds = kwargs.pop("bounds", None) + if bounds: + self._bounds = bounds + + # the rest should be convolutional kwargs + self._check_convolution_kwargs() + + self.kernel_ = None + + def _check_convolution_kwargs(self): + """Check convolution kwargs settings. + + Raises + ------ + ValueError: + - If `self._conv_kwargs` are not None and `mode =="eval". + - If `axis` is provided as an argument, and it is different from 0 + (samples must always be in the first axis). + - If `self._conv_kwargs` include parameters not recognized or that do not have + default values in `create_convolutional_predictor`. + """ + # this should not be hit since **kwargs are not allowed at EvalBasis init. + if self._mode == "eval" and self._conv_kwargs: + raise ValueError( + f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" + ) + + if "axis" in self._conv_kwargs: + raise ValueError( + "Setting the `axis` parameter is not allowed. Basis requires the " + "convolution to be applied along the first axis (`axis=0`).\n" + "Please transpose your input so that the desired axis for " + "convolution is the first dimension (axis=0)." + ) + convolve_params = inspect.signature(create_convolutional_predictor).parameters + convolve_configs = { + key + for key, param in convolve_params.items() + if param.default + # prevent user from passing + # `basis_matrix` or `time_series` in kwargs. + is not inspect.Parameter.empty + } + if not set(self._conv_kwargs.keys()).issubset(convolve_configs): + # do not encourage to set axis. + convolve_configs = convolve_configs.difference({"axis"}) + # remove the parameter in case axis=0 was passed, since it is allowed. + invalid = ( + set(self._conv_kwargs.keys()) + .difference(convolve_configs) + .difference({"axis"}) + ) + raise ValueError( + f"Unrecognized keyword arguments: {invalid}. " + f"Allowed convolution keyword arguments are: {convolve_configs}." + ) + @property + def n_output_features(self) -> int | None: """ - return (X[:, k] for k in range(X.shape[1])) + Read-only property indicating the number of features returned by the basis, when available. - def fit(self, X: FeatureMatrix, y=None): + Notes + ----- + The number of output features can be determined only when the number of inputs + provided to the basis is known. Therefore, before the first call to `compute_features`, + this property will return `None`. After that call, `n_output_features` will be available. """ - Compute the convolutional kernels. + return self._n_output_features - If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. + @property + def label(self) -> str: + return self._label - Parameters - ---------- - X : - The data to fit the basis functions to, shape (num_samples, num_input). - y : ignored - Not used, present for API consistency by convention. + @property + def n_basis_input(self) -> tuple | None: + if self._n_basis_input is None: + return + return self._n_basis_input - Returns - ------- - self : - The transformer object. + @property + def n_basis_funcs(self): + return self._n_basis_funcs - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis + @n_basis_funcs.setter + def n_basis_funcs(self, value): + orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None)) + self._n_basis_funcs = value + try: + self._check_n_basis_min() + except ValueError as e: + self._n_basis_funcs = orig_n_basis + raise e - >>> # Example input - >>> X = np.random.normal(size=(100, 2)) + @property + def mode(self): + return self._mode - >>> # Define and fit tranformation basis - >>> basis = MSplineBasis(10) - >>> transformer = TransformerBasis(basis) - >>> transformer_fitted = transformer.fit(X) - """ - self._basis._set_kernel(*self._unpack_inputs(X)) - return self + @staticmethod + def _apply_identifiability_constraints(X: NDArray): + """Apply identifiability constraints to a design matrix `X`. - def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: - """ - Transform the data using the fitted basis functions. + Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness + of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly + crucial for models using bases like BSplines and CyclicBspline, which, due to their + construction, sum to 1 and can cause rank deficiency when combined with an intercept. + + For GLMs, this rank deficiency means that different sets of coefficients might yield + identical predicted rates and log-likelihood, complicating parameter learning, especially + in the absence of regularization. Parameters ---------- - X : - The data to transform using the basis functions, shape (num_samples, num_input). - y : - Not used, present for API consistency by convention. + X: + The design matrix before applying the identifiability constraints. Returns ------- : - The data transformed by the basis functions. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis + The adjusted design matrix with redundant columns dropped and columns mean-centered. + """ - >>> # Example input - >>> X = np.random.normal(size=(10000, 2)) + def add_constant(x): + return np.hstack((np.ones((x.shape[0], 1)), x)) - >>> # Define and fit tranformation basis - >>> basis = MSplineBasis(10, mode="conv", window_size=200) - >>> transformer = TransformerBasis(basis) - >>> # Before calling `fit` the convolution kernel is not set - >>> transformer.kernel_ - - >>> transformer_fitted = transformer.fit(X) - >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) - >>> transformer_fitted.kernel_.shape - (200, 10) - - >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) - """ - # transpose does not work with pynapple - # can't use func(*X.T) to unwrap - - return self._basis._compute_features(*self._unpack_inputs(X)) + rank = np.linalg.matrix_rank(add_constant(X)) + # mean center + X = X - np.nanmean(X, axis=0) + while rank < X.shape[1] + 1: + # drop a column + X = X[:, :-1] + # recompute rank + rank = np.linalg.matrix_rank(add_constant(X)) + return X - def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: + @check_transform_input + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ - Compute the kernels and the features. + Compute the basis functions and transform input data into model features. - This method is a convenience that combines fit and transform into - one step. + This method is designed to be a high-level interface for transforming input + data using the basis functions defined by the subclass. Depending on the basis' + mode ('eval' or 'conv'), it either evaluates the basis functions at the sample + points or performs a convolution operation between the input data and the + basis functions. Parameters ---------- - X : - The data to fit the basis functions to and then transform. - y : - Not used, present for API consistency by convention. + *xi : + Input data arrays to be transformed. The shape and content requirements + depend on the subclass and mode of operation ('eval' or 'conv'). Returns ------- - array-like - The data transformed by the basis functions, after fitting the basis - functions to the data. + : + Transformed features. In 'eval' mode, it corresponds to the basis functions + evaluated at the input samples. In 'conv' mode, it consists of convolved + input samples with the basis functions. The output shape varies based on + the subclass and mode. Examples -------- >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis - - >>> # Example input - >>> X = np.random.normal(size=(100, 1)) + >>> from nemos.basis import BSplineBasis - >>> # Define tranformation basis - >>> basis = MSplineBasis(10) - >>> transformer = TransformerBasis(basis) + >>> # Generate data + >>> num_samples = 10000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = BSplineBasis(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (10000, 10) - >>> # Fit and transform basis - >>> feature_transformed = transformer.fit_transform(X) + Notes + ----- + Subclasses should implement how to handle the transformation specific to their + basis function types and operation modes. """ - return self._basis.compute_features(*self._unpack_inputs(X)) + self._set_num_output_features(*xi) + self._set_kernel() + return self._compute_features(*xi) - def __getstate__(self): - """ - Explicitly define how to pickle TransformerBasis object. + @abc.abstractmethod + def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """Convolve or evaluate the basis.""" + pass - See https://docs.python.org/3/library/pickle.html#object.__getstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - return {"_basis": self._basis} + @abc.abstractmethod + def _set_kernel(self): + """Set kernel for conv basis and return self or just return self for eval.""" + pass - def __setstate__(self, state): + @abc.abstractmethod + def __call__(self, *xi: ArrayLike) -> FeatureMatrix: """ - Define how to populate the object's state when unpickling. + Abstract method to evaluate the basis functions at given points. - Note that during unpickling a new object is created without calling __init__. - Needed to avoid infinite recursion in __getattr__ when unpickling. + This method must be implemented by subclasses to define the specific behavior + of the basis transformation. The implementation depends on the type of basis + (e.g., spline, raised cosine), and it should evaluate the basis functions at + the specified points in the domain. - See https://docs.python.org/3/library/pickle.html#object.__setstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - self._basis = state["_basis"] + Parameters + ---------- + *xi : + Variable number of arguments, each representing an array of points at which + to evaluate the basis functions. The dimensions and requirements of these + inputs vary depending on the specific basis implementation. - def __getattr__(self, name: str): + Returns + ------- + : + An array containing the evaluated values of the basis functions at the input + points. The shape and structure of this array are specific to the subclass + implementation. """ - Enable easy access to attributes of the underlying Basis object. + pass - Examples - -------- - >>> from nemos import basis - >>> bas = basis.RaisedCosineBasisLinear(5) - >>> trans_bas = basis.TransformerBasis(bas) - >>> bas.n_basis_funcs - 5 - >>> trans_bas.n_basis_funcs - 5 - """ - return getattr(self._basis, name) + def _get_samples(self, *n_samples: int) -> Generator[NDArray]: + """Get equi-spaced samples for all the input dimensions. - def __setattr__(self, name: str, value) -> None: - r""" - Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. + This will be used to evaluate the basis on a grid of + points derived by the samples. - Setting any other attribute is not allowed. + Parameters + ---------- + n_samples[0],...,n_samples[n] + The number of samples in each axis of the grid. Returns ------- - None + : + A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by `n_samples`. + """ + # handling of defaults when evaluating on a grid + # (i.e. when we cannot use max and min of samples) + if self.bounds is None: + mn, mx = 0, 1 + else: + mn, mx = self.bounds + return (np.linspace(mn, mx, n_samples[k]) for k in range(len(n_samples))) + + @support_pynapple(conv_type="numpy") + def _check_transform_input( + self, *xi: ArrayLike + ) -> Tuple[Union[NDArray, Tsd, TsdFrame]]: + """Check transform input. + + Parameters + ---------- + xi[0],...,xi[n] : + The input samples, each with shape (number of samples, ). Raises ------ ValueError - If the attribute being set is not `_basis` or an attribute of `_basis`. + - If the time point number is inconsistent between inputs. + - If the number of inputs doesn't match what the Basis object requires. + - At least one of the samples is empty. - Examples - -------- - >>> import nemos as nmo - >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.MSplineBasis(10)) - >>> # allowed - >>> trans_bas._basis = nmo.basis.BSplineBasis(10) - >>> # allowed - >>> trans_bas.n_basis_funcs = 20 - >>> # not allowed - >>> try: - ... trans_bas.random_attribute_name = "some value" - ... except ValueError as e: - ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') """ - # allow self._basis = basis - if name == "_basis": - super().__setattr__(name, value) - # allow changing existing attributes of self._basis - elif hasattr(self._basis, name): - setattr(self._basis, name, value) - # don't allow setting any other attribute - else: + # check that the input is array-like (i.e., whether we can cast it to + # numeric arrays) + try: + # make sure array is at least 1d (so that we succeed when only + # passed a scalar) + xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) + # ValueError here surfaces the exception with e.g., `x=np.array["a", "b"])` + except (TypeError, ValueError): + raise TypeError("Input samples must be array-like of floats!") + + # check for non-empty samples + if self._has_zero_samples(tuple(len(x) for x in xi)): + raise ValueError("All sample provided must be non empty.") + + # checks on input and outputs + self._check_samples_consistency(*xi) + self._check_input_dimensionality(xi) + + return xi + + def _check_has_kernel(self) -> None: + """Check that the kernel is pre-computed.""" + if self.mode == "conv" and self.kernel_ is None: raise ValueError( - "Only setting _basis or existing attributes of _basis is allowed." + "You must call `_set_kernel` before `_compute_features` when mode =`conv`." ) - def __sklearn_clone__(self) -> TransformerBasis: - """ - Customize how TransformerBasis objects are cloned when used with sklearn.model_selection. - - By default, scikit-learn tries to clone the object by calling __init__ using the output of get_params, - which fails in our case. + def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. - For more info: https://scikit-learn.org/stable/developers/develop.html#cloning - """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) - cloned_obj._basis.kernel_ = None - return cloned_obj + The i-th axis of the grid will be sampled with n_samples[i] equi-spaced points. + The method uses numpy.meshgrid with `indexing="ij"`, returning matrix indexing + instead of the default cartesian indexing, see Notes. - def set_params(self, **parameters) -> TransformerBasis: - """ - Set TransformerBasis parameters. + Parameters + ---------- + n_samples[0],...,n_samples[n] + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. - When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly - or the parameters of the underlying Basis, but not both. + Returns + ------- + *Xs : + A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, + where n equals to the number of inputs. + The size of Xs[i] is (n_samples[0], ... , n_samples[n]). + Y : + The basis function evaluated at the samples, + shape (n_samples[0], ... , n_samples[n], number of basis). - Examples - -------- - >>> from nemos.basis import BSplineBasis, MSplineBasis, TransformerBasis - >>> basis = MSplineBasis(10) - >>> transformer_basis = TransformerBasis(basis=basis) + Raises + ------ + ValueError + - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what + the Basis object requires. + - If one of the n_samples is <= 0. - >>> # setting parameters of _basis is allowed - >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) - 8 - >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)) - - >>> # mixing is not allowed, this will raise an exception - >>> try: - ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) - ... except ValueError as e: - ... print(repr(e)) - ValueError('Set either new _basis object or parameters for existing _basis, not both.') + Notes + ----- + Setting "indexing = 'ij'" returns a meshgrid with matrix indexing. In the N-D case with inputs of size + $M_1,...,M_N$, outputs are of shape $(M_1, M_2, M_3, ....,M_N)$. + This differs from the numpy.meshgrid default, which uses Cartesian indexing. + For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$. + + Examples + -------- + >>> # Evaluate and visualize 4 M-spline basis functions of order 3: + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalMSpline + >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) + >>> p = plt.plot(sample_points, basis_values) + >>> _ = plt.title('M-Spline Basis Functions') + >>> _ = plt.xlabel('Domain') + >>> _ = plt.ylabel('Basis Function Value') + >>> _ = plt.legend([f'Function {i+1}' for i in range(4)]); """ - new_basis = parameters.pop("_basis", None) - if new_basis is not None: - self._basis = new_basis - if len(parameters) > 0: - raise ValueError( - "Set either new _basis object or parameters for existing _basis, not both." - ) - else: - self._basis = self._basis.set_params(**parameters) + self._check_input_dimensionality(n_samples) - return self + if self._has_zero_samples(n_samples): + raise ValueError("All sample counts provided must be greater than zero.") - def get_params(self, deep: bool = True) -> dict: - """Extend the dict of parameters from the underlying Basis with _basis.""" - return {"_basis": self._basis, **self._basis.get_params(deep)} + # get the samples + sample_tuple = self._get_samples(*n_samples) + Xs = np.meshgrid(*sample_tuple, indexing="ij") - def __dir__(self) -> list[str]: - """Extend the list of properties of methods with the ones from the underlying Basis.""" - return super().__dir__() + self._basis.__dir__() + # evaluates the basis on a flat NDArray and reshape to match meshgrid output + Y = self.__call__(*tuple(grid_axis.flatten() for grid_axis in Xs)).reshape( + (*n_samples, self.n_basis_funcs) + ) - def __add__(self, other: TransformerBasis) -> TransformerBasis: + return *Xs, Y + + @staticmethod + def _has_zero_samples(n_samples: Tuple[int, ...]) -> bool: + return any([n <= 0 for n in n_samples]) + + def _check_input_dimensionality(self, xi: Tuple) -> None: """ - Add two TransformerBasis objects. + Check that the number of inputs provided by the user matches the number of inputs required. + + Parameters + ---------- + xi[0], ..., xi[n] : + The input samples, shape (number of samples, ). + + Raises + ------ + ValueError + If the number of inputs doesn't match what the Basis object requires. + """ + if len(xi) != self._n_input_dimensionality: + raise TypeError( + f"Input dimensionality mismatch. This basis evaluation requires {self._n_input_dimensionality} inputs, " + f"{len(xi)} inputs provided instead." + ) + + @staticmethod + def _check_samples_consistency(*xi: NDArray) -> None: + """ + Check that each input provided to the Basis object has the same number of time points. + + Parameters + ---------- + xi[0], ..., xi[n] : + The input samples, shape (number of samples, ). + + Raises + ------ + ValueError + If the time point number is inconsistent between inputs. + """ + sample_sizes = [sample.shape[0] for sample in xi] + if any(elem != sample_sizes[0] for elem in sample_sizes): + raise ValueError( + "Sample size mismatch. Input elements have inconsistent sample sizes." + ) + + @abc.abstractmethod + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Most of the basis work with at least 1 element, but some + such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined. + + Raises + ------ + ValueError + If an insufficient number of basis element is requested for the basis type + """ + pass + + def __add__(self, other: Basis) -> AdditiveBasis: + """ + Add two Basis objects together. Parameters ---------- other - The other TransformerBasis object to add. + The other Basis object to add. Returns ------- - : TransformerBasis + : AdditiveBasis The resulting Basis object. """ - return TransformerBasis(self._basis + other._basis) + return AdditiveBasis(self, other) - def __mul__(self, other: TransformerBasis) -> TransformerBasis: + def __mul__(self, other: Basis) -> MultiplicativeBasis: """ - Multiply two TransformerBasis objects. + Multiply two Basis objects together. Parameters ---------- other - The other TransformerBasis object to multiply. + The other Basis object to multiply. Returns ------- : The resulting Basis object. """ - return TransformerBasis(self._basis * other._basis) + return MultiplicativeBasis(self, other) - def __pow__(self, exponent: int) -> TransformerBasis: - """Exponentiation of a TransformerBasis object. + def __pow__(self, exponent: int) -> MultiplicativeBasis: + """Exponentiation of a Basis object. - Define the power of a basis by repeatedly applying the method __mul__. + Define the power of a basis by repeatedly applying the method __multiply__. The exponent must be a positive integer. Parameters @@ -486,1171 +625,157 @@ def __pow__(self, exponent: int) -> TransformerBasis: ValueError If the integer is zero or negative. """ - # errors are handled by Basis.__pow__ - return TransformerBasis(self._basis**exponent) + if not isinstance(exponent, int): + raise TypeError("Exponent should be an integer!") + if exponent <= 0: + raise ValueError("Exponent should be a non-negative integer!") -class Basis(Base, abc.ABC): - """ - Abstract base class for defining basis functions for feature transformation. + result = self + for _ in range(exponent - 1): + result = result * self + return result - Basis functions are mathematical constructs that can represent data in alternative, - often more compact or interpretable forms. This class provides a template for such - transformations, with specific implementations defining the actual behavior. + def to_transformer(self) -> TransformerBasis: + """ + Turn the Basis into a TransformerBasis for use with scikit-learn. - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. + Examples + -------- + Jointly cross-validating basis and GLM parameters with scikit-learn. - Raises - ------ - ValueError: - - If `mode` is not 'eval' or 'conv'. - - If `kwargs` are not None and `mode =="eval". - - If `kwargs` include parameters not recognized or do not have - default values in `create_convolutional_predictor`. - - If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis). - """ + >>> import nemos as nmo + >>> from sklearn.pipeline import Pipeline + >>> from sklearn.model_selection import GridSearchCV + >>> # load some data + >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) + >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() + >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) + >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) + >>> param_grid = dict( + ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), + ... basis__n_basis_funcs=(3, 5, 10, 20, 100), + ... ) + >>> gridsearch = GridSearchCV( + ... pipeline, + ... param_grid=param_grid, + ... cv=5, + ... ) + >>> gridsearch = gridsearch.fit(X, y) + """ + return TransformerBasis(copy.deepcopy(self)) - def __init__( + def _get_feature_slicing( self, - n_basis_funcs: int, - mode: Literal["eval", "conv"] = "eval", - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = None, - **kwargs, - ) -> None: - self.n_basis_funcs = n_basis_funcs - self._n_input_dimensionality = 0 - - self._conv_kwargs = kwargs + n_inputs: Optional[tuple] = None, + start_slice: Optional[int] = None, + split_by_input: bool = True, + ) -> Tuple[dict, int]: + """ + Calculate and return the slicing for features based on the input structure. - # check mode - if mode not in ["conv", "eval"]: - raise ValueError( - f"`mode` should be either 'conv' or 'eval'. '{mode}' provided instead!" - ) + This method determines how to slice the features for different basis types. + If the instance is of `AdditiveBasis` type, the slicing is calculated recursively + for each component basis. Otherwise, it determines the slicing based on + the number of basis functions and `split_by_input` flag. - self._mode = mode + Parameters + ---------- + n_inputs : + The number of input basis for each component, by default it uses `self._n_basis_input`. + start_slice : + The starting index for slicing, by default it starts from 0. + split_by_input : + Flag indicating whether to split the slicing by individual inputs or not. + If `False`, a single slice is generated for all inputs. - self._n_basis_input = None + Returns + ------- + split_dict : + Dictionary with keys as labels and values as slices representing + the slicing for each input or additive component, if split_by_input equals to + True or False respectively. + start_slice : + The updated starting index after slicing. - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - - if label is None: - self._label = self.__class__.__name__ - else: - self._label = str(label) - - self.window_size = window_size - self.bounds = bounds - - self._check_convolution_kwargs() - - self.kernel_ = None - - def _check_convolution_kwargs(self): - """Check convolution kwargs settings. - - Raises - ------ - ValueError: - - If `self._conv_kwargs` are not None and `mode =="eval". - - If `axis` is provided as an argument, and it is different from 0 - (samples must always be in the first axis). - - If `self._conv_kwargs` include parameters not recognized or that do not have - default values in `create_convolutional_predictor`. - """ - if self._mode == "eval" and self._conv_kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" - ) - - if "axis" in self._conv_kwargs: - raise ValueError( - "Setting the `axis` parameter is not allowed. Basis requires the " - "convolution to be applied along the first axis (`axis=0`).\n" - "Please transpose your input so that the desired axis for " - "convolution is the first dimension (axis=0)." - ) - convolve_params = inspect.signature(create_convolutional_predictor).parameters - convolve_configs = { - key - for key, param in convolve_params.items() - if param.default - # prevent user from passing - # `basis_matrix` or `time_series` in kwargs. - is not inspect.Parameter.empty - } - if not set(self._conv_kwargs.keys()).issubset(convolve_configs): - # do not encourage to set axis. - convolve_configs = convolve_configs.difference({"axis"}) - # remove the parameter in case axis=0 was passed, since it is allowed. - invalid = ( - set(self._conv_kwargs.keys()) - .difference(convolve_configs) - .difference({"axis"}) - ) - raise ValueError( - f"Unrecognized keyword arguments: {invalid}. " - f"Allowed convolution keyword arguments are: {convolve_configs}." - ) - - @property - def n_output_features(self) -> int | None: - """ - Read-only property indicating the number of features returned by the basis, when available. - - Notes - ----- - The number of output features can be determined only when the number of inputs - provided to the basis is known. Therefore, before the first call to `compute_features`, - this property will return `None`. After that call, `n_output_features` will be available. + See Also + -------- + _get_default_slicing : Handles default slicing logic. + _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ - return self._n_output_features - - @property - def label(self) -> str: - return self._label - - @property - def n_basis_input(self) -> tuple | None: - if self._n_basis_input is None: - return - return self._n_basis_input - - @property - def n_basis_funcs(self): - return self._n_basis_funcs - - @n_basis_funcs.setter - def n_basis_funcs(self, value): - orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None)) - self._n_basis_funcs = value - try: - self._check_n_basis_min() - except ValueError as e: - self._n_basis_funcs = orig_n_basis - raise e - - @property - def bounds(self): - return self._bounds - - @bounds.setter - def bounds(self, values: Union[None, Tuple[float, float]]): - """Setter for bounds.""" - if values is not None and self.mode == "conv": - raise ValueError("`bounds` should only be set when `mode=='eval'`.") + # Set default values for n_inputs and start_slice if not provided + n_inputs = n_inputs or self._n_basis_input + start_slice = start_slice or 0 - if values is not None and len(values) != 2: - raise ValueError( - f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" + # If the instance is of AdditiveBasis type, handle slicing for the additive components + if isinstance(self, AdditiveBasis): + split_dict, start_slice = self._basis1._get_feature_slicing( + n_inputs[: len(self._basis1._n_basis_input)], + start_slice, + split_by_input=split_by_input, ) - - # convert to float and store - try: - self._bounds = values if values is None else tuple(map(float, values)) - except (ValueError, TypeError): - raise TypeError("Could not convert `bounds` to float.") - - if values is not None and values[1] <= values[0]: - raise ValueError( - f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." + sp2, start_slice = self._basis2._get_feature_slicing( + n_inputs[len(self._basis1._n_basis_input) :], + start_slice, + split_by_input=split_by_input, ) - - @property - def mode(self): - return self._mode - - @property - def window_size(self): - return self._window_size - - @window_size.setter - def window_size(self, window_size): - """Setter for the window size parameter.""" - if self.mode == "eval": - if window_size: - raise ValueError( - "If basis is in `mode=='eval'`, `window_size` should be None." - ) - + split_dict = self._merge_slicing_dicts(split_dict, sp2) else: - if window_size is None: - raise ValueError( - "If the basis is in `conv` mode, you must provide a window_size!" - ) - - elif not (isinstance(window_size, int) and window_size > 0): - raise ValueError( - f"`window_size` must be a positive integer. {window_size} provided instead!" - ) - - self._window_size = window_size - - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix `X`. - - Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - r""" - Apply the basis transformation to the input data. - - This method operates in two modes: - - 'eval': Evaluates the basis functions at the given sample points. - - 'conv': Applies a convolution operation between the input data and the basis functions, - using a window size defined at initialization. - - Parameters - ---------- - *xi: - The input samples over which to apply the basis transformation. The samples can be passed - as multiple arguments, each representing a different dimension for multivariate inputs. - - Returns - ------- - : - A matrix with the transformed features. The shape of the output depends on the operation mode: - - If `mode == 'eval'`, the basis evaluated at the samples, or $b_i(*xi)$, where $b_i$ is a - basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. - - - If `mode == 'conv'`, a bank of basis filters (created by calling fit) is convolved with the - samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions - except for the sample-axis are flattened, so that the method always returns a matrix. - For example, if samples are of shape (num_samples, 2, 3), the output will be - (num_samples, num_basis_funcs * 2 * 3). - The time-axis can be specified at basis initialization by setting the keyword argument `axis`. - For example, if `axis == 1` your samples should be (N1, num_samples N3, ...), the output of - transform will be (num_samples, num_basis_funcs * N1 * N3 *...). - - Raises - ------ - ValueError: - - If an invalid mode is specified or necessary parameters for the chosen mode are missing. - - In mode "conv", if the number of inputs to be convolved, doesn't match the number of inputs - set at initialization. - """ - # check if self.kernel_ is not None for mode="conv" - self._check_has_kernel() - if self.mode == "eval": # evaluate at the sample - return self.__call__(*xi) - else: # convolve, called only at the last layer - # before calling the convolve, check that the input matches - # the expectation. We can check xi[0] only, since convolution - # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. - conv = create_convolutional_predictor( - self.kernel_, *xi, **self._conv_kwargs + # Handle the default case for other basis types + split_dict, start_slice = self._get_default_slicing( + split_by_input, start_slice ) - # make sure to return a matrix - return np.reshape(conv, newshape=(conv.shape[0], -1)) - - @check_transform_input - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute the basis functions and transform input data into model features. - - This method is designed to be a high-level interface for transforming input - data using the basis functions defined by the subclass. Depending on the basis' - mode ('eval' or 'conv'), it either evaluates the basis functions at the sample - points or performs a convolution operation between the input data and the - basis functions. - - Parameters - ---------- - *xi : - Input data arrays to be transformed. The shape and content requirements - depend on the subclass and mode of operation ('eval' or 'conv'). - - Returns - ------- - : - Transformed features. In 'eval' mode, it corresponds to the basis functions - evaluated at the input samples. In 'conv' mode, it consists of convolved - input samples with the basis functions. The output shape varies based on - the subclass and mode. - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import BSplineBasis + return split_dict, start_slice - >>> # Generate data - >>> num_samples = 10000 - >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) - >>> features = basis.compute_features(X) # basis transformed time series - >>> features.shape - (10000, 10) + def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: + """Merge two slicing dictionaries, handling key conflicts.""" + for key, val in dict2.items(): + if key in dict1: + new_key = self._generate_unique_key(dict1, key) + dict1[new_key] = val + else: + dict1[key] = val + return dict1 - Notes - ----- - Subclasses should implement how to handle the transformation specific to their - basis function types and operation modes. - """ - self._set_num_output_features(*xi) - if self.kernel_ is None: - self._set_kernel(*xi) - return self._compute_features(*xi) + @staticmethod + def _generate_unique_key(existing_dict: dict, key: str) -> str: + """Generate a unique key if there is a conflict.""" + extra = 1 + new_key = f"{key}-{extra}" + while new_key in existing_dict: + extra += 1 + new_key = f"{key}-{extra}" + return new_key - def _set_kernel(self, *xi: ArrayLike) -> Basis: - """ - Prepare or compute the convolutional kernel for the basis functions. - - This method is called to prepare the basis functions for convolution operations - in subclasses where the 'conv' mode is used. It typically involves computing a - kernel based on the basis functions that will be used for convolution with the - input data. The specifics of kernel computation depend on the subclass implementation - and the nature of the basis functions. - - In 'eval' mode, this method might not perform any operation but simply return the - instance itself, as no kernel preparation is necessary. - - Parameters - ---------- - *xi : - The input data based on which the kernel might be computed. The actual use of - these inputs is subclass-specific and might not be applicable for all basis types. - - Returns - ------- - self : - The instance itself, modified to include the computed kernel if applicable. This - allows for method chaining and integration into transformation pipelines. - - Notes - ----- - Subclasses implementing this method should detail the specifics of how the kernel is - computed and how the input parameters are utilized. If the basis operates in 'eval' - mode exclusively, this method should simply return `self` without modification. - """ - if self.mode == "conv": - self.kernel_ = self.__call__(np.linspace(0, 1, self.window_size)) - return self - - @abc.abstractmethod - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Abstract method to evaluate the basis functions at given points. - - This method must be implemented by subclasses to define the specific behavior - of the basis transformation. The implementation depends on the type of basis - (e.g., spline, raised cosine), and it should evaluate the basis functions at - the specified points in the domain. - - Parameters - ---------- - *xi : - Variable number of arguments, each representing an array of points at which - to evaluate the basis functions. The dimensions and requirements of these - inputs vary depending on the specific basis implementation. - - Returns - ------- - : - An array containing the evaluated values of the basis functions at the input - points. The shape and structure of this array are specific to the subclass - implementation. - """ - pass - - def _get_samples(self, *n_samples: int) -> Generator[NDArray]: - """Get equi-spaced samples for all the input dimensions. - - This will be used to evaluate the basis on a grid of - points derived by the samples. - - Parameters - ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. - - Returns - ------- - : - A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by `n_samples`. - """ - # handling of defaults when evaluating on a grid - # (i.e. when we cannot use max and min of samples) - if self.bounds is None: - mn, mx = 0, 1 - else: - mn, mx = self.bounds - return (np.linspace(mn, mx, n_samples[k]) for k in range(len(n_samples))) - - @support_pynapple(conv_type="numpy") - def _check_transform_input( - self, *xi: ArrayLike - ) -> Tuple[Union[NDArray, Tsd, TsdFrame]]: - """Check transform input. - - Parameters - ---------- - xi[0],...,xi[n] : - The input samples, each with shape (number of samples, ). - - Raises - ------ - ValueError - - If the time point number is inconsistent between inputs. - - If the number of inputs doesn't match what the Basis object requires. - - At least one of the samples is empty. - - """ - # check that the input is array-like (i.e., whether we can cast it to - # numeric arrays) - try: - # make sure array is at least 1d (so that we succeed when only - # passed a scalar) - xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) - # ValueError here surfaces the exception with e.g., `x=np.array["a", "b"])` - except (TypeError, ValueError): - raise TypeError("Input samples must be array-like of floats!") - - # check for non-empty samples - if self._has_zero_samples(tuple(len(x) for x in xi)): - raise ValueError("All sample provided must be non empty.") - - # checks on input and outputs - self._check_samples_consistency(*xi) - self._check_input_dimensionality(xi) - - return xi - - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - The i-th axis of the grid will be sampled with n_samples[i] equi-spaced points. - The method uses numpy.meshgrid with `indexing="ij"`, returning matrix indexing - instead of the default cartesian indexing, see Notes. - - Parameters - ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of - n_samples must equal the number of combined bases. - - Returns - ------- - *Xs : - A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, - where n equals to the number of inputs. - The size of Xs[i] is (n_samples[0], ... , n_samples[n]). - Y : - The basis function evaluated at the samples, - shape (n_samples[0], ... , n_samples[n], number of basis). - - Raises - ------ - ValueError - - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what - the Basis object requires. - - If one of the n_samples is <= 0. - - Notes - ----- - Setting "indexing = 'ij'" returns a meshgrid with matrix indexing. In the N-D case with inputs of size - $M_1,...,M_N$, outputs are of shape $(M_1, M_2, M_3, ....,M_N)$. - This differs from the numpy.meshgrid default, which uses Cartesian indexing. - For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$. - - Examples - -------- - >>> # Evaluate and visualize 4 M-spline basis functions of order 3: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import MSplineBasis - >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> p = plt.plot(sample_points, basis_values) - >>> _ = plt.title('M-Spline Basis Functions') - >>> _ = plt.xlabel('Domain') - >>> _ = plt.ylabel('Basis Function Value') - >>> _ = plt.legend([f'Function {i+1}' for i in range(4)]); - """ - self._check_input_dimensionality(n_samples) - - if self._has_zero_samples(n_samples): - raise ValueError("All sample counts provided must be greater than zero.") - - # get the samples - sample_tuple = self._get_samples(*n_samples) - Xs = np.meshgrid(*sample_tuple, indexing="ij") - - # evaluates the basis on a flat NDArray and reshape to match meshgrid output - Y = self.__call__(*tuple(grid_axis.flatten() for grid_axis in Xs)).reshape( - (*n_samples, self.n_basis_funcs) - ) - - return *Xs, Y - - @staticmethod - def _has_zero_samples(n_samples: Tuple[int, ...]) -> bool: - return any([n <= 0 for n in n_samples]) - - def _check_input_dimensionality(self, xi: Tuple) -> None: - """ - Check that the number of inputs provided by the user matches the number of inputs required. - - Parameters - ---------- - xi[0], ..., xi[n] : - The input samples, shape (number of samples, ). - - Raises - ------ - ValueError - If the number of inputs doesn't match what the Basis object requires. - """ - if len(xi) != self._n_input_dimensionality: - raise TypeError( - f"Input dimensionality mismatch. This basis evaluation requires {self._n_input_dimensionality} inputs, " - f"{len(xi)} inputs provided instead." - ) - - @staticmethod - def _check_samples_consistency(*xi: NDArray) -> None: - """ - Check that each input provided to the Basis object has the same number of time points. - - Parameters - ---------- - xi[0], ..., xi[n] : - The input samples, shape (number of samples, ). - - Raises - ------ - ValueError - If the time point number is inconsistent between inputs. - """ - sample_sizes = [sample.shape[0] for sample in xi] - if any(elem != sample_sizes[0] for elem in sample_sizes): - raise ValueError( - "Sample size mismatch. Input elements have inconsistent sample sizes." - ) - - @abc.abstractmethod - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Most of the basis work with at least 1 element, but some - such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - pass - - def __add__(self, other: Basis) -> AdditiveBasis: - """ - Add two Basis objects together. - - Parameters - ---------- - other - The other Basis object to add. - - Returns - ------- - : AdditiveBasis - The resulting Basis object. - """ - return AdditiveBasis(self, other) - - def __mul__(self, other: Basis) -> MultiplicativeBasis: - """ - Multiply two Basis objects together. - - Parameters - ---------- - other - The other Basis object to multiply. - - Returns - ------- - : - The resulting Basis object. - """ - return MultiplicativeBasis(self, other) - - def __pow__(self, exponent: int) -> MultiplicativeBasis: - """Exponentiation of a Basis object. - - Define the power of a basis by repeatedly applying the method __multiply__. - The exponent must be a positive integer. - - Parameters - ---------- - exponent : - Positive integer exponent - - Returns - ------- - : - The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. - - Raises - ------ - TypeError - If the provided exponent is not an integer. - ValueError - If the integer is zero or negative. - """ - if not isinstance(exponent, int): - raise TypeError("Exponent should be an integer!") - - if exponent <= 0: - raise ValueError("Exponent should be a non-negative integer!") - - result = self - for _ in range(exponent - 1): - result = result * self - return result - - def to_transformer(self) -> TransformerBasis: - """ - Turn the Basis into a TransformerBasis for use with scikit-learn. - - Examples - -------- - Jointly cross-validating basis and GLM parameters with scikit-learn. - - >>> import nemos as nmo - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> # load some data - >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() - >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) - >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) - >>> param_grid = dict( - ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), - ... basis__n_basis_funcs=(3, 5, 10, 20, 100), - ... ) - >>> gridsearch = GridSearchCV( - ... pipeline, - ... param_grid=param_grid, - ... cv=5, - ... ) - >>> gridsearch = gridsearch.fit(X, y) - """ - return TransformerBasis(copy.deepcopy(self)) - - def _get_feature_slicing( - self, - n_inputs: Optional[tuple] = None, - start_slice: Optional[int] = None, - split_by_input: bool = True, - ) -> Tuple[dict, int]: - """ - Calculate and return the slicing for features based on the input structure. - - This method determines how to slice the features for different basis types. - If the instance is of `AdditiveBasis` type, the slicing is calculated recursively - for each component basis. Otherwise, it determines the slicing based on - the number of basis functions and `split_by_input` flag. - - Parameters - ---------- - n_inputs : - The number of input basis for each component, by default it uses `self._n_basis_input`. - start_slice : - The starting index for slicing, by default it starts from 0. - split_by_input : - Flag indicating whether to split the slicing by individual inputs or not. - If `False`, a single slice is generated for all inputs. - - Returns - ------- - split_dict : - Dictionary with keys as labels and values as slices representing - the slicing for each input or additive component, if split_by_input equals to - True or False respectively. - start_slice : - The updated starting index after slicing. - - See Also - -------- - _get_default_slicing : Handles default slicing logic. - _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. - """ - # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input - start_slice = start_slice or 0 - - # If the instance is of AdditiveBasis type, handle slicing for the additive components - if isinstance(self, AdditiveBasis): - split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], - start_slice, - split_by_input=split_by_input, - ) - sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], - start_slice, - split_by_input=split_by_input, - ) - split_dict = self._merge_slicing_dicts(split_dict, sp2) - else: - # Handle the default case for other basis types - split_dict, start_slice = self._get_default_slicing( - split_by_input, start_slice - ) - - return split_dict, start_slice - - def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: - """Merge two slicing dictionaries, handling key conflicts.""" - for key, val in dict2.items(): - if key in dict1: - new_key = self._generate_unique_key(dict1, key) - dict1[new_key] = val - else: - dict1[key] = val - return dict1 - - @staticmethod - def _generate_unique_key(existing_dict: dict, key: str) -> str: - """Generate a unique key if there is a conflict.""" - extra = 1 - new_key = f"{key}-{extra}" - while new_key in existing_dict: - extra += 1 - new_key = f"{key}-{extra}" - return new_key - - def _get_default_slicing( - self, split_by_input: bool, start_slice: int - ) -> Tuple[dict, int]: - """Handle default slicing logic.""" - if split_by_input: - # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): - split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) - } - else: - split_dict = { - self.label: { - f"{i}": slice( - start_slice + i * self.n_basis_funcs, - start_slice + (i + 1) * self.n_basis_funcs, - ) - for i in range(self._n_basis_input[0]) - } - } - else: - split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) - } - start_slice += self._n_output_features - return split_dict, start_slice - - def split_by_feature( - self, - x: NDArray, - axis: int = 1, - ): - r""" - Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. - - This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. - - **How it works:** - - - If the basis expects an input shape `(n_samples, n_inputs)`, then the feature axis length will - be `total_n_features = n_inputs * n_basis_funcs`. This axis is reshaped into dimensions - `(n_inputs, n_basis_funcs)`. - - If the basis expects an input of shape `(n_samples,)`, then the feature axis length will - be `total_n_features = n_basis_funcs`. This axis is reshaped into `(1, n_basis_funcs)`. - - For example, if the input array `x` has shape `(1, 2, total_n_features, 4, 5)`, - then after applying this method, it will be reshaped into `(1, 2, n_inputs, n_basis_funcs, 4, 5)`. - - The specified axis (`axis`) determines where the split occurs, and all other dimensions - remain unchanged. See the example section below for the most common use cases. - - Parameters - ---------- - x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of `x` along the specified axis must match the total - number of features generated by the basis, i.e., `self.n_output_features`. - - **Examples:** - - For a design matrix: `(n_samples, total_n_features)` - - For model coefficients: `(total_n_features,)` or `(total_n_features, n_neurons)`. - - axis : int, optional - The axis along which to split the features. Defaults to 1. - Use `axis=1` for design matrices (features along columns) and `axis=0` for - coefficient arrays (features along rows). All other dimensions are preserved. - - Raises - ------ - ValueError - If the shape of `x` along the specified axis does not match `self.n_output_features`. - - Returns - ------- - dict - A dictionary where: - - **Key**: Label of the basis. - - **Value**: the array reshaped to: `(..., n_inputs, n_basis_funcs, ...) - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import BSplineBasis - >>> from nemos.glm import GLM - >>> # Define an additive basis - >>> basis = BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature") - >>> # Generate a sample input array and compute features - >>> x = np.random.randn(20) - >>> X = basis.compute_features(x) - >>> # Split the feature matrix along axis 1 - >>> split_features = basis.split_by_feature(X, axis=1) - >>> for feature, arr in split_features.items(): - ... print(f"{feature}: shape {arr.shape}") - feature: shape (20, 1, 5) - >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, - ... label="multi_input") - >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) - >>> for feature, sub_dict in split_features_multi.items(): - ... print(f"{feature}, shape {sub_dict.shape}") - multi_input, shape (20, 2, 6) - >>> # the method can be used to decompose the glm coefficients in the various features - >>> counts = np.random.poisson(size=20) - >>> model = GLM().fit(X, counts) - >>> split_coef = basis.split_by_feature(model.coef_, axis=0) - >>> for feature, coef in split_coef.items(): - ... print(f"{feature}: shape {coef.shape}") - feature: shape (1, 5) - - """ - if x.shape[axis] != self.n_output_features: - raise ValueError( - "`x.shape[axis]` does not match the expected number of features." - f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " - f"of features is {self.n_output_features}" - ) - - # Get the slice dictionary based on predefined feature slicing - slice_dict = self._get_feature_slicing(split_by_input=False)[0] - - # Helper function to build index tuples for each slice - def build_index_tuple(slice_obj, axis: int, ndim: int): - """Create an index tuple to apply a slice on the given axis.""" - index = [slice(None)] * ndim # Initialize index for all dimensions - index[axis] = slice_obj # Replace the axis with the slice object - return tuple(index) - - # Get the dict for slicing the correct axis - index_dict = jax.tree_util.tree_map( - lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict - ) - - # Custom leaf function to identify index tuples as leaves - def is_leaf(val): - # Check if it's a tuple, length matches ndim, and all elements are slice objects - if isinstance(val, tuple) and len(val) == x.ndim: - return all(isinstance(v, slice) for v in val) - return False - - # Apply the slicing using the custom leaf function - out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - - # reshape the arrays to spilt by n_basis_input - reshaped_out = dict() - for i, vals in enumerate(out.items()): - key, val = vals - shape = list(val.shape) - reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] - ) - return reshaped_out - - def _check_input_shape_consistency(self, x: NDArray): - """Check input consistency across calls.""" - # remove sample axis - shape = x.shape[1:] - if self._input_shape is not None and self._input_shape != shape: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] - expected_shape_str = expected_shape_str.replace(",)", ")") - raise ValueError( - f"Input shape mismatch detected.\n\n" - f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " - f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" - f" Expected: {expected_shape_str}\n" - f" But got: {x.shape}.\n\n" - "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " - "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance." - ) - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - """ - Pre-compute the number of inputs and output features. - - This function computes the number of inputs that are provided to the basis and uses - that number, and the n_basis_funcs to calculate the number of output features that - `self.compute_features` will return. These quantities and the input shape (excluding the sample axis) - are stored in `self._n_basis_input` and `self._n_output_features`, and `self._input_shape` - respectively. - - Parameters - ---------- - xi: - The input arrays. - - Returns - ------- - : - The basis itself, for chaining. - - Raises - ------ - ValueError: - If the number of inputs do not match `self._n_basis_input`, if `self._n_basis_input` was - not None. - - Notes - ----- - Once a `compute_features` is called, we enforce that for all subsequent calls of the method, - the input that the basis receives preserves the shape of all axes, except for the sample axis. - This condition guarantees the consistency of the feature axis, and therefore that - `self.split_by_feature` behaves appropriately. - - """ - # Check that the input shape matches expectation - # Note that this method is reimplemented in AdditiveBasis and MultiplicativeBasis - # so we can assume that len(xi) == 1 - xi = xi[0] - self._check_input_shape_consistency(xi) - - # remove sample axis (samples are allowed to vary) - shape = xi.shape[1:] - - self._input_shape = shape - - # remove sample axis & get the total input number - n_inputs = (1,) if xi.ndim == 1 else (np.prod(shape),) - - self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] - return self - - -class AdditiveBasis(Basis): - """ - Class representing the addition of two Basis objects. - - Parameters - ---------- - basis1 : - First basis object to add. - basis2 : - Second basis object to add. - - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - - Examples - -------- - >>> # Generate sample data - >>> import numpy as np - >>> import nemos as nmo - >>> X = np.random.normal(size=(30, 2)) - - >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) - >>> additive_basis = basis_1 + basis_2 - - >>> # can add another basis to the AdditiveBasis object - >>> X = np.random.normal(size=(30, 3)) - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) - >>> additive_basis_2 = additive_basis + basis_3 - """ - - def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") - self._n_input_dimensionality = ( - basis1._n_input_dimensionality + basis2._n_input_dimensionality - ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - return - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features - ) - return self - - def _check_n_basis_min(self) -> None: - pass - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Evaluate the basis at the input samples. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - - """ - X = np.hstack( - ( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), - ) - ) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute features for added bases and concatenate. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The features, shape (n_samples, n_basis_funcs) - - """ - # the numpy conversion is important, there is some in-place - # array modification in basis. - hstack_pynapple = support_pynapple(conv_type="numpy")(np.hstack) - X = hstack_pynapple( - ( - self._basis1._compute_features( - *xi[: self._basis1._n_input_dimensionality] - ), - self._basis2._compute_features( - *xi[self._basis1._n_input_dimensionality :] - ), - ), - ) - return X - - def _set_kernel(self, *xi: ArrayLike) -> Basis: - """Call fit on the added basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to `scikit-learn` API. - - Returns - ------- - : - The AdditiveBasis ready to be evaluated. - """ - self._basis1._set_kernel(*xi) - self._basis2._set_kernel(*xi) - return self + def _get_default_slicing( + self, split_by_input: bool, start_slice: int + ) -> Tuple[dict, int]: + """Handle default slicing logic.""" + if split_by_input: + # should we remove this option? + if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): + split_dict = { + self.label: slice( + start_slice, start_slice + self._n_output_features + ) + } + else: + split_dict = { + self.label: { + f"{i}": slice( + start_slice + i * self.n_basis_funcs, + start_slice + (i + 1) * self.n_basis_funcs, + ) + for i in range(self._n_basis_input[0]) + } + } + else: + split_dict = { + self.label: slice(start_slice, start_slice + self._n_output_features) + } + start_slice += self._n_output_features + return split_dict, start_slice def split_by_feature( self, @@ -1658,36 +783,21 @@ def split_by_feature( axis: int = 1, ): r""" - Decompose an array along a specified axis into sub-arrays based on the basis components. + Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. Each split corresponds to a different additive component of the basis, - preserving all dimensions except the specified axis. - - **How It Works:** - - Suppose the basis is made up of **m components**, each with $b_i$ basis functions and $n_i$ inputs. - The total number of features, $N$, is calculated as: - - $$ - N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m - $$ - - This method splits any axis of length $N$ into sub-arrays, one for each basis component. - - The sub-array for the i-th basis component is reshaped into dimensions - $(n_i, b_i)$. - - For example, if the array shape is $(1, 2, N, 4, 5)$, then each split sub-array will have shape: + a designated axis. - $$ - (1, 2, n_i, b_i, 4, 5) - $$ + **How it works:** - where: + - If the basis expects an input shape `(n_samples, n_inputs)`, then the feature axis length will + be `total_n_features = n_inputs * n_basis_funcs`. This axis is reshaped into dimensions + `(n_inputs, n_basis_funcs)`. + - If the basis expects an input of shape `(n_samples,)`, then the feature axis length will + be `total_n_features = n_basis_funcs`. This axis is reshaped into `(1, n_basis_funcs)`. - - $n_i$ represents the number of inputs associated with the i-th component, - - $b_i$ represents the number of basis functions in that component. + For example, if the input array `x` has shape `(1, 2, total_n_features, 4, 5)`, + then after applying this method, it will be reshaped into `(1, 2, n_inputs, n_basis_funcs, 4, 5)`. The specified axis (`axis`) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases. @@ -1717,18 +827,8 @@ def split_by_feature( ------- dict A dictionary where: - - **Keys**: Labels of the additive basis components. - - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: - - $$ - (..., n_i, b_i, ...) - $$ - - - `n_i`: The number of inputs processed by the i-th basis component. - - `b_i`: The number of basis functions for the i-th basis component. - - These sub-arrays are reshaped along the specified axis, with all other dimensions - remaining the same. + - **Key**: Label of the basis. + - **Value**: the array reshaped to: `(..., n_inputs, n_basis_funcs, ...) Examples -------- @@ -1736,19 +836,15 @@ def split_by_feature( >>> from nemos.basis import BSplineBasis >>> from nemos.glm import GLM >>> # Define an additive basis - >>> basis = ( - ... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") + - ... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2") - ... ) + >>> basis = BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature") >>> # Generate a sample input array and compute features - >>> x1, x2 = np.random.randn(20), np.random.randn(20) - >>> X = basis.compute_features(x1, x2) + >>> x = np.random.randn(20) + >>> X = basis.compute_features(x) >>> # Split the feature matrix along axis 1 >>> split_features = basis.split_by_feature(X, axis=1) >>> for feature, arr in split_features.items(): ... print(f"{feature}: shape {arr.shape}") - feature_1: shape (20, 1, 5) - feature_2: shape (20, 1, 6) + feature: shape (20, 1, 5) >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, ... label="multi_input") @@ -1763,1516 +859,895 @@ def split_by_feature( >>> split_coef = basis.split_by_feature(model.coef_, axis=0) >>> for feature, coef in split_coef.items(): ... print(f"{feature}: shape {coef.shape}") - feature_1: shape (1, 5) - feature_2: shape (1, 6) - - """ - return super().split_by_feature(x, axis=axis) - - -class MultiplicativeBasis(Basis): - """ - Class representing the multiplication (external product) of two Basis objects. - - Parameters - ---------- - basis1 : - First basis object to multiply. - basis2 : - Second basis object to multiply. - - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - - Examples - -------- - >>> # Generate sample data - >>> import numpy as np - >>> import nemos as nmo - >>> X = np.random.normal(size=(30, 3)) - - >>> # define two basis and multiply - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) - >>> multiplicative_basis = basis_1 * basis_2 - - >>> # Can multiply or add another basis to the AdditiveBasis object - >>> # This will cause the number of output features of the result basis to grow accordingly - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) - >>> multiplicative_basis_2 = multiplicative_basis * basis_3 - """ - - def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") - self._n_input_dimensionality = ( - basis1._n_input_dimensionality + basis2._n_input_dimensionality - ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - return - - def _check_n_basis_min(self) -> None: - pass - - def _set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to `scikit-learn` API. - - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. - """ - self._basis1._set_kernel(*xi) - self._basis2._set_kernel(*xi) - return self - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Evaluate the basis at the input samples. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. + feature: shape (1, 5) - Returns - ------- - : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) """ - X = np.asarray( - row_wise_kron( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), - transpose=False, + if x.shape[axis] != self.n_output_features: + raise ValueError( + "`x.shape[axis]` does not match the expected number of features." + f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " + f"of features is {self.n_output_features}" ) - ) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute the features for the multiplied bases, and compute their outer product. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The features, shape (n_samples, n_basis_funcs) - - """ - kron = support_pynapple(conv_type="numpy")(row_wise_kron) - X = kron( - self._basis1._compute_features(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._compute_features(*xi[self._basis1._n_input_dimensionality :]), - transpose=False, - ) - return X - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features - ) - return self - - -class SplineBasis(Basis, abc.ABC): - """ - SplineBasis class inherits from the Basis class and represents spline basis functions. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : optional - Spline order. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. - Attributes - ---------- - order : int - Spline order. - """ + # Get the slice dictionary based on predefined feature slicing + slice_dict = self._get_feature_slicing(split_by_input=False)[0] - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 2, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = None, - **kwargs, - ) -> None: - self.order = order - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, + # Helper function to build index tuples for each slice + def build_index_tuple(slice_obj, axis: int, ndim: int): + """Create an index tuple to apply a slice on the given axis.""" + index = [slice(None)] * ndim # Initialize index for all dimensions + index[axis] = slice_obj # Replace the axis with the slice object + return tuple(index) + + # Get the dict for slicing the correct axis + index_dict = jax.tree_util.tree_map( + lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict ) - self._n_input_dimensionality = 1 + # Custom leaf function to identify index tuples as leaves + def is_leaf(val): + # Check if it's a tuple, length matches ndim, and all elements are slice objects + if isinstance(val, tuple) and len(val) == x.ndim: + return all(isinstance(v, slice) for v in val) + return False - @property - def order(self): - return self._order + # Apply the slicing using the custom leaf function + out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - @order.setter - def order(self, value): - """Setter for the order parameter.""" - if value != int(value): - raise ValueError( - f"Spline order must be an integer! Order {value} provided." + # reshape the arrays to spilt by n_basis_input + reshaped_out = dict() + for i, vals in enumerate(out.items()): + key, val = vals + shape = list(val.shape) + reshaped_out[key] = val.reshape( + shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] ) - value = int(value) - if value < 1: - raise ValueError(f"Spline order must be positive! Order {value} provided.") - - # Set to None only the first time the setter is called. - orig_order = copy.deepcopy(getattr(self, "_order", None)) - - # Set the order - self._order = value + return reshaped_out - # If the order was already initialized, re-check basis - if orig_order is not None: - try: - self._check_n_basis_min() - except ValueError as e: - self._order = orig_order - raise e + def _check_input_shape_consistency(self, x: NDArray): + """Check input consistency across calls.""" + # remove sample axis + shape = x.shape[1:] + if self._input_shape is not None and self._input_shape != shape: + expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] + expected_shape_str = expected_shape_str.replace(",)", ")") + raise ValueError( + f"Input shape mismatch detected.\n\n" + f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " + f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" + f" Expected: {expected_shape_str}\n" + f" But got: {x.shape}.\n\n" + "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " + "but all other dimensions must remain the same. If you need to process inputs with a " + "different shape, please create a new basis instance." + ) - def _generate_knots( - self, - sample_pts: NDArray, - perc_low: float = 0.0, - perc_high: float = 1.0, - is_cyclic: bool = False, - ) -> NDArray: + def _set_num_output_features(self, *xi: NDArray) -> Basis: """ - Generate knot locations for spline basis functions. + Pre-compute the number of inputs and output features. + + This function computes the number of inputs that are provided to the basis and uses + that number, and the n_basis_funcs to calculate the number of output features that + `self.compute_features` will return. These quantities and the input shape (excluding the sample axis) + are stored in `self._n_basis_input` and `self._n_output_features`, and `self._input_shape` + respectively. Parameters ---------- - sample_pts : (n_samples,) - The sample points. - perc_low - The low percentile value, between [0,1). - perc_high - The high percentile value, between (0,1]. - is_cyclic : optional - Whether the spline is cyclic. + xi: + The input arrays. Returns ------- - The knot locations for the spline basis functions. + : + The basis itself, for chaining. Raises ------ - AssertionError - If the percentiles or order values are not within the valid range. + ValueError: + If the number of inputs do not match `self._n_basis_input`, if `self._n_basis_input` was + not None. + + Notes + ----- + Once a `compute_features` is called, we enforce that for all subsequent calls of the method, + the input that the basis receives preserves the shape of all axes, except for the sample axis. + This condition guarantees the consistency of the feature axis, and therefore that + `self.split_by_feature` behaves appropriately. + """ - # Determine number of interior knots. - num_interior_knots = self.n_basis_funcs - self.order - if is_cyclic: - num_interior_knots += self.order - 1 - - # Spline basis have support on the semi-open [a, b) interval, we add a small epsilon - # to mx so that the so that basis_element(max(samples)) != 0 - knot_locs = np.concatenate( - ( - np.zeros(self.order - 1), - np.linspace(0, (1 + np.finfo(float).eps), num_interior_knots + 2), - np.full(self.order - 1, 1 + np.finfo(float).eps), - ) - ) - return knot_locs + # Check that the input shape matches expectation + # Note that this method is reimplemented in AdditiveBasis and MultiplicativeBasis + # so we can assume that len(xi) == 1 + xi = xi[0] + self._check_input_shape_consistency(xi) - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. + # remove sample axis (samples are allowed to vary) + shape = xi.shape[1:] - Check that the spline-basis has at least as many basis as the order. + self._input_shape = shape - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - if self.n_basis_funcs < self.order: - raise ValueError( - f"{self.__class__.__name__} `order` parameter cannot be larger " - "than `n_basis_funcs` parameter." - ) + # remove sample axis & get the total input number + n_inputs = (1,) if xi.ndim == 1 else (np.prod(shape),) + + self._n_basis_input = n_inputs + self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] + return self -class MSplineBasis(SplineBasis): - r""" - M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation. +class TransformerBasis: + """Basis as `scikit-learn` transformers. - M-splines are a type of spline basis function used for smooth curve fitting - and data representation. They are positive and integrate to one, making them - suitable for probabilistic models and density estimation. The order of an - M-spline defines its smoothness, with higher orders resulting in smoother - splines. + This class abstracts the underlying basis function details, offering methods + similar to scikit-learn's transformers but specifically designed for basis + transformations. It supports fitting to data (calculating any necessary parameters + of the basis functions), transforming data (applying the basis functions to + data), and both fitting and transforming in one step. - This class provides functionality to create M-spline basis functions, allowing - for flexible and smooth modeling of data. It inherits from the `SplineBasis` - abstract class, providing specific implementations for M-splines. + `TransformerBasis`, unlike `Basis`, is compatible with scikit-learn pipelining and + model selection, enabling the cross-validation of the basis type and parameters, + for example `n_basis_funcs`. See the example section below. Parameters ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : - The order of the splines used in basis functions. Must be between [1, - n_basis_funcs]. Default is 2. Higher order splines have more continuous - derivatives at each interior knot, resulting in smoother basis functions. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs: - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. + basis : + A concrete subclass of `Basis`. Examples -------- - >>> from numpy import linspace - >>> from nemos.basis import MSplineBasis - >>> n_basis_funcs = 5 - >>> order = 3 - >>> mspline_basis = MSplineBasis(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) - - # References - ------------ - [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, - 3(4), 425-441. - - Notes - ----- - MSplines must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of $\alpha$, the values on the co-domain (y-axis) values - will shrink by a factor of $1/\alpha$. - For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. - If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 2, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "MSplineBasis", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: - """ - Evaluate the M-spline basis functions at given sample points. - - Parameters - ---------- - sample_pts : - An array of sample points where the M-spline basis functions are to be - evaluated. + >>> from nemos.basis import BSplineBasis, TransformerBasis + >>> from nemos.glm import GLM + >>> from sklearn.pipeline import Pipeline + >>> from sklearn.model_selection import GridSearchCV + >>> import numpy as np + >>> np.random.seed(123) - Returns - ------- - : - An array where each column corresponds to one M-spline basis function - evaluated at the input sample points. The shape of the array is - (len(sample_pts), n_basis_funcs). + >>> # Generate data + >>> num_samples, num_features = 10000, 1 + >>> x = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = BSplineBasis(10) + >>> features = basis.compute_features(x) # basis transformed time series + >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights + >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts - Notes - ----- - The implementation uses a recursive definition of M-splines. Boundary - conditions are handled such that the basis functions are positive and - integrate to one over the domain defined by the sample points. - """ - sample_pts, scaling = min_max_rescale_samples(sample_pts, self.bounds) - # add knots if not passed - knot_locs = self._generate_knots( - sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False - ) + >>> # transformer can be used in pipelines + >>> transformer = TransformerBasis(basis) + >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) + >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API + >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas + >>> # TransformerBasis parameter can be cross-validated. + >>> # 5-fold cross-validate the number of basis + >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) + >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) + >>> grid_cv = grid_cv.fit(x[:, None], y) + >>> print("Cross-validated number of basis:", grid_cv.best_params_) + Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} + """ - X = np.stack( - [ - mspline(sample_pts, self.order, i, knot_locs) - for i in range(self.n_basis_funcs) - ], - axis=1, - ) - # re-normalize so that it integrates to 1 over the range. - X /= scaling - return X + def __init__(self, basis: Basis): + self._basis = copy.deepcopy(basis) - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """ - Evaluate the M-spline basis functions on a uniformly spaced grid. + @staticmethod + def _unpack_inputs(X: FeatureMatrix): + """Unpack impute without using transpose. - This method creates a uniformly spaced grid of sample points within the domain - [0, 1] and evaluates all the M-spline basis functions at these points. It is - particularly useful for visualizing the shape and distribution of the basis - functions across their domain. + Unpack horizontally stacked inputs using slicing. This works gracefully with `pynapple`, + returning a list of Tsd objects. Attempt to unpack using *X.T will raise a `pynapple` + exception since `pynapple` assumes that the time axis is the first axis. Parameters ---------- - n_samples : - The number of points in the uniformly spaced grid. A higher number of - samples will result in a more detailed visualization of the basis functions. + X: + The inputs horizontally stacked. Returns ------- - X : NDArray - A 1D array of uniformly spaced sample points within the domain [0, 1]. - Shape: `(n_samples,)`. - Y : NDArray - A 2D array where each row corresponds to the evaluated M-spline basis - function values at the points in X. Shape: `(n_samples, n_basis_funcs)`. - - Examples - -------- - Evaluate and visualize 4 M-spline basis functions of order 3: + : + A tuple of each individual input. - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import MSplineBasis - >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> for i in range(4): - ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') - >>> plt.title('M-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') - >>> plt.xlabel('Domain') - Text(0.5, 0, 'Domain') - >>> plt.ylabel('Basis Function Value') - Text(0, 0.5, 'Basis Function Value') - >>> l = plt.legend() """ - return super().evaluate_on_grid(n_samples) - - -class BSplineBasis(SplineBasis): - """ - B-spline[$^{[1]}$](#references) 1-dimensional basis functions. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : - Order of the splines used in basis functions. Must lie within [1, n_basis_funcs]. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. - - Attributes - ---------- - order : - Spline order. - - - # References - ------------ - [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. - Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import BSplineBasis - - >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis(sample_points) - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "BSplineBasis", - **kwargs, - ): - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) + return (X[:, k] for k in range(X.shape[1])) - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: + def fit(self, X: FeatureMatrix, y=None): """ - Evaluate the B-spline basis functions with given sample points. + Compute the convolutional kernels. + + If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. Parameters ---------- - sample_pts : - The sample points at which the B-spline is evaluated, shape (n_samples,). + X : + The data to fit the basis functions to, shape (num_samples, num_input). + y : ignored + Not used, present for API consistency by convention. Returns ------- - basis_funcs : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs). + self : + The transformer object. - Raises - ------ - AssertionError - If the sample points are not within the B-spline knots. + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline, TransformerBasis - Notes - ----- - The evaluation is performed by looping over each element and using `splev` - from SciPy to compute the basis values. - """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - # add knots - knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) + >>> # Example input + >>> X = np.random.normal(size=(100, 2)) - basis_eval = bspline( - sample_pts, knot_locs, order=self.order, der=0, outer_ok=False - ) - return basis_eval + >>> # Define and fit tranformation basis + >>> basis = EvalMSpline(10) + >>> transformer = TransformerBasis(basis) + >>> transformer_fitted = transformer.fit(X) + """ + self._basis._set_kernel() + return self - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the B-spline basis set on a grid of equi-spaced sample points. + def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: + """ + Transform the data using the fitted basis functions. Parameters ---------- - n_samples : - The number of samples. + X : + The data to transform using the basis functions, shape (num_samples, num_input). + y : + Not used, present for API consistency by convention. Returns ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape (n_samples, n_basis_funcs) - - Notes - ----- - The evaluation is performed by looping over each element and using `splev` from - SciPy to compute the basis values. + : + The data transformed by the basis functions. Examples -------- >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import BSplineBasis - >>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - + >>> from nemos.basis import EvalMSpline, TransformerBasis -class CyclicBSplineBasis(SplineBasis): - """ - B-spline 1-dimensional basis functions for cyclic splines. + >>> # Example input + >>> X = np.random.normal(size=(10000, 2)) - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : - Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. + >>> # Define and fit tranformation basis + >>> basis = EvalMSpline(10, mode="conv", window_size=200) + >>> transformer = TransformerBasis(basis) + >>> # Before calling `fit` the convolution kernel is not set + >>> transformer.kernel_ - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - order : int - Order of the splines used in basis functions. + >>> transformer_fitted = transformer.fit(X) + >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) + >>> transformer_fitted.kernel_.shape + (200, 10) - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import CyclicBSplineBasis - >>> X = np.random.normal(size=(1000, 1)) + >>> # Transform basis + >>> feature_transformed = transformer.transform(X[:, 0:1]) + """ + # transpose does not work with pynapple + # can't use func(*X.T) to unwrap - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_basis(sample_points) - """ + return self._basis._compute_features(*self._unpack_inputs(X)) - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "CyclicBSplineBasis", - **kwargs, - ): - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - if self.order < 2: - raise ValueError( - f"Order >= 2 required for cyclic B-spline, " - f"order {self.order} specified instead!" - ) + def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: + """ + Compute the kernels and the features. - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Evaluate the Cyclic B-spline basis functions with given sample points. + This method is a convenience that combines fit and transform into + one step. Parameters ---------- - sample_pts : - The sample points at which the cyclic B-spline is evaluated, shape - (n_samples,). + X : + The data to fit the basis functions to and then transform. + y : + Not used, present for API consistency by convention. Returns ------- - basis_funcs : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + array-like + The data transformed by the basis functions, after fitting the basis + functions to the data. - Notes - ----- - The evaluation is performed by looping over each element and using `splev` from - SciPy to compute the basis values. + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(100, 1)) + + >>> # Define tranformation basis + >>> basis = EvalMSpline(10) + >>> transformer = TransformerBasis(basis) + >>> # Fit and transform basis + >>> feature_transformed = transformer.fit_transform(X) """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) + return self._basis.compute_features(*self._unpack_inputs(X)) - # for cyclic, do not repeat knots - knot_locs = np.unique(knot_locs) + def __getstate__(self): + """ + Explicitly define how to pickle TransformerBasis object. - nk = knot_locs.shape[0] + See https://docs.python.org/3/library/pickle.html#object.__getstate__ + and https://docs.python.org/3/library/pickle.html#pickle-state + """ + return {"_basis": self._basis} - # make sure knots are sorted - knot_locs.sort() + def __setstate__(self, state): + """ + Define how to populate the object's state when unpickling. - # extend knots - xc = knot_locs[nk - self.order] - knots = np.hstack( - ( - knot_locs[0] - knot_locs[-1] + knot_locs[nk - self.order : nk - 1], - knot_locs, - ) - ) + Note that during unpickling a new object is created without calling __init__. + Needed to avoid infinite recursion in __getattr__ when unpickling. - ind = sample_pts > xc + See https://docs.python.org/3/library/pickle.html#object.__setstate__ + and https://docs.python.org/3/library/pickle.html#pickle-state + """ + self._basis = state["_basis"] - basis_eval = bspline(sample_pts, knots, order=self.order, der=0, outer_ok=True) - sample_pts[ind] = sample_pts[ind] - knots.max() + knot_locs[0] + def __getattr__(self, name: str): + """ + Enable easy access to attributes of the underlying Basis object. - if np.sum(ind): - basis_eval[ind] = basis_eval[ind] + bspline( - sample_pts[ind], knots, order=self.order, outer_ok=True, der=0 - ) - # restore points - sample_pts[ind] = sample_pts[ind] + knots.max() - knot_locs[0] - return basis_eval + Examples + -------- + >>> from nemos import basis + >>> bas = basis.RaisedCosineBasisLinear(5) + >>> trans_bas = basis.TransformerBasis(bas) + >>> bas.n_basis_funcs + 5 + >>> trans_bas.n_basis_funcs + 5 + """ + return getattr(self._basis, name) - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. + def __setattr__(self, name: str, value) -> None: + r""" + Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. - Parameters - ---------- - n_samples : - The number of samples. + Setting any other attribute is not allowed. Returns ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape (n_samples, n_basis_funcs) + None - Notes - ----- - The evaluation is performed by looping over each element and using `splev` from - SciPy to compute the basis values. + Raises + ------ + ValueError + If the attribute being set is not `_basis` or an attribute of `_basis`. Examples -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import CyclicBSplineBasis - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100) + >>> import nemos as nmo + >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.EvalMSpline(10)) + >>> # allowed + >>> trans_bas._basis = nmo.basis.BSplineBasis(10) + >>> # allowed + >>> trans_bas.n_basis_funcs = 20 + >>> # not allowed + >>> try: + ... trans_bas.random_attribute_name = "some value" + ... except ValueError as e: + ... print(repr(e)) + ValueError('Only setting _basis or existing attributes of _basis is allowed.') + """ + # allow self._basis = basis + if name == "_basis": + super().__setattr__(name, value) + # allow changing existing attributes of self._basis + elif hasattr(self._basis, name): + setattr(self._basis, name, value) + # don't allow setting any other attribute + else: + raise ValueError( + "Only setting _basis or existing attributes of _basis is allowed." + ) + + def __sklearn_clone__(self) -> TransformerBasis: """ - return super().evaluate_on_grid(n_samples) + Customize how TransformerBasis objects are cloned when used with sklearn.model_selection. + By default, scikit-learn tries to clone the object by calling __init__ using the output of get_params, + which fails in our case. -class RaisedCosineBasisLinear(Basis): - """Represent linearly-spaced raised cosine basis functions. + For more info: https://scikit-learn.org/stable/developers/develop.html#cloning + """ + cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) + cloned_obj._basis.kernel_ = None + return cloned_obj - This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) - to uniformly tile the internal points of the domain. + def set_params(self, **parameters) -> TransformerBasis: + """ + Set TransformerBasis parameters. - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - width : - Width of the raised cosine. By default, it's set to 2.0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. + When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly + or the parameters of the underlying Basis, but not both. - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLinear - >>> X = np.random.normal(size=(1000, 1)) - - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - - # References - ------------ - [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - """ + Examples + -------- + >>> from nemos.basis import BSplineBasis, EvalMSpline, TransformerBasis + >>> basis = EvalMSpline(10) + >>> transformer_basis = TransformerBasis(basis=basis) + + >>> # setting parameters of _basis is allowed + >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) + 8 + >>> # setting _basis directly is allowed + >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)) + + >>> # mixing is not allowed, this will raise an exception + >>> try: + ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) + ... except ValueError as e: + ... print(repr(e)) + ValueError('Set either new _basis object or parameters for existing _basis, not both.') + """ + new_basis = parameters.pop("_basis", None) + if new_basis is not None: + self._basis = new_basis + if len(parameters) > 0: + raise ValueError( + "Set either new _basis object or parameters for existing _basis, not both." + ) + else: + self._basis = self._basis.set_params(**parameters) - def __init__( - self, - n_basis_funcs: int, - mode="eval", - width: float = 2.0, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "RaisedCosineBasisLinear", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - self._n_input_dimensionality = 1 - self._check_width(width) - self._width = width - # for these linear raised-cosine basis functions, - # the samples must be rescaled to 0 and 1. - self._rescale_samples = True + return self - @property - def width(self): - """Return width of the raised cosine.""" - return self._width + def get_params(self, deep: bool = True) -> dict: + """Extend the dict of parameters from the underlying Basis with _basis.""" + return {"_basis": self._basis, **self._basis.get_params(deep)} - @width.setter - def width(self, width: float): - self._check_width(width) - self._width = width + def __dir__(self) -> list[str]: + """Extend the list of properties of methods with the ones from the underlying Basis.""" + return super().__dir__() + self._basis.__dir__() - @staticmethod - def _check_width(width: float) -> None: - """Validate the width value. + def __add__(self, other: TransformerBasis) -> TransformerBasis: + """ + Add two TransformerBasis objects. Parameters ---------- - width : - The width value to validate. + other + The other TransformerBasis object to add. - Raises - ------ - ValueError - If width <= 1 or 2*width is not a positive integer. Values that do not match - this constraint will result in: - - No overlap between bumps (width < 1). - - Oscillatory behavior when summing the basis elements (2*width not integer). + Returns + ------- + : TransformerBasis + The resulting Basis object. """ - if width <= 1 or (not np.isclose(width * 2, round(2 * width))): - raise ValueError( - f"Invalid raised cosine width. " - f"2*width must be a positive integer, 2*width = {2 * width} instead!" - ) + return TransformerBasis(self._basis + other._basis) - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Generate basis functions with given samples. + def __mul__(self, other: TransformerBasis) -> TransformerBasis: + """ + Multiply two TransformerBasis objects. Parameters ---------- - sample_pts : - Spacing for basis functions, holding elements on interval [0, 1], Shape (number of samples, ). - - Raises - ------ - ValueError - If the sample provided do not lie in [0,1]. - - """ - if self._rescale_samples: - # note that sample points is converted to NDArray - # with the decorator. - # copy is necessary otherwise: - # basis1 = nmo.basis.RaisedCosineBasisLinear(5) - # basis2 = nmo.basis.RaisedCosineBasisLog(5) - # additive_basis = basis1 + basis2 - # additive_basis(*([x] * 2)) would modify both inputs - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) - - peaks = self._compute_peaks() - delta = peaks[1] - peaks[0] - # generate a set of shifted cosines, and constrain them to be non-zero - # over a single period, then enforce the codomain to be [0,1], by adding 1 - # and then multiply by 0.5 - basis_funcs = 0.5 * ( - np.cos( - np.clip( - np.pi * (sample_pts[:, None] - peaks[None]) / (delta * self.width), - -np.pi, - np.pi, - ) - ) - + 1 - ) - return basis_funcs - - def _compute_peaks(self) -> NDArray: - """ - Compute the location of raised cosine peaks. + other + The other TransformerBasis object to multiply. Returns ------- - Peak locations of each basis element. + : + The resulting Basis object. """ - return np.linspace(0, 1, self.n_basis_funcs) + return TransformerBasis(self._basis * other._basis) - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. + def __pow__(self, exponent: int) -> TransformerBasis: + """Exponentiation of a TransformerBasis object. + + Define the power of a basis by repeatedly applying the method __mul__. + The exponent must be a positive integer. Parameters ---------- - n_samples : - The number of samples. + exponent : + Positive integer exponent Returns ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape (n_samples, n_basis_funcs) - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import RaisedCosineBasisLinear - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Check that the number of basis is at least 2. + : + The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. Raises ------ + TypeError + If the provided exponent is not an integer. ValueError - If n_basis_funcs < 2. + If the integer is zero or negative. """ - if self.n_basis_funcs < 2: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 2 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) - + # errors are handled by Basis.__pow__ + return TransformerBasis(self._basis**exponent) -class RaisedCosineBasisLog(RaisedCosineBasisLinear): - """Represent log-spaced raised cosine basis functions. - Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) - to uniformly tile the internal points of the domain. +class AdditiveBasis(Basis): + """ + Class representing the addition of two Basis objects. Parameters ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - width : - Width of the raised cosine. - time_scaling : - Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with - larger values resulting in more stretching. As this approaches 0, the - transformation becomes linear. - enforce_decay_to_zero: - If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements - and subsequently trims off the extra basis elements. This ensures that the final basis element - decays to 0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. + basis1 : + First basis object to add. + basis2 : + Second basis object to add. + + Attributes + ---------- + n_basis_funcs : int + Number of basis functions. Examples -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLog - >>> X = np.random.normal(size=(1000, 1)) - - >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - - # References - ------------ - [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 - """ + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + >>> X = np.random.normal(size=(30, 2)) - def __init__( - self, - n_basis_funcs: int, - mode="eval", - width: float = 2.0, - time_scaling: float = None, - enforce_decay_to_zero: bool = True, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "RaisedCosineBasisLog", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - width=width, - window_size=window_size, - bounds=bounds, - **kwargs, - label=label, - ) - # The samples are scaled appropriately in the self._transform_samples which scales - # and applies the log-stretch, no additional transform is needed. - self._rescale_samples = False - if time_scaling is None: - time_scaling = 50.0 + >>> # define two basis objects and add them + >>> basis_1 = nmo.basis.BSplineBasis(10) + >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> additive_basis = basis_1 + basis_2 - self.time_scaling = time_scaling - self.enforce_decay_to_zero = enforce_decay_to_zero + >>> # can add another basis to the AdditiveBasis object + >>> X = np.random.normal(size=(30, 3)) + >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> additive_basis_2 = additive_basis + basis_3 + """ - @property - def time_scaling(self): - """Getter property for time_scaling.""" - return self._time_scaling + def __init__(self, basis1: Basis, basis2: Basis) -> None: + self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs + super().__init__(self.n_basis_funcs, mode="eval") + self._n_input_dimensionality = ( + basis1._n_input_dimensionality + basis2._n_input_dimensionality + ) + self._n_basis_input = None + self._n_output_features = None + self._label = "(" + basis1.label + " + " + basis2.label + ")" + self._basis1 = basis1 + self._basis2 = basis2 + return - @time_scaling.setter - def time_scaling(self, time_scaling): - """Setter property for time_scaling.""" - self._check_time_scaling(time_scaling) - self._time_scaling = time_scaling + def _set_num_output_features(self, *xi: NDArray) -> Basis: + self._n_basis_input = ( + *self._basis1._set_num_output_features( + *xi[: self._basis1._n_input_dimensionality] + )._n_basis_input, + *self._basis2._set_num_output_features( + *xi[self._basis1._n_input_dimensionality :] + )._n_basis_input, + ) + self._n_output_features = ( + self._basis1.n_output_features + self._basis2.n_output_features + ) + return self - @staticmethod - def _check_time_scaling(time_scaling: float) -> None: - if time_scaling <= 0: - raise ValueError( - f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." - ) + def _check_n_basis_min(self) -> None: + pass - def _transform_samples( - self, - sample_pts: ArrayLike, - ) -> NDArray: + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__(self, *xi: ArrayLike) -> FeatureMatrix: """ - Map the sample domain to log-space. + Evaluate the basis at the input samples. Parameters ---------- - sample_pts : - Sample points used for evaluating the splines, - shape (n_samples, ). + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- - Transformed version of the sample points that matches the Raised Cosine basis domain, - shape (n_samples, ). + : + The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + """ - # rescale to [0,1] - # copy is necessary to avoid unwanted rescaling in additive/multiplicative basis. - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) - # This log-stretching of the sample axis has the following effect: - # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. - # - as the time_scaling tends to inf, basis will be small and dense around 0 and - # progressively larger and less dense towards 1. - log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( - self.time_scaling + 1 + X = np.hstack( + ( + self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), + ) ) - return log_spaced_pts + return X - def _compute_peaks(self) -> NDArray: + def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ - Peak location of each log-spaced cosine basis element. + Compute features for added bases and concatenate. - Compute the peak location for the log-spaced raised cosine basis. - Enforcing that the last basis decays to zero is equivalent to - setting the last peak to a value smaller than 1. + Parameters + ---------- + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. Returns ------- - Peak locations of each basis element. + : + The features, shape (n_samples, n_basis_funcs) """ - if self.enforce_decay_to_zero: - # compute the last peak location such that the last - # basis element decays to zero at the last sample. - last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1) - else: - last_peak = 1 - return np.linspace(0, last_peak, self.n_basis_funcs) + # the numpy conversion is important, there is some in-place + # array modification in basis. + hstack_pynapple = support_pynapple(conv_type="numpy")(np.hstack) + X = hstack_pynapple( + ( + self._basis1._compute_features( + *xi[: self._basis1._n_input_dimensionality] + ), + self._basis2._compute_features( + *xi[self._basis1._n_input_dimensionality :] + ), + ), + ) + return X - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Generate log-spaced raised cosine basis with given samples. + def _set_kernel(self, *xi: ArrayLike) -> Basis: + """Call fit on the added basis. + + If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. Parameters ---------- - sample_pts : - Spacing for basis functions. Samples will be rescaled to the interval [0, 1]. + *xi: + The sample inputs. Unused, necessary to conform to `scikit-learn` API. Returns ------- - basis_funcs : - Log-raised cosine basis functions, shape (n_samples, n_basis_funcs). - - Raises - ------ - ValueError - If the sample provided do not lie in [0,1]. + : + The AdditiveBasis ready to be evaluated. """ - return super().__call__(self._transform_samples(sample_pts)) - - -class OrthExponentialBasis(Basis): - """Set of 1D basis decaying exponential functions numerically orthogonalized. - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - decay_rates : - Decay rates of the exponentials, shape (n_basis_funcs,). - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when - `mode='conv'`; These arguments are used to change the default behavior of the convolution. - For example, changing the `predictor_causality`, which by default is set to `"causal"`. - Note that one cannot change the default value for the `axis` parameter. Basis assumes - that the convolution axis is `axis=0`. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import OrthExponentialBasis - >>> X = np.random.normal(size=(1000, 1)) - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = ortho_basis(sample_points) - """ + self._basis1._set_kernel() + self._basis2._set_kernel() + return self - def __init__( + def split_by_feature( self, - n_basis_funcs: int, - decay_rates: NDArray[np.floating], - mode="eval", - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "OrthExponentialBasis", - **kwargs, + x: NDArray, + axis: int = 1, ): - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - self.decay_rates = decay_rates - self._check_rates() - self._n_input_dimensionality = 1 + r""" + Decompose an array along a specified axis into sub-arrays based on the basis components. - @property - def decay_rates(self): - """Decay rate getter.""" - return self._decay_rates - - @decay_rates.setter - def decay_rates(self, value: NDArray): - """Decay rate setter.""" - value = np.asarray(value) - if value.shape[0] != self.n_basis_funcs: - raise ValueError( - f"The number of basis functions must match the number of decay rates provided. " - f"Number of basis functions provided: {self.n_basis_funcs}, " - f"Number of decay rates provided: {value.shape[0]}" - ) - self._decay_rates = value + This function takes an array (e.g., a design matrix or model coefficients) and splits it along + a designated axis. Each split corresponds to a different additive component of the basis, + preserving all dimensions except the specified axis. - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. + **How It Works:** - Checks that the number of basis is at least 1. + Suppose the basis is made up of **m components**, each with $b_i$ basis functions and $n_i$ inputs. + The total number of features, $N$, is calculated as: - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - if self.n_basis_funcs < 1: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 1 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) + $$ + N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m + $$ - def _check_rates(self) -> None: - """ - Check if the decay rates list has duplicate entries. + This method splits any axis of length $N$ into sub-arrays, one for each basis component. - Raises - ------ - ValueError - If two or more decay rates are repeated, which would result in a linearly - dependent set of functions for the basis. - """ - if len(set(self._decay_rates)) != len(self._decay_rates): - raise ValueError( - "Two or more rate are repeated! Repeating rate will result in a " - "linearly dependent set of function for the basis." - ) + The sub-array for the i-th basis component is reshaped into dimensions + $(n_i, b_i)$. + + For example, if the array shape is $(1, 2, N, 4, 5)$, then each split sub-array will have shape: + + $$ + (1, 2, n_i, b_i, 4, 5) + $$ + + where: - def _check_sample_size(self, *sample_pts: NDArray) -> None: - """Check that the sample size is greater than the number of basis. + - $n_i$ represents the number of inputs associated with the i-th component, + - $b_i$ represents the number of basis functions in that component. - This is necessary for the orthogonalization procedure, - that otherwise will return (sample_size, ) basis elements instead of the expected number. + The specified axis (`axis`) determines where the split occurs, and all other dimensions + remain unchanged. See the example section below for the most common use cases. Parameters ---------- - sample_pts - Spacing for basis functions, holding elements on the interval [0, inf). + x : + The input array to be split, representing concatenated features, coefficients, + or other data. The shape of `x` along the specified axis must match the total + number of features generated by the basis, i.e., `self.n_output_features`. + + **Examples:** + - For a design matrix: `(n_samples, total_n_features)` + - For model coefficients: `(total_n_features,)` or `(total_n_features, n_neurons)`. + + axis : int, optional + The axis along which to split the features. Defaults to 1. + Use `axis=1` for design matrices (features along columns) and `axis=0` for + coefficient arrays (features along rows). All other dimensions are preserved. Raises ------ ValueError - If the number of basis element is less than the number of samples. - """ - if sample_pts[0].size < self.n_basis_funcs: - raise ValueError( - "OrthExponentialBasis requires at least as many samples as basis functions!\n" - f"Class instantiated with {self.n_basis_funcs} basis functions " - f"but only {sample_pts[0].size} samples provided!" - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: NDArray, - ) -> FeatureMatrix: - """Generate basis functions with given spacing. - - Parameters - ---------- - sample_pts - Spacing for basis functions, holding elements on the interval [0, - inf), shape (n_samples,). + If the shape of `x` along the specified axis does not match `self.n_output_features`. Returns ------- - basis_funcs - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) - - """ - self._check_sample_size(sample_pts) - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - valid_idx = ~np.isnan(sample_pts) - # because of how scipy.linalg.orth works, have to create a matrix of - # shape (n_pts, n_basis_funcs) and then transpose, rather than - # directly computing orth on the matrix of shape (n_basis_funcs, - # n_pts) - exp_decay_eval = np.stack( - [np.exp(-lam * sample_pts[valid_idx]) for lam in self._decay_rates], axis=1 - ) - # count the linear independent components (could be lower than n_basis_funcs for num precision). - n_independent_component = np.linalg.matrix_rank(exp_decay_eval) - # initialize output to nan - basis_funcs = np.full( - shape=(sample_pts.shape[0], n_independent_component), fill_value=np.nan - ) - # orthonormalize on valid points - basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) - return basis_funcs + dict + A dictionary where: + - **Keys**: Labels of the additive basis components. + - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. + $$ + (..., n_i, b_i, ...) + $$ - Parameters - ---------- - n_samples : - The number of samples. + - `n_i`: The number of inputs processed by the i-th basis component. + - `b_i`: The number of basis functions for the i-th basis component. - Returns - ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) + These sub-arrays are reshaped along the specified axis, with all other dimensions + remaining the same. Examples -------- >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import OrthExponentialBasis - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + >>> from nemos.basis import BSplineBasis + >>> from nemos.glm import GLM + >>> # Define an additive basis + >>> basis = ( + ... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") + + ... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2") + ... ) + >>> # Generate a sample input array and compute features + >>> x1, x2 = np.random.randn(20), np.random.randn(20) + >>> X = basis.compute_features(x1, x2) + >>> # Split the feature matrix along axis 1 + >>> split_features = basis.split_by_feature(X, axis=1) + >>> for feature, arr in split_features.items(): + ... print(f"{feature}: shape {arr.shape}") + feature_1: shape (20, 1, 5) + feature_2: shape (20, 1, 6) + >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: + >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, + ... label="multi_input") + >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + multi_input, shape (20, 2, 6) + >>> # the method can be used to decompose the glm coefficients in the various features + >>> counts = np.random.poisson(size=20) + >>> model = GLM().fit(X, counts) + >>> split_coef = basis.split_by_feature(model.coef_, axis=0) + >>> for feature, coef in split_coef.items(): + ... print(f"{feature}: shape {coef.shape}") + feature_1: shape (1, 5) + feature_2: shape (1, 6) + """ - return super().evaluate_on_grid(n_samples) + return super().split_by_feature(x, axis=axis) -def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray: - """Compute M-spline basis function. +class MultiplicativeBasis(Basis): + """ + Class representing the multiplication (external product) of two Basis objects. Parameters ---------- - x - Spacing for basis functions, shape (n_sample_points, ). - k - Order of the spline basis. - i - Number of the spline basis. - T - knot locations. should lie in interval [0, 1], shape (k + n_basis_funcs,). - - Returns - ------- - spline - M-spline basis function, shape (n_sample_points, ). + basis1 : + First basis object to multiply. + basis2 : + Second basis object to multiply. + + Attributes + ---------- + n_basis_funcs : int + Number of basis functions. Examples -------- + >>> # Generate sample data >>> import numpy as np - >>> from numpy import linspace - >>> from nemos.basis import mspline + >>> import nemos as nmo + >>> X = np.random.normal(size=(30, 3)) - >>> sample_points = linspace(0, 1, 100) - >>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline - >>> mspline_eval.shape - (100,) - """ - # Boundary conditions. - if (T[i + k] - T[i]) < 1e-6: - return np.zeros_like(x) + >>> # define two basis and multiply + >>> basis_1 = nmo.basis.BSplineBasis(10) + >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> multiplicative_basis = basis_1 * basis_2 - # Special base case of first-order spline basis. - elif k == 1: - v = np.zeros_like(x) - v[(x >= T[i]) & (x < T[i + 1])] = 1 / (T[i + 1] - T[i]) - return v + >>> # Can multiply or add another basis to the AdditiveBasis object + >>> # This will cause the number of output features of the result basis to grow accordingly + >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> multiplicative_basis_2 = multiplicative_basis * basis_3 + """ - # General case, defined recursively - else: - return ( - k - * ( - (x - T[i]) * mspline(x, k - 1, i, T) - + (T[i + k] - x) * mspline(x, k - 1, i + 1, T) - ) - / ((k - 1) * (T[i + k] - T[i])) + def __init__(self, basis1: Basis, basis2: Basis) -> None: + self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs + super().__init__(self.n_basis_funcs, mode="eval") + self._n_input_dimensionality = ( + basis1._n_input_dimensionality + basis2._n_input_dimensionality ) + self._n_basis_input = None + self._n_output_features = None + self._label = "(" + basis1.label + " * " + basis2.label + ")" + self._basis1 = basis1 + self._basis2 = basis2 + return + def _check_n_basis_min(self) -> None: + pass -def bspline( - sample_pts: NDArray, - knots: NDArray, - order: int = 4, - der: int = 0, - outer_ok: bool = False, -) -> NDArray: - """ - Calculate and return the evaluation of B-spline basis. - - This function evaluates B-spline basis for given sample points. It checks for - out of range points and optionally handles them. It also handles the NaNs if present. + def _set_kernel(self, *xi: NDArray) -> Basis: + """Call fit on the multiplied basis. - Parameters - ---------- - sample_pts : - An array containing sample points for which B-spline basis needs to be evaluated, - shape (n_samples,) - knots : - An array containing knots for the B-spline basis. The knots are sorted in ascending order. - order : - The order of the B-spline basis. - der : - The derivative of the B-spline basis to be evaluated. - outer_ok : - If True, allows for evaluation at points outside the range of knots. - Default is False, in which case an assertion error is raised when - points outside the knots range are encountered. - - Returns - ------- - basis_eval : - An array containing the evaluation of B-spline basis for the given sample points. - Shape (n_samples, n_basis_funcs). + If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - Raises - ------ - AssertionError - If `outer_ok` is False and the sample points lie outside the B-spline knots range. + Parameters + ---------- + *xi: + The sample inputs. Unused, necessary to conform to `scikit-learn` API. - Notes - ----- - The function uses splev function from scipy.interpolate library for the basis evaluation. + Returns + ------- + : + The MultiplicativeBasis ready to be evaluated. + """ + self._basis1._set_kernel() + self._basis2._set_kernel() + return self - Examples - -------- - >>> import numpy as np - >>> from numpy import linspace - >>> from nemos.basis import bspline - - >>> sample_points = linspace(0, 1, 100) - >>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1]) - >>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline - >>> bspline_eval.shape - (100, 10) - """ - knots.sort() - nk = knots.shape[0] + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Evaluate the basis at the input samples. - # check for out of range points (in cyclic b-spline need_outer must be set to False) - need_outer = any(sample_pts < knots[order - 1]) or any( - sample_pts > knots[nk - order] - ) - assert ( - not need_outer - ) | outer_ok, 'sample points must lie within the B-spline knots range unless "outer_ok==True".' + Parameters + ---------- + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. - # select knots that are within the knots range (this takes care of eventual NaNs) - in_sample = (sample_pts >= knots[0]) & (sample_pts <= knots[-1]) + Returns + ------- + : + The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + """ + X = np.asarray( + row_wise_kron( + self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), + self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), + transpose=False, + ) + ) + return X - if need_outer: - reps = order - 1 - knots = np.hstack((np.ones(reps) * knots[0], knots, np.ones(reps) * knots[-1])) - nk = knots.shape[0] - else: - reps = 0 + def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Compute the features for the multiplied bases, and compute their outer product. - # number of basis elements - n_basis = nk - order + Parameters + ---------- + xi[0], ..., xi[n] : (n_samples,) + Tuple of input samples, each with the same number of samples. The + number of input arrays must equal the number of combined bases. - # initialize the basis element container - basis_eval = np.full((n_basis - 2 * reps, sample_pts.shape[0]), np.nan) + Returns + ------- + : + The features, shape (n_samples, n_basis_funcs) - # loop one element at the time and evaluate the basis using splev - id_basis = np.eye(n_basis, nk, dtype=np.int8) - for i in range(reps, len(knots) - order - reps): - basis_eval[i - reps, in_sample] = splev( - sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der + """ + kron = support_pynapple(conv_type="numpy")(row_wise_kron) + X = kron( + self._basis1._compute_features(*xi[: self._basis1._n_input_dimensionality]), + self._basis2._compute_features(*xi[self._basis1._n_input_dimensionality :]), + transpose=False, ) + return X - return basis_eval.T + def _set_num_output_features(self, *xi: NDArray) -> Basis: + self._n_basis_input = ( + *self._basis1._set_num_output_features( + *xi[: self._basis1._n_input_dimensionality] + )._n_basis_input, + *self._basis2._set_num_output_features( + *xi[self._basis1._n_input_dimensionality :] + )._n_basis_input, + ) + self._n_output_features = ( + self._basis1.n_output_features * self._basis2.n_output_features + ) + return self \ No newline at end of file diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py new file mode 100644 index 00000000..adcef722 --- /dev/null +++ b/src/nemos/basis/_basis_mixin.py @@ -0,0 +1,156 @@ +"""Mixin classes for basis.""" + +from numpy.typing import ArrayLike +from ..convolve import create_convolutional_predictor +import numpy as np +from typing import Union, Tuple + + +class EvalBasisMixin: + + def __init__(self, *args, **kwargs): + self._bounds = kwargs.pop("bounds", None) + + def _compute_features(self, *xi: ArrayLike): + """ + Apply the basis transformation to the input data. + + The basis evaluated at the samples, or $b_i(*xi)$, where $b_i$ is a + basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. + + Parameters + ---------- + *xi: + The input samples over which to apply the basis transformation. The samples can be passed + as multiple arguments, each representing a different dimension for multivariate inputs. + + Returns + ------- + : + A matrix with the transformed features. Faturehe basis evaluated at the samples, + or $b_i(*xi)$, where $b_i$ is a basis element. xi[k] must be a one-dimensional array + or a pynapple Tsd. + + """ + return self.__call__(*xi) + + def _set_kernel(self) -> "EvalBasisMixin": + """ + Prepare or compute the convolutional kernel for the basis functions. + + For EvalBasisMixin, this method might not perform any operation but simply return the + instance itself, as no kernel preparation is necessary. + + Returns + ------- + self : + The instance itself. + + """ + return self + + @property + def bounds(self): + return self._bounds + + @bounds.setter + def bounds(self, values: Union[None, Tuple[float, float]]): + """Setter for bounds.""" + if values is not None and self.mode == "conv": + raise ValueError("`bounds` should only be set when `mode=='eval'`.") + + if values is not None and len(values) != 2: + raise ValueError( + f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" + ) + + # convert to float and store + try: + self._bounds = values if values is None else tuple(map(float, values)) + except (ValueError, TypeError): + raise TypeError("Could not convert `bounds` to float.") + + if values is not None and values[1] <= values[0]: + raise ValueError( + f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." + ) + + +class ConvBasisMixin: + + def __init__(self, *args, **kwargs): + self._window_size = kwargs.pop("window_size") + + def _compute_features(self, *xi: ArrayLike): + """ + Apply the basis transformation to the input data. + + A bank of basis filters (created by calling fit) is convolved with the + samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions + except for the sample-axis are flattened, so that the method always returns a matrix. + For example, if samples are of shape (num_samples, 2, 3), the output will be + (num_samples, num_basis_funcs * 2 * 3). + The time-axis can be specified at basis initialization by setting the keyword argument `axis`. + For example, if `axis == 1` your samples should be (N1, num_samples N3, ...), the output of + transform will be (num_samples, num_basis_funcs * N1 * N3 *...). + + Parameters + ---------- + *xi: + The input samples over which to apply the basis transformation. The samples can be passed + as multiple arguments, each representing a different dimension for multivariate inputs. + + """ + # before calling the convolve, check that the input matches + # the expectation. We can check xi[0] only, since convolution + # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. + conv = create_convolutional_predictor( + self.kernel_, *xi, **self._conv_kwargs + ) + # make sure to return a matrix + return np.reshape(conv, newshape=(conv.shape[0], -1)) + + def _set_kernel(self) -> "ConvBasisMixin": + """ + Prepare or compute the convolutional kernel for the basis functions. + + This method is called to prepare the basis functions for convolution operations + in subclasses where the 'conv' mode is used. It typically involves computing a + kernel based on the basis functions that will be used for convolution with the + input data. The specifics of kernel computation depend on the subclass implementation + and the nature of the basis functions. + + Returns + ------- + self : + The instance itself, modified to include the computed kernel if applicable. This + allows for method chaining and integration into transformation pipelines. + + Notes + ----- + Subclasses implementing this method should detail the specifics of how the kernel is + computed and how the input parameters are utilized. If the basis operates in 'eval' + mode exclusively, this method should simply return `self` without modification. + """ + self.kernel_ = self.__call__(np.linspace(0, 1, self.window_size)) + return self + + @property + def window_size(self): + return self._window_size + + @window_size.setter + def window_size(self, window_size): + """Setter for the window size parameter.""" + + if window_size is None: + raise ValueError( + "If the basis is in `conv` mode, you must provide a window_size!" + ) + + elif not (isinstance(window_size, int) and window_size > 0): + raise ValueError( + f"`window_size` must be a positive integer. {window_size} provided instead!" + ) + + self._window_size = window_size \ No newline at end of file diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py new file mode 100644 index 00000000..c02242c2 --- /dev/null +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -0,0 +1,408 @@ + +# required to get ArrayLike to render correctly +from __future__ import annotations + +from typing import Optional, Tuple + +import numpy as np +import scipy.linalg +from numpy.typing import ArrayLike, NDArray + + +from ..type_casting import support_pynapple +from ..typing import FeatureMatrix +from ._basis_mixin import EvalBasisMixin, ConvBasisMixin + +from ._basis import Basis, check_transform_input, check_one_dimensional +import abc + + +class RaisedCosineBasisLinear(Basis, abc.ABC): + """Represent linearly-spaced raised cosine basis functions. + + This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + width : + Width of the raised cosine. By default, it's set to 2.0. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import RaisedCosineBasisLinear + >>> X = np.random.normal(size=(1000, 1)) + + >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cosine_basis(sample_points) + + # References + ------------ + [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 + """ + + def __init__( + self, + n_basis_funcs: int, + mode="eval", + width: float = 2.0, + window_size: Optional[int] = None, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "RaisedCosineBasisLinear", + **kwargs, + ) -> None: + super().__init__( + n_basis_funcs, + mode=mode, + window_size=window_size, + bounds=bounds, + label=label, + **kwargs, + ) + self._n_input_dimensionality = 1 + self._check_width(width) + self._width = width + # for these linear raised-cosine basis functions, + # the samples must be rescaled to 0 and 1. + self._rescale_samples = True + + @property + def width(self): + """Return width of the raised cosine.""" + return self._width + + @width.setter + def width(self, width: float): + self._check_width(width) + self._width = width + + @staticmethod + def _check_width(width: float) -> None: + """Validate the width value. + + Parameters + ---------- + width : + The width value to validate. + + Raises + ------ + ValueError + If width <= 1 or 2*width is not a positive integer. Values that do not match + this constraint will result in: + - No overlap between bumps (width < 1). + - Oscillatory behavior when summing the basis elements (2*width not integer). + """ + if width <= 1 or (not np.isclose(width * 2, round(2 * width))): + raise ValueError( + f"Invalid raised cosine width. " + f"2*width must be a positive integer, 2*width = {2 * width} instead!" + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__( + self, + sample_pts: ArrayLike, + ) -> FeatureMatrix: + """Generate basis functions with given samples. + + Parameters + ---------- + sample_pts : + Spacing for basis functions, holding elements on interval [0, 1], Shape (number of samples, ). + + Raises + ------ + ValueError + If the sample provided do not lie in [0,1]. + + """ + if self._rescale_samples: + # note that sample points is converted to NDArray + # with the decorator. + # copy is necessary otherwise: + # basis1 = nmo.basis.RaisedCosineBasisLinear(5) + # basis2 = nmo.basis.RaisedCosineBasisLog(5) + # additive_basis = basis1 + basis2 + # additive_basis(*([x] * 2)) would modify both inputs + sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) + + peaks = self._compute_peaks() + delta = peaks[1] - peaks[0] + # generate a set of shifted cosines, and constrain them to be non-zero + # over a single period, then enforce the codomain to be [0,1], by adding 1 + # and then multiply by 0.5 + basis_funcs = 0.5 * ( + np.cos( + np.clip( + np.pi * (sample_pts[:, None] - peaks[None]) / (delta * self.width), + -np.pi, + np.pi, + ) + ) + + 1 + ) + return basis_funcs + + def _compute_peaks(self) -> NDArray: + """ + Compute the location of raised cosine peaks. + + Returns + ------- + Peak locations of each basis element. + """ + return np.linspace(0, 1, self.n_basis_funcs) + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import RaisedCosineBasisLinear + >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100) + """ + return super().evaluate_on_grid(n_samples) + + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Check that the number of basis is at least 2. + + Raises + ------ + ValueError + If n_basis_funcs < 2. + """ + if self.n_basis_funcs < 2: + raise ValueError( + f"Object class {self.__class__.__name__} requires >= 2 basis elements. " + f"{self.n_basis_funcs} basis elements specified instead" + ) + + +class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): + """Represent log-spaced raised cosine basis functions. + + Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references) + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + width : + Width of the raised cosine. + time_scaling : + Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with + larger values resulting in more stretching. As this approaches 0, the + transformation becomes linear. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import RaisedCosineBasisLog + >>> X = np.random.normal(size=(1000, 1)) + + >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cosine_basis(sample_points) + + # References + ------------ + [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005 + """ + + def __init__( + self, + n_basis_funcs: int, + mode="eval", + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + window_size: Optional[int] = None, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "RaisedCosineBasisLog", + **kwargs, + ) -> None: + super().__init__( + n_basis_funcs, + mode=mode, + width=width, + window_size=window_size, + bounds=bounds, + **kwargs, + label=label, + ) + # The samples are scaled appropriately in the self._transform_samples which scales + # and applies the log-stretch, no additional transform is needed. + self._rescale_samples = False + if time_scaling is None: + time_scaling = 50.0 + + self.time_scaling = time_scaling + self.enforce_decay_to_zero = enforce_decay_to_zero + + @property + def time_scaling(self): + """Getter property for time_scaling.""" + return self._time_scaling + + @time_scaling.setter + def time_scaling(self, time_scaling): + """Setter property for time_scaling.""" + self._check_time_scaling(time_scaling) + self._time_scaling = time_scaling + + @staticmethod + def _check_time_scaling(time_scaling: float) -> None: + if time_scaling <= 0: + raise ValueError( + f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." + ) + + def _transform_samples( + self, + sample_pts: ArrayLike, + ) -> NDArray: + """ + Map the sample domain to log-space. + + Parameters + ---------- + sample_pts : + Sample points used for evaluating the splines, + shape (n_samples, ). + + Returns + ------- + Transformed version of the sample points that matches the Raised Cosine basis domain, + shape (n_samples, ). + """ + # rescale to [0,1] + # copy is necessary to avoid unwanted rescaling in additive/multiplicative basis. + sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) + # This log-stretching of the sample axis has the following effect: + # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. + # - as the time_scaling tends to inf, basis will be small and dense around 0 and + # progressively larger and less dense towards 1. + log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( + self.time_scaling + 1 + ) + return log_spaced_pts + + def _compute_peaks(self) -> NDArray: + """ + Peak location of each log-spaced cosine basis element. + + Compute the peak location for the log-spaced raised cosine basis. + Enforcing that the last basis decays to zero is equivalent to + setting the last peak to a value smaller than 1. + + Returns + ------- + Peak locations of each basis element. + + """ + if self.enforce_decay_to_zero: + # compute the last peak location such that the last + # basis element decays to zero at the last sample. + last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1) + else: + last_peak = 1 + return np.linspace(0, last_peak, self.n_basis_funcs) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__( + self, + sample_pts: ArrayLike, + ) -> FeatureMatrix: + """Generate log-spaced raised cosine basis with given samples. + + Parameters + ---------- + sample_pts : + Spacing for basis functions. Samples will be rescaled to the interval [0, 1]. + + Returns + ------- + basis_funcs : + Log-raised cosine basis functions, shape (n_samples, n_basis_funcs). + + Raises + ------ + ValueError + If the sample provided do not lie in [0,1]. + """ + return super().__call__(self._transform_samples(sample_pts)) \ No newline at end of file diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py new file mode 100644 index 00000000..081171fc --- /dev/null +++ b/src/nemos/basis/_spline_basis.py @@ -0,0 +1,768 @@ + +# required to get ArrayLike to render correctly +from __future__ import annotations + +import abc +import copy +from typing import Optional, Tuple + +import numpy as np +from numpy.typing import ArrayLike, NDArray +from scipy.interpolate import splev + + +from ..type_casting import support_pynapple +from ..typing import FeatureMatrix +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples + + +class SplineBasis(Basis, abc.ABC): + """ + SplineBasis class inherits from the Basis class and represents spline basis functions. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + order : optional + Spline order. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Attributes + ---------- + order : int + Spline order. + """ + + def __init__( + self, + n_basis_funcs: int, + order: int = 2, + label: Optional[str] = None, + **kwargs, + ) -> None: + self.order = order + super().__init__( + n_basis_funcs, + label=label, + **kwargs, + ) + + self._n_input_dimensionality = 1 + + @property + def order(self): + return self._order + + @order.setter + def order(self, value): + """Setter for the order parameter.""" + if value != int(value): + raise ValueError( + f"Spline order must be an integer! Order {value} provided." + ) + value = int(value) + if value < 1: + raise ValueError(f"Spline order must be positive! Order {value} provided.") + + # Set to None only the first time the setter is called. + orig_order = copy.deepcopy(getattr(self, "_order", None)) + + # Set the order + self._order = value + + # If the order was already initialized, re-check basis + if orig_order is not None: + try: + self._check_n_basis_min() + except ValueError as e: + self._order = orig_order + raise e + + def _generate_knots( + self, + is_cyclic: bool = False, + ) -> NDArray: + """ + Generate knots locations for spline basis functions. + + Parameters + ---------- + is_cyclic : optional + Whether the spline is cyclic. + + Returns + ------- + The knot locations for the spline basis functions. + + Raises + ------ + AssertionError + If the percentiles or order values are not within the valid range. + """ + # Determine number of interior knots. + num_interior_knots = self.n_basis_funcs - self.order + if is_cyclic: + num_interior_knots += self.order - 1 + + # Spline basis have support on the semi-open [a, b) interval, we add a small epsilon + # to mx so that the so that basis_element(max(samples)) != 0 + knot_locs = np.concatenate( + ( + np.zeros(self.order - 1), + np.linspace(0, (1 + np.finfo(float).eps), num_interior_knots + 2), + np.full(self.order - 1, 1 + np.finfo(float).eps), + ) + ) + return knot_locs + + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Check that the spline-basis has at least as many basis as the order. + + Raises + ------ + ValueError + If an insufficient number of basis element is requested for the basis type + """ + if self.n_basis_funcs < self.order: + raise ValueError( + f"{self.__class__.__name__} `order` parameter cannot be larger " + "than `n_basis_funcs` parameter." + ) + + +class MSplineBasis(SplineBasis, abc.ABC): + r""" + M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation. + + M-splines are a type of spline basis function used for smooth curve fitting + and data representation. They are positive and integrate to one, making them + suitable for probabilistic models and density estimation. The order of an + M-spline defines its smoothness, with higher orders resulting in smoother + splines. + + This class provides functionality to create M-spline basis functions, allowing + for flexible and smooth modeling of data. It inherits from the `SplineBasis` + abstract class, providing specific implementations for M-splines. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + order : + The order of the splines used in basis functions. Must be between [1, + n_basis_funcs]. Default is 2. Higher order splines have more continuous + derivatives at each interior knot, resulting in smoother basis functions. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs: + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = mspline_basis(sample_points) + + # References + ------------ + [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, + 3(4), 425-441. + + Notes + ----- + MSplines must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain + (x-axis) of an MSpline basis is expanded by a factor of $\alpha$, the values on the co-domain (y-axis) values + will shrink by a factor of $1/\alpha$. + For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. + If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + """ + + def __init__( + self, + n_basis_funcs: int, + order: int = 2, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", + **kwargs, + ) -> None: + super().__init__( + n_basis_funcs, + mode="eval", + order=order, + bounds=bounds, + label=label, + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: + """ + Evaluate the M-spline basis functions at given sample points. + + Parameters + ---------- + sample_pts : + An array of sample points where the M-spline basis functions are to be + evaluated. + + Returns + ------- + : + An array where each column corresponds to one M-spline basis function + evaluated at the input sample points. The shape of the array is + (len(sample_pts), n_basis_funcs). + + Notes + ----- + The implementation uses a recursive definition of M-splines. Boundary + conditions are handled such that the basis functions are positive and + integrate to one over the domain defined by the sample points. + """ + sample_pts, scaling = min_max_rescale_samples(sample_pts, self.bounds) + # add knots if not passed + knot_locs = self._generate_knots(is_cyclic=False) + + X = np.stack( + [ + mspline(sample_pts, self.order, i, knot_locs) + for i in range(self.n_basis_funcs) + ], + axis=1, + ) + # re-normalize so that it integrates to 1 over the range. + X /= scaling + return X + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Evaluate the M-spline basis functions on a uniformly spaced grid. + + This method creates a uniformly spaced grid of sample points within the domain + [0, 1] and evaluates all the M-spline basis functions at these points. It is + particularly useful for visualizing the shape and distribution of the basis + functions across their domain. + + Parameters + ---------- + n_samples : + The number of points in the uniformly spaced grid. A higher number of + samples will result in a more detailed visualization of the basis functions. + + Returns + ------- + X : NDArray + A 1D array of uniformly spaced sample points within the domain [0, 1]. + Shape: `(n_samples,)`. + Y : NDArray + A 2D array where each row corresponds to the evaluated M-spline basis + function values at the points in X. Shape: `(n_samples, n_basis_funcs)`. + + Examples + -------- + Evaluate and visualize 4 M-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalMSpline + >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('M-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return super().evaluate_on_grid(n_samples) + + +class BSplineBasis(SplineBasis, abc.ABC): + """ + B-spline[$^{[1]}$](#references) 1-dimensional basis functions. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + order : + Order of the splines used in basis functions. Must lie within [1, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Attributes + ---------- + order : + Spline order. + + + # References + ------------ + [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import BSplineBasis + + >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = bspline_basis(sample_points) + """ + + def __init__( + self, + n_basis_funcs: int, + mode="eval", + order: int = 4, + window_size: Optional[int] = None, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "BSplineBasis", + **kwargs, + ): + super().__init__( + n_basis_funcs, + mode=mode, + order=order, + window_size=window_size, + bounds=bounds, + label=label, + **kwargs, + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: + """ + Evaluate the B-spline basis functions with given sample points. + + Parameters + ---------- + sample_pts : + The sample points at which the B-spline is evaluated, shape (n_samples,). + + Returns + ------- + basis_funcs : + The basis function evaluated at the samples, shape (n_samples, n_basis_funcs). + + Raises + ------ + AssertionError + If the sample points are not within the B-spline knots. + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` + from SciPy to compute the basis values. + """ + sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + # add knots + knot_locs = self._generate_knots(is_cyclic=False) + + basis_eval = bspline( + sample_pts, knot_locs, order=self.order, der=0, outer_ok=False + ) + return basis_eval + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import BSplineBasis + >>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) + """ + return super().evaluate_on_grid(n_samples) + + +class CyclicBSplineBasis(SplineBasis, abc.ABC): + """ + B-spline 1-dimensional basis functions for cyclic splines. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + order : + Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Attributes + ---------- + n_basis_funcs : int + Number of basis functions. + order : int + Order of the splines used in basis functions. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import CyclicBSplineBasis + >>> X = np.random.normal(size=(1000, 1)) + + >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cyclic_basis(sample_points) + """ + + def __init__( + self, + n_basis_funcs: int, + mode="eval", + order: int = 4, + window_size: Optional[int] = None, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "CyclicBSplineBasis", + **kwargs, + ): + super().__init__( + n_basis_funcs, + mode=mode, + order=order, + window_size=window_size, + bounds=bounds, + label=label, + **kwargs, + ) + if self.order < 2: + raise ValueError( + f"Order >= 2 required for cyclic B-spline, " + f"order {self.order} specified instead!" + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__( + self, + sample_pts: ArrayLike, + ) -> FeatureMatrix: + """Evaluate the Cyclic B-spline basis functions with given sample points. + + Parameters + ---------- + sample_pts : + The sample points at which the cyclic B-spline is evaluated, shape + (n_samples,). + + Returns + ------- + basis_funcs : + The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + + """ + sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + knot_locs = self._generate_knots(is_cyclic=True) + + # for cyclic, do not repeat knots + knot_locs = np.unique(knot_locs) + + nk = knot_locs.shape[0] + + # make sure knots are sorted + knot_locs.sort() + + # extend knots + xc = knot_locs[nk - self.order] + knots = np.hstack( + ( + knot_locs[0] - knot_locs[-1] + knot_locs[nk - self.order : nk - 1], + knot_locs, + ) + ) + + ind = sample_pts > xc + + basis_eval = bspline(sample_pts, knots, order=self.order, der=0, outer_ok=True) + sample_pts[ind] = sample_pts[ind] - knots.max() + knot_locs[0] + + if np.sum(ind): + basis_eval[ind] = basis_eval[ind] + bspline( + sample_pts[ind], knots, order=self.order, outer_ok=True, der=0 + ) + # restore points + sample_pts[ind] = sample_pts[ind] + knots.max() - knot_locs[0] + return basis_eval + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Raised cosine basis functions, shape (n_samples, n_basis_funcs) + + Notes + ----- + The evaluation is performed by looping over each element and using `splev` from + SciPy to compute the basis values. + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import CyclicBSplineBasis + >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100) + """ + return super().evaluate_on_grid(n_samples) + + + +def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray: + """Compute M-spline basis function. + + Parameters + ---------- + x + Spacing for basis functions, shape (n_sample_points, ). + k + Order of the spline basis. + i + Number of the spline basis. + T + knot locations. should lie in interval [0, 1], shape (k + n_basis_funcs,). + + Returns + ------- + spline + M-spline basis function, shape (n_sample_points, ). + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis._spline_basis import mspline + >>> sample_points = linspace(0, 1, 100) + >>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline + >>> mspline_eval.shape + (100,) + """ + # Boundary conditions. + if (T[i + k] - T[i]) < 1e-6: + return np.zeros_like(x) + + # Special base case of first-order spline basis. + elif k == 1: + v = np.zeros_like(x) + v[(x >= T[i]) & (x < T[i + 1])] = 1 / (T[i + 1] - T[i]) + return v + + # General case, defined recursively + else: + return ( + k + * ( + (x - T[i]) * mspline(x, k - 1, i, T) + + (T[i + k] - x) * mspline(x, k - 1, i + 1, T) + ) + / ((k - 1) * (T[i + k] - T[i])) + ) + + +def bspline( + sample_pts: NDArray, + knots: NDArray, + order: int = 4, + der: int = 0, + outer_ok: bool = False, +) -> NDArray: + """ + Calculate and return the evaluation of B-spline basis. + + This function evaluates B-spline basis for given sample points. It checks for + out of range points and optionally handles them. It also handles the NaNs if present. + + Parameters + ---------- + sample_pts : + An array containing sample points for which B-spline basis needs to be evaluated, + shape (n_samples,) + knots : + An array containing knots for the B-spline basis. The knots are sorted in ascending order. + order : + The order of the B-spline basis. + der : + The derivative of the B-spline basis to be evaluated. + outer_ok : + If True, allows for evaluation at points outside the range of knots. + Default is False, in which case an assertion error is raised when + points outside the knots range are encountered. + + Returns + ------- + basis_eval : + An array containing the evaluation of B-spline basis for the given sample points. + Shape (n_samples, n_basis_funcs). + + Raises + ------ + AssertionError + If `outer_ok` is False and the sample points lie outside the B-spline knots range. + + Notes + ----- + The function uses splev function from scipy.interpolate library for the basis evaluation. + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis._spline_basis import bspline + >>> sample_points = linspace(0, 1, 100) + >>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1]) + >>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline + >>> bspline_eval.shape + (100, 10) + """ + knots.sort() + nk = knots.shape[0] + + # check for out of range points (in cyclic b-spline need_outer must be set to False) + need_outer = any(sample_pts < knots[order - 1]) or any( + sample_pts > knots[nk - order] + ) + assert ( + not need_outer + ) | outer_ok, 'sample points must lie within the B-spline knots range unless "outer_ok==True".' + + # select knots that are within the knots range (this takes care of eventual NaNs) + in_sample = (sample_pts >= knots[0]) & (sample_pts <= knots[-1]) + + if need_outer: + reps = order - 1 + knots = np.hstack((np.ones(reps) * knots[0], knots, np.ones(reps) * knots[-1])) + nk = knots.shape[0] + else: + reps = 0 + + # number of basis elements + n_basis = nk - order + + # initialize the basis element container + basis_eval = np.full((n_basis - 2 * reps, sample_pts.shape[0]), np.nan) + + # loop one element at the time and evaluate the basis using splev + id_basis = np.eye(n_basis, nk, dtype=np.int8) + for i in range(reps, len(knots) - order - reps): + basis_eval[i - reps, in_sample] = splev( + sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der + ) + + return basis_eval.T diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py new file mode 100644 index 00000000..60ae2056 --- /dev/null +++ b/src/nemos/basis/basis.py @@ -0,0 +1,358 @@ +"""Bases classes.""" + +# required to get ArrayLike to render correctly +from __future__ import annotations + +from typing import Optional, Tuple + +import numpy as np +import scipy.linalg +from numpy.typing import ArrayLike, NDArray + + +from ..type_casting import support_pynapple +from ..typing import FeatureMatrix +from ._basis_mixin import EvalBasisMixin, ConvBasisMixin + +from ._basis import Basis, check_transform_input, check_one_dimensional +from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis + +__all__ = [ + "EvalMSpline", + "ConvMSpline", + "EvalBSpline", + "ConvBSpline", + "EvalCyclicBSpline", + "ConvCyclicBSpline", + "RaisedCosineBasisLinear", + "RaisedCosineBasisLog", + "OrthExponentialBasis", +] + + +def __dir__() -> list[str]: + return __all__ + + +class EvalBSpline(EvalBasisMixin, BSplineBasis): + def __init__( + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalBSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + BSplineBasis.__init__( + self, + n_basis_funcs, + mode="eval", + order=order, + bounds=bounds, + label=label, + ) + + + +class ConvBSpline(ConvBasisMixin, BSplineBasis): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvBSpline", + ): + ConvBasisMixin.__init__(self, window_size=window_size) + BSplineBasis.__init__( + self, + n_basis_funcs, + mode="conv", + order=order, + window_size=window_size, + label=label, + ) + + +class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): + def __init__( + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalCyclicBSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + CyclicBSplineBasis.__init__( + self, + n_basis_funcs, + mode="eval", + order=order, + bounds=bounds, + label=label, + ) + + +class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvCyclicBSpline", + ): + ConvBasisMixin.__init__(self, window_size=window_size) + CyclicBSplineBasis.__init__( + self, + n_basis_funcs, + mode="conv", + order=order, + window_size=window_size, + label=label, + ) + + +class EvalMSpline(EvalBasisMixin, MSplineBasis): + def __init__( + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + MSplineBasis.__init__( + self, + n_basis_funcs, + mode="eval", + order=order, + bounds=bounds, + label=label, + ) + + +class ConvMSpline(ConvBasisMixin, MSplineBasis): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvMSpline", + ): + MSplineBasis.__init__( + self, + n_basis_funcs, + mode="conv", + order=order, + window_size=window_size, + label=label, + ) + ConvBasisMixin.__init__(self, window_size=window_size) + + + +class OrthExponentialBasis(Basis): + """Set of 1D basis decaying exponential functions numerically orthogonalized. + + Parameters + ---------- + n_basis_funcs + Number of basis functions. + decay_rates : + Decay rates of the exponentials, shape (n_basis_funcs,). + mode : + The mode of operation. 'eval' for evaluation at sample points, + 'conv' for convolutional operation. + window_size : + The window size for convolution. Required if mode is 'conv'. + bounds : + The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when + `mode='conv'`; These arguments are used to change the default behavior of the convolution. + For example, changing the `predictor_causality`, which by default is set to `"causal"`. + Note that one cannot change the default value for the `axis` parameter. Basis assumes + that the convolution axis is `axis=0`. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import OrthExponentialBasis + >>> X = np.random.normal(size=(1000, 1)) + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = ortho_basis(sample_points) + """ + + def __init__( + self, + n_basis_funcs: int, + decay_rates: NDArray[np.floating], + mode="eval", + window_size: Optional[int] = None, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "OrthExponentialBasis", + **kwargs, + ): + super().__init__( + n_basis_funcs, + mode=mode, + window_size=window_size, + bounds=bounds, + label=label, + **kwargs, + ) + self.decay_rates = decay_rates + self._check_rates() + self._n_input_dimensionality = 1 + + @property + def decay_rates(self): + """Decay rate getter.""" + return self._decay_rates + + @decay_rates.setter + def decay_rates(self, value: NDArray): + """Decay rate setter.""" + value = np.asarray(value) + if value.shape[0] != self.n_basis_funcs: + raise ValueError( + f"The number of basis functions must match the number of decay rates provided. " + f"Number of basis functions provided: {self.n_basis_funcs}, " + f"Number of decay rates provided: {value.shape[0]}" + ) + self._decay_rates = value + + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Checks that the number of basis is at least 1. + + Raises + ------ + ValueError + If an insufficient number of basis element is requested for the basis type + """ + if self.n_basis_funcs < 1: + raise ValueError( + f"Object class {self.__class__.__name__} requires >= 1 basis elements. " + f"{self.n_basis_funcs} basis elements specified instead" + ) + + def _check_rates(self) -> None: + """ + Check if the decay rates list has duplicate entries. + + Raises + ------ + ValueError + If two or more decay rates are repeated, which would result in a linearly + dependent set of functions for the basis. + """ + if len(set(self._decay_rates)) != len(self._decay_rates): + raise ValueError( + "Two or more rate are repeated! Repeating rate will result in a " + "linearly dependent set of function for the basis." + ) + + def _check_sample_size(self, *sample_pts: NDArray) -> None: + """Check that the sample size is greater than the number of basis. + + This is necessary for the orthogonalization procedure, + that otherwise will return (sample_size, ) basis elements instead of the expected number. + + Parameters + ---------- + sample_pts + Spacing for basis functions, holding elements on the interval [0, inf). + + Raises + ------ + ValueError + If the number of basis element is less than the number of samples. + """ + if sample_pts[0].size < self.n_basis_funcs: + raise ValueError( + "OrthExponentialBasis requires at least as many samples as basis functions!\n" + f"Class instantiated with {self.n_basis_funcs} basis functions " + f"but only {sample_pts[0].size} samples provided!" + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__( + self, + sample_pts: NDArray, + ) -> FeatureMatrix: + """Generate basis functions with given spacing. + + Parameters + ---------- + sample_pts + Spacing for basis functions, holding elements on the interval [0, + inf), shape (n_samples,). + + Returns + ------- + basis_funcs + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + """ + self._check_sample_size(sample_pts) + sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + valid_idx = ~np.isnan(sample_pts) + # because of how scipy.linalg.orth works, have to create a matrix of + # shape (n_pts, n_basis_funcs) and then transpose, rather than + # directly computing orth on the matrix of shape (n_basis_funcs, + # n_pts) + exp_decay_eval = np.stack( + [np.exp(-lam * sample_pts[valid_idx]) for lam in self._decay_rates], axis=1 + ) + # count the linear independent components (could be lower than n_basis_funcs for num precision). + n_independent_component = np.linalg.matrix_rank(exp_decay_eval) + # initialize output to nan + basis_funcs = np.full( + shape=(sample_pts.shape[0], n_independent_component), fill_value=np.nan + ) + # orthonormalize on valid points + basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) + return basis_funcs + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import OrthExponentialBasis + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + """ + return super().evaluate_on_grid(n_samples) diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index 79437c65..e0c77e03 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -9,7 +9,7 @@ from jax.typing import ArrayLike as JaxArray from numpy.typing import NDArray -from .basis import Basis +from nemos.basis.basis import Basis from .tree_utils import get_valid_multitree, tree_slice from .type_casting import support_pynapple from .validation import _warn_if_not_float64 diff --git a/src/nemos/typing.py b/src/nemos/typing.py index dd9bc5a6..406a5e13 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -7,6 +7,8 @@ from jax._src.typing import ArrayLike from .pytrees import FeaturePytree +from pynapple import TsdFrame +from numpy.typing import NDArray DESIGN_INPUT_TYPE = Union[jnp.ndarray, FeaturePytree] @@ -51,3 +53,5 @@ ], # Step-size for optimization (must be a float) Tuple[jnp.ndarray, jnp.ndarray], ] + +FeatureMatrix = Union[NDArray, TsdFrame] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index cb39ee37..e114e43e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -286,7 +286,7 @@ def coupled_model_simulate(): ) # shrink the filters for simulation stability coupling_filter_bank *= 0.8 - basis = nmo.basis.RaisedCosineBasisLog(20) + basis = nemos.basis.basis.RaisedCosineBasisLog(20) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/tests/test_basis.py b/tests/test_basis.py index 72f987c6..156a8292 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -12,7 +12,7 @@ import utils_testing from sklearn.base import clone as sk_clone -import nemos.basis as basis +import nemos.basis.basis as basis import nemos.convolve as convolve from nemos.utils import pynapple_concatenate_numpy @@ -1752,7 +1752,7 @@ def test_transformer_get_params(self): class TestMSplineBasis(BasisFuncsTesting): - cls = basis.MSplineBasis + cls = basis.EvalMSpline @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) @@ -4940,7 +4940,7 @@ def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): if mode == "eval": window_size = None - if basis_class == basis.MSplineBasis: + if basis_class == basis.EvalMSpline: basis_obj = basis_class( n_basis_funcs=n_basis, order=4, mode=mode, window_size=window_size ) @@ -4964,13 +4964,13 @@ def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): n_basis_funcs=n_basis, order=3, mode=mode, window_size=window_size ) elif basis_class == basis.AdditiveBasis: - b1 = basis.MSplineBasis( + b1 = basis.EvalMSpline( n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size ) b2 = basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis + 1) basis_obj = b1 + b2 elif basis_class == basis.MultiplicativeBasis: - b1 = basis.MSplineBasis( + b1 = basis.EvalMSpline( n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size ) b2 = basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis + 1) @@ -4992,9 +4992,9 @@ class TestAdditiveBasis(CombinedBasis): def test_non_empty_samples(self, samples, mode, ws): if mode == "conv" and len(samples[0]) < 2: return - basis_obj = basis.MSplineBasis( + basis_obj = basis.EvalMSpline( 5, mode=mode, window_size=ws - ) + basis.MSplineBasis(5, mode=mode, window_size=ws) + ) + basis.EvalMSpline(5, mode=mode, window_size=ws) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -5017,7 +5017,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5) + basis_obj = basis.EvalMSpline(5) + basis.EvalMSpline(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -5534,9 +5534,9 @@ class TestMultiplicativeBasis(CombinedBasis): def test_non_empty_samples(self, samples, mode, ws): if mode == "conv" and len(samples[0]) < 2: return - basis_obj = basis.MSplineBasis( + basis_obj = basis.EvalMSpline( 5, mode=mode, window_size=ws - ) * basis.MSplineBasis(5, mode=mode, window_size=ws) + ) * basis.EvalMSpline(5, mode=mode, window_size=ws) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -5559,7 +5559,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.MSplineBasis(5) * basis.MSplineBasis(5) + basis_obj = basis.EvalMSpline(5) * basis.EvalMSpline(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -5725,7 +5725,7 @@ def test_evaluate_on_grid_input_number( with expectation: basis_obj.evaluate_on_grid(*inputs) - @pytest.mark.parametrize("basis_a", [basis.MSplineBasis]) + @pytest.mark.parametrize("basis_a", [basis.EvalMSpline]) @pytest.mark.parametrize("basis_b", [basis.OrthExponentialBasis]) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) @@ -6146,7 +6146,7 @@ def test_power_of_basis(exponent, basis_class): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6169,7 +6169,7 @@ def test_basis_to_transformer(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6197,7 +6197,7 @@ def test_transformer_has_the_same_public_attributes_as_basis(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6224,7 +6224,7 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6249,7 +6249,7 @@ def test_basis_to_transformer_makes_a_copy(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6265,7 +6265,7 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6285,7 +6285,7 @@ def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_func @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6305,7 +6305,7 @@ def test_transformerbasis_setattr_basis(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6326,7 +6326,7 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6349,7 +6349,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6371,7 +6371,7 @@ def test_transformerbasis_setattr_illegal_attribute(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6401,7 +6401,7 @@ def test_transformerbasis_addition(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6431,7 +6431,7 @@ def test_transformerbasis_multiplication(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6462,7 +6462,7 @@ def test_transformerbasis_exponentiation( @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6485,7 +6485,7 @@ def test_transformerbasis_dir(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6512,7 +6512,7 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6563,7 +6563,7 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6599,7 +6599,7 @@ def test_multi_epoch_pynapple_basis( predictor_causality=predictor_causality, shift=shift, ) - bas = basis.MSplineBasis(3) * bas + bas = basis.EvalMSpline(3) * bas else: bas = basis_cls( 5, @@ -6651,7 +6651,7 @@ def test_multi_epoch_pynapple_basis( @pytest.mark.parametrize( "basis_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -6687,7 +6687,7 @@ def test_multi_epoch_pynapple_basis_transformer( predictor_causality=predictor_causality, shift=shift, ) - bas = basis.MSplineBasis(3) * bas + bas = basis.EvalMSpline(3) * bas else: bas = basis_cls( 5, diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index a1aec5de..ab0a6cbe 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from nemos.basis import BSplineBasis, RaisedCosineBasisLinear +from nemos.basis.basis import BSplineBasis, RaisedCosineBasisLinear from nemos.identifiability_constraints import ( _WARN_FLOAT32_MESSAGE, _find_drop_column, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 703d42ff..5487b497 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "bas", [ - basis.MSplineBasis(5), + basis.EvalMSpline(5), basis.BSplineBasis(5), basis.CyclicBSplineBasis(5), basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), @@ -29,7 +29,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): @pytest.mark.parametrize( "bas", [ - basis.MSplineBasis(5), + basis.EvalMSpline(5), basis.BSplineBasis(5), basis.CyclicBSplineBasis(5), basis.RaisedCosineBasisLinear(5), @@ -48,7 +48,7 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): @pytest.mark.parametrize( "bas", [ - basis.MSplineBasis(5), + basis.EvalMSpline(5), basis.BSplineBasis(5), basis.CyclicBSplineBasis(5), basis.RaisedCosineBasisLinear(5), @@ -73,7 +73,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( @pytest.mark.parametrize( "bas_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -94,7 +94,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( @pytest.mark.parametrize( "bas_cls", [ - basis.MSplineBasis, + basis.EvalMSpline, basis.BSplineBasis, basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, @@ -122,14 +122,14 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( @pytest.mark.parametrize( "bas, expected_nans", [ - (basis.MSplineBasis(5), 0), + (basis.EvalMSpline(5), 0), (basis.BSplineBasis(5), 0), (basis.CyclicBSplineBasis(5), 0), (basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), 0), (basis.RaisedCosineBasisLinear(5), 0), (basis.RaisedCosineBasisLog(5), 0), - (basis.RaisedCosineBasisLog(5) + basis.MSplineBasis(5), 0), - (basis.MSplineBasis(5, mode="conv", window_size=3), 6), + (basis.RaisedCosineBasisLog(5) + basis.EvalMSpline(5), 0), + (basis.EvalMSpline(5, mode="conv", window_size=3), 6), (basis.BSplineBasis(5, mode="conv", window_size=3), 6), ( basis.CyclicBSplineBasis( @@ -147,12 +147,12 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( (basis.RaisedCosineBasisLog(5, mode="conv", window_size=3), 6), ( basis.RaisedCosineBasisLog(5, mode="conv", window_size=3) - + basis.MSplineBasis(5), + + basis.EvalMSpline(5), 6, ), ( basis.RaisedCosineBasisLog(5, mode="conv", window_size=3) - * basis.MSplineBasis(5), + * basis.EvalMSpline(5), 6, ), ], diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 70ff315a..939f5eab 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -5,7 +5,7 @@ import numpy as np import pytest -import nemos.basis as basis +from nemos import basis import nemos.simulation as simulation From 145ac9bd23ce2a09d46ac8c5e2c593b75b889a5c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 14 Nov 2024 14:36:48 -0500 Subject: [PATCH 002/109] setup raised cosine --- src/nemos/basis/__init__.py | 7 +- src/nemos/basis/_basis.py | 57 +---------- src/nemos/basis/_basis_mixin.py | 73 ++++++++++++-- src/nemos/basis/_raised_cosine_basis.py | 13 +-- src/nemos/basis/basis.py | 129 ++++++++++++++++++++++-- 5 files changed, 194 insertions(+), 85 deletions(-) diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index 29cbf9af..4995292e 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -1,3 +1,6 @@ -from .basis import (EvalMSpline, ConvMSpline, EvalCyclicBSpline, ConvCyclicBSpline, EvalBSpline, ConvBSpline, - RaisedCosineBasisLog, RaisedCosineBasisLog, RaisedCosineBasisLog, RaisedCosineBasisLog, +from .basis import (EvalMSpline, ConvMSpline, + EvalCyclicBSpline, ConvCyclicBSpline, + EvalBSpline, ConvBSpline, + EvalRaisedCosineLinear, ConvRaisedCosineLinear, + EvalRaisedCosineLog, ConvRaisedCosineLog, OrthExponentialBasis) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 8b1e2686..a89a2e43 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -152,66 +152,15 @@ def __init__( else: self._label = str(label) - # pop the two mode dependent kwargs - window_size = kwargs.pop("window_size", None) - if window_size: - self._window_size = window_size - bounds = kwargs.pop("bounds", None) - if bounds: - self._bounds = bounds - # the rest should be convolutional kwargs self._check_convolution_kwargs() self.kernel_ = None + @abc.abstractmethod def _check_convolution_kwargs(self): - """Check convolution kwargs settings. - - Raises - ------ - ValueError: - - If `self._conv_kwargs` are not None and `mode =="eval". - - If `axis` is provided as an argument, and it is different from 0 - (samples must always be in the first axis). - - If `self._conv_kwargs` include parameters not recognized or that do not have - default values in `create_convolutional_predictor`. - """ - # this should not be hit since **kwargs are not allowed at EvalBasis init. - if self._mode == "eval" and self._conv_kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" - ) - - if "axis" in self._conv_kwargs: - raise ValueError( - "Setting the `axis` parameter is not allowed. Basis requires the " - "convolution to be applied along the first axis (`axis=0`).\n" - "Please transpose your input so that the desired axis for " - "convolution is the first dimension (axis=0)." - ) - convolve_params = inspect.signature(create_convolutional_predictor).parameters - convolve_configs = { - key - for key, param in convolve_params.items() - if param.default - # prevent user from passing - # `basis_matrix` or `time_series` in kwargs. - is not inspect.Parameter.empty - } - if not set(self._conv_kwargs.keys()).issubset(convolve_configs): - # do not encourage to set axis. - convolve_configs = convolve_configs.difference({"axis"}) - # remove the parameter in case axis=0 was passed, since it is allowed. - invalid = ( - set(self._conv_kwargs.keys()) - .difference(convolve_configs) - .difference({"axis"}) - ) - raise ValueError( - f"Unrecognized keyword arguments: {invalid}. " - f"Allowed convolution keyword arguments are: {convolve_configs}." - ) + """Check convolution kwargs settings.""" + pass @property def n_output_features(self) -> int | None: diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index adcef722..660d2d44 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -3,13 +3,13 @@ from numpy.typing import ArrayLike from ..convolve import create_convolutional_predictor import numpy as np -from typing import Union, Tuple - +from typing import Union, Tuple, Optional +import inspect class EvalBasisMixin: - def __init__(self, *args, **kwargs): - self._bounds = kwargs.pop("bounds", None) + def __init__(self, bounds: Optional[Tuple[float, float]] = None): + self.bounds = bounds def _compute_features(self, *xi: ArrayLike): """ @@ -56,9 +56,6 @@ def bounds(self): @bounds.setter def bounds(self, values: Union[None, Tuple[float, float]]): """Setter for bounds.""" - if values is not None and self.mode == "conv": - raise ValueError("`bounds` should only be set when `mode=='eval'`.") - if values is not None and len(values) != 2: raise ValueError( f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" @@ -75,11 +72,26 @@ def bounds(self, values: Union[None, Tuple[float, float]]): f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." ) + def _check_convolution_kwargs(self): + """Check convolution kwargs settings. + + Raises + ------ + ValueError: + If `self._conv_kwargs` are not None. + """ + # this should not be hit since **kwargs are not allowed at EvalBasis init. + # still keep it for compliance with Abstract class Basis. + if self._conv_kwargs: + raise ValueError( + f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" + ) + class ConvBasisMixin: - def __init__(self, *args, **kwargs): - self._window_size = kwargs.pop("window_size") + def __init__(self, window_size: int): + self.window_size = window_size def _compute_features(self, *xi: ArrayLike): """ @@ -153,4 +165,45 @@ def window_size(self, window_size): f"`window_size` must be a positive integer. {window_size} provided instead!" ) - self._window_size = window_size \ No newline at end of file + self._window_size = window_size + + def _check_convolution_kwargs(self): + """Check convolution kwargs settings. + + Raises + ------ + ValueError: + - If `axis` is provided as an argument, and it is different from 0 + (samples must always be in the first axis). + - If `self._conv_kwargs` include parameters not recognized or that do not have + default values in `create_convolutional_predictor`. + """ + if "axis" in self._conv_kwargs: + raise ValueError( + "Setting the `axis` parameter is not allowed. Basis requires the " + "convolution to be applied along the first axis (`axis=0`).\n" + "Please transpose your input so that the desired axis for " + "convolution is the first dimension (axis=0)." + ) + convolve_params = inspect.signature(create_convolutional_predictor).parameters + convolve_configs = { + key + for key, param in convolve_params.items() + if param.default + # prevent user from passing + # `basis_matrix` or `time_series` in kwargs. + is not inspect.Parameter.empty + } + if not set(self._conv_kwargs.keys()).issubset(convolve_configs): + # do not encourage to set axis. + convolve_configs = convolve_configs.difference({"axis"}) + # remove the parameter in case axis=0 was passed, since it is allowed. + invalid = ( + set(self._conv_kwargs.keys()) + .difference(convolve_configs) + .difference({"axis"}) + ) + raise ValueError( + f"Unrecognized keyword arguments: {invalid}. " + f"Allowed convolution keyword arguments are: {convolve_configs}." + ) \ No newline at end of file diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index c02242c2..4989e172 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -5,15 +5,14 @@ from typing import Optional, Tuple import numpy as np -import scipy.linalg from numpy.typing import ArrayLike, NDArray from ..type_casting import support_pynapple from ..typing import FeatureMatrix -from ._basis_mixin import EvalBasisMixin, ConvBasisMixin -from ._basis import Basis, check_transform_input, check_one_dimensional + +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples import abc @@ -71,16 +70,12 @@ def __init__( n_basis_funcs: int, mode="eval", width: float = 2.0, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineBasisLinear", **kwargs, ) -> None: super().__init__( n_basis_funcs, mode=mode, - window_size=window_size, - bounds=bounds, label=label, **kwargs, ) @@ -289,8 +284,6 @@ def __init__( width: float = 2.0, time_scaling: float = None, enforce_decay_to_zero: bool = True, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "RaisedCosineBasisLog", **kwargs, ) -> None: @@ -298,8 +291,6 @@ def __init__( n_basis_funcs, mode=mode, width=width, - window_size=window_size, - bounds=bounds, **kwargs, label=label, ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 60ae2056..55219a2d 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -16,6 +16,7 @@ from ._basis import Basis, check_transform_input, check_one_dimensional from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis +from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog __all__ = [ "EvalMSpline", @@ -24,8 +25,8 @@ "ConvBSpline", "EvalCyclicBSpline", "ConvCyclicBSpline", - "RaisedCosineBasisLinear", - "RaisedCosineBasisLog", + "EvalRaisedCosineLinear", + "ConvRaisedCosineLinear", "OrthExponentialBasis", ] @@ -48,7 +49,6 @@ def __init__( n_basis_funcs, mode="eval", order=order, - bounds=bounds, label=label, ) @@ -68,7 +68,6 @@ def __init__( n_basis_funcs, mode="conv", order=order, - window_size=window_size, label=label, ) @@ -87,7 +86,6 @@ def __init__( n_basis_funcs, mode="eval", order=order, - bounds=bounds, label=label, ) @@ -106,7 +104,6 @@ def __init__( n_basis_funcs, mode="conv", order=order, - window_size=window_size, label=label, ) @@ -125,7 +122,6 @@ def __init__( n_basis_funcs, mode="eval", order=order, - bounds=bounds, label=label, ) @@ -138,15 +134,132 @@ def __init__( order: int = 4, label: Optional[str] = "ConvMSpline", ): + ConvBasisMixin.__init__(self, window_size=window_size) MSplineBasis.__init__( self, n_basis_funcs, mode="conv", order=order, - window_size=window_size, label=label, ) + +class EvalMSpline(EvalBasisMixin, MSplineBasis): + def __init__( + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + MSplineBasis.__init__( + self, + n_basis_funcs, + mode="eval", + order=order, + label=label, + ) + + +class ConvMSpline(ConvBasisMixin, MSplineBasis): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvMSpline", + ): + ConvBasisMixin.__init__(self, window_size=window_size) + MSplineBasis.__init__( + self, + n_basis_funcs, + mode="conv", + order=order, + label=label, + ) + + +class EvalRaisedCosineLinear(EvalBasisMixin, RaisedCosineBasisLinear): + def __init__( + self, + n_basis_funcs: int, + width: float = 2.0, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + RaisedCosineBasisLinear.__init__( + self, + n_basis_funcs, + width=width, + mode="eval", + label=label, + ) + + +class ConvRaisedCosineLinear(ConvBasisMixin, RaisedCosineBasisLinear): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + width: float = 2.0, + label: Optional[str] = "ConvMSpline", + **conv_kwargs, + ): + ConvBasisMixin.__init__(self, window_size=window_size) + RaisedCosineBasisLinear.__init__( + self, + n_basis_funcs, + mode="conv", + width=width, + label=label, + **conv_kwargs, + ) + +class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): + def __init__( + self, + n_basis_funcs: int, + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", + ): + EvalBasisMixin.__init__(self, bounds=bounds) + RaisedCosineBasisLog.__init__( + self, + n_basis_funcs, + width=width, + time_scaling=time_scaling, + enforce_decay_to_zero=enforce_decay_to_zero, + mode="eval", + label=label, + ) + + +class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): + def __init__( + self, + n_basis_funcs: int, + window_size: int, + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + label: Optional[str] = "ConvMSpline", + **conv_kwargs, + ): ConvBasisMixin.__init__(self, window_size=window_size) + RaisedCosineBasisLog.__init__( + self, + n_basis_funcs, + mode="conv", + width=width, + time_scaling=time_scaling, + enforce_decay_to_zero=enforce_decay_to_zero, + label=label, + **conv_kwargs, + ) From 5a4bb558f24eec37e3f8aa866e897bf984050e64 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 18 Nov 2024 12:44:29 -0500 Subject: [PATCH 003/109] fixed docs class orth exp --- src/nemos/basis/__init__.py | 3 +- src/nemos/basis/_basis.py | 65 ++--- src/nemos/basis/_basis_mixin.py | 12 +- src/nemos/basis/_decaying_exponential.py | 220 +++++++++++++++ src/nemos/basis/basis.py | 330 +++++++---------------- src/nemos/identifiability_constraints.py | 2 +- 6 files changed, 352 insertions(+), 280 deletions(-) create mode 100644 src/nemos/basis/_decaying_exponential.py diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index 4995292e..fe6de99a 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -3,4 +3,5 @@ EvalBSpline, ConvBSpline, EvalRaisedCosineLinear, ConvRaisedCosineLinear, EvalRaisedCosineLog, ConvRaisedCosineLog, - OrthExponentialBasis) + EvalOrthExponential, ConvOrthExponential) +from ._basis import TransformerBasis diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 18c18649..baf0639a 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -3,7 +3,7 @@ import abc import copy -import inspect + from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union @@ -13,7 +13,6 @@ from pynapple import Tsd, TsdFrame from ..base_class import Base -from ..convolve import create_convolutional_predictor from ..type_casting import support_pynapple from ..utils import row_wise_kron @@ -103,12 +102,6 @@ class Basis(Base, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. """ @@ -117,13 +110,10 @@ def __init__( n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", label: Optional[str] = None, - **kwargs, ) -> None: self.n_basis_funcs = n_basis_funcs self._n_input_dimensionality = 0 - self._conv_kwargs = kwargs - self._mode = mode self._n_basis_input = None @@ -138,9 +128,6 @@ def __init__( else: self._label = str(label) - # the rest should be convolutional kwargs - self._check_convolution_kwargs() - self.kernel_ = None @abc.abstractmethod @@ -258,12 +245,12 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: Examples -------- >>> import numpy as np - >>> from nemos.basis import BSplineBasis + >>> from nemos.basis import EvalBSpline >>> # Generate data >>> num_samples = 10000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) + >>> basis = EvalBSpline(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (10000, 10) @@ -425,8 +412,8 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: >>> # Evaluate and visualize 4 M-spline basis functions of order 3: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import MSplineBasis - >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) + >>> from nemos.basis import EvalMSpline + >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) >>> p = plt.plot(sample_points, basis_values) >>> _ = plt.title('M-Spline Basis Functions') @@ -588,7 +575,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() + >>> basis = nmo.basis.EvalRaisedCosineLinear(10).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( @@ -772,10 +759,10 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import BSplineBasis + >>> from nemos.basis import ConvBSpline >>> from nemos.glm import GLM >>> # Define an additive basis - >>> basis = BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature") + >>> basis = ConvBSpline(n_basis_funcs=5, window_size=10, label="feature") >>> # Generate a sample input array and compute features >>> x = np.random.randn(20) >>> X = basis.compute_features(x) @@ -785,7 +772,7 @@ def split_by_feature( ... print(f"{feature}: shape {arr.shape}") feature: shape (20, 1, 5) >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, + >>> multi_input_basis = ConvBSpline(n_basis_funcs=6, window_size=10, ... label="multi_input") >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) @@ -934,7 +921,7 @@ class TransformerBasis: Examples -------- - >>> from nemos.basis import BSplineBasis, TransformerBasis + >>> from nemos.basis import EvalBSpline, TransformerBasis >>> from nemos.glm import GLM >>> from sklearn.pipeline import Pipeline >>> from sklearn.model_selection import GridSearchCV @@ -944,7 +931,7 @@ class TransformerBasis: >>> # Generate data >>> num_samples, num_features = 10000, 1 >>> x = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) + >>> basis = EvalBSpline(10) >>> features = basis.compute_features(x) # basis transformed time series >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts @@ -1129,7 +1116,7 @@ def __getattr__(self, name: str): Examples -------- >>> from nemos import basis - >>> bas = basis.RaisedCosineBasisLinear(5) + >>> bas = basis.EvalRaisedCosineLinear(5) >>> trans_bas = basis.TransformerBasis(bas) >>> bas.n_basis_funcs 5 @@ -1158,7 +1145,7 @@ def __setattr__(self, name: str, value) -> None: >>> import nemos as nmo >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.EvalMSpline(10)) >>> # allowed - >>> trans_bas._basis = nmo.basis.BSplineBasis(10) + >>> trans_bas._basis = nmo.basis.EvalBSpline(10) >>> # allowed >>> trans_bas.n_basis_funcs = 20 >>> # not allowed @@ -1202,7 +1189,7 @@ def set_params(self, **parameters) -> TransformerBasis: Examples -------- - >>> from nemos.basis import BSplineBasis, EvalMSpline, TransformerBasis + >>> from nemos.basis import EvalBSpline, EvalMSpline, TransformerBasis >>> basis = EvalMSpline(10) >>> transformer_basis = TransformerBasis(basis=basis) @@ -1210,11 +1197,11 @@ def set_params(self, **parameters) -> TransformerBasis: >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) 8 >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)) + >>> print(type(transformer_basis.set_params(_basis=EvalBSpline(10))._basis)) >>> # mixing is not allowed, this will raise an exception >>> try: - ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) + ... transformer_basis.set_params(_basis=EvalBSpline(10), n_basis_funcs=2) ... except ValueError as e: ... print(repr(e)) ValueError('Set either new _basis object or parameters for existing _basis, not both.') @@ -1322,13 +1309,13 @@ class AdditiveBasis(Basis): >>> X = np.random.normal(size=(30, 2)) >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> basis_1 = nmo.basis.EvalBSpline(10) + >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) >>> additive_basis = basis_1 + basis_2 >>> # can add another basis to the AdditiveBasis object >>> X = np.random.normal(size=(30, 3)) - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> basis_3 = nmo.basis.EvalRaisedCosineLog(100) >>> additive_basis_2 = additive_basis + basis_3 """ @@ -1520,12 +1507,12 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import BSplineBasis + >>> from nemos.basis import ConvBSpline >>> from nemos.glm import GLM >>> # Define an additive basis >>> basis = ( - ... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") + - ... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2") + ... ConvBSpline(n_basis_funcs=5, window_size=10, label="feature_1") + + ... ConvBSpline(n_basis_funcs=6, window_size=10, label="feature_2") ... ) >>> # Generate a sample input array and compute features >>> x1, x2 = np.random.randn(20), np.random.randn(20) @@ -1537,7 +1524,7 @@ def split_by_feature( feature_1: shape (20, 1, 5) feature_2: shape (20, 1, 6) >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, + >>> multi_input_basis = ConvBSpline(n_basis_funcs=6, window_size=10, ... label="multi_input") >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) @@ -1581,13 +1568,13 @@ class MultiplicativeBasis(Basis): >>> X = np.random.normal(size=(30, 3)) >>> # define two basis and multiply - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) + >>> basis_1 = nmo.basis.EvalBSpline(10) + >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) >>> multiplicative_basis = basis_1 * basis_2 >>> # Can multiply or add another basis to the AdditiveBasis object >>> # This will cause the number of output features of the result basis to grow accordingly - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) + >>> basis_3 = nmo.basis.EvalRaisedCosineLog(100) >>> multiplicative_basis_2 = multiplicative_basis * basis_3 """ diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d6f95c7e..35ea3087 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -91,8 +91,10 @@ def _check_convolution_kwargs(self): class ConvBasisMixin: - def __init__(self, window_size: int): + def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.window_size = window_size + self._conv_kwargs = {} if conv_kwargs is None else conv_kwargs + self._check_convolution_kwargs() def _compute_features(self, *xi: ArrayLike): """ @@ -172,6 +174,14 @@ def window_size(self, window_size): self._window_size = window_size + @property + def conv_kwargs(self): + """The convolutional kwargs. + + Keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`. + """ + return self._conv_kwargs + def _check_convolution_kwargs(self): """Check convolution kwargs settings. diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py new file mode 100644 index 00000000..bd82ca41 --- /dev/null +++ b/src/nemos/basis/_decaying_exponential.py @@ -0,0 +1,220 @@ +"""Decaying exponential basis.""" + +# required to get ArrayLike to render correctly +from __future__ import annotations + +import abc +from typing import Optional, Tuple + +import numpy as np +import scipy.linalg +from numpy.typing import NDArray + + +from ..type_casting import support_pynapple +from ..typing import FeatureMatrix + +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples + +_ORTHEXPONENTIAL_EVAL_IMPORT = ">>> from nemos.basis import EvalOrthExponential" +_ORTHEXPONENTIAL_CONV_IMPORT = ">>> from nemos.basis import ConvOrthExponential" + +_ORTHEXPONENTIAL_EVAL_INIT = '>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates)' +_ORTHEXPONENTIAL_CONV_INIT = '>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, window_size, decay_rates)' + + +class OrthExponentialBasis(Basis, abc.ABC): + """Set of 1D basis decaying exponential functions numerically orthogonalized. + + Parameters + ---------- + n_basis_funcs + Number of basis functions. + decay_rates : + Decay rates of the exponentials, shape ``(n_basis_funcs,)``. + mode : + The mode of operation. ``'eval'`` for evaluation at sample points, + ``'conv'`` for convolutional operation. + window_size : + The window size for convolution. Required if mode is ``'conv'``. + bounds : + The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + **kwargs : + Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + ``mode='conv'``; These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. + + """ + + def __init__( + self, + n_basis_funcs: int, + decay_rates: NDArray[np.floating], + mode="eval", + label: Optional[str] = "OrthExponentialBasis", + **kwargs, + ): + super().__init__( + n_basis_funcs, + mode=mode, + label=label, + **kwargs, + ) + self.decay_rates = decay_rates + self._check_rates() + self._n_input_dimensionality = 1 + + @property + def decay_rates(self): + r"""Decay rate. + + The rate of decay of the exponential functions. If :math:`f_i(t) = \exp{-\alpha_i t}` is the i-th decay + exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector. + """ + return self._decay_rates + + @decay_rates.setter + def decay_rates(self, value: NDArray): + """Decay rate setter.""" + value = np.asarray(value, dtype=float) + if value.shape[0] != self.n_basis_funcs: + raise ValueError( + f"The number of basis functions must match the number of decay rates provided. " + f"Number of basis functions provided: {self.n_basis_funcs}, " + f"Number of decay rates provided: {value.shape[0]}" + ) + self._decay_rates = value + + def _check_n_basis_min(self) -> None: + """Check that the user required enough basis elements. + + Checks that the number of basis is at least 1. + + Raises + ------ + ValueError + If an insufficient number of basis element is requested for the basis type + """ + if self.n_basis_funcs < 1: + raise ValueError( + f"Object class {self.__class__.__name__} requires >= 1 basis elements. " + f"{self.n_basis_funcs} basis elements specified instead" + ) + + def _check_rates(self) -> None: + """ + Check if the decay rates list has duplicate entries. + + Raises + ------ + ValueError + If two or more decay rates are repeated, which would result in a linearly + dependent set of functions for the basis. + """ + if len(set(self._decay_rates)) != len(self._decay_rates): + raise ValueError( + "Two or more rate are repeated! Repeating rate will result in a " + "linearly dependent set of function for the basis." + ) + + def _check_sample_size(self, *sample_pts: NDArray) -> None: + """Check that the sample size is greater than the number of basis. + + This is necessary for the orthogonalization procedure, + that otherwise will return (sample_size, ) basis elements instead of the expected number. + + Parameters + ---------- + sample_pts + Spacing for basis functions, holding elements on the interval [0, inf). + + Raises + ------ + ValueError + If the number of basis element is less than the number of samples. + """ + if sample_pts[0].size < self.n_basis_funcs: + raise ValueError( + "OrthExponentialBasis requires at least as many samples as basis functions!\n" + f"Class instantiated with {self.n_basis_funcs} basis functions " + f"but only {sample_pts[0].size} samples provided!" + ) + + @support_pynapple(conv_type="numpy") + @check_transform_input + @check_one_dimensional + def __call__( + self, + sample_pts: NDArray, + ) -> FeatureMatrix: + """Generate basis functions with given spacing. + + Parameters + ---------- + sample_pts + Spacing for basis functions, holding elements on the interval [0, + inf), shape (n_samples,). + + Returns + ------- + basis_funcs + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + """ + self._check_sample_size(sample_pts) + sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + valid_idx = ~np.isnan(sample_pts) + # because of how scipy.linalg.orth works, have to create a matrix of + # shape (n_pts, n_basis_funcs) and then transpose, rather than + # directly computing orth on the matrix of shape (n_basis_funcs, + # n_pts) + exp_decay_eval = np.stack( + [np.exp(-lam * sample_pts[valid_idx]) for lam in self._decay_rates], axis=1 + ) + # count the linear independent components (could be lower than n_basis_funcs for num precision). + n_independent_component = np.linalg.matrix_rank(exp_decay_eval) + # initialize output to nan + basis_funcs = np.full( + shape=(sample_pts.shape[0], n_independent_component), fill_value=np.nan + ) + # orthonormalize on valid points + basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) + return basis_funcs + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of samples. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import OrthExponentialBasis + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + """ + return super().evaluate_on_grid(n_samples) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 8f250fe5..b3a43ae2 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -5,18 +5,14 @@ from typing import Optional, Tuple -import numpy as np -import scipy.linalg -from numpy.typing import ArrayLike, NDArray +from numpy.typing import NDArray -from ..type_casting import support_pynapple -from ..typing import FeatureMatrix from ._basis_mixin import EvalBasisMixin, ConvBasisMixin -from ._basis import Basis, check_transform_input, check_one_dimensional from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from ._decaying_exponential import OrthExponentialBasis __all__ = [ "EvalMSpline", @@ -27,7 +23,10 @@ "ConvCyclicBSpline", "EvalRaisedCosineLinear", "ConvRaisedCosineLinear", - "OrthExponentialBasis", + "EvalRaisedCosineLog", + "ConvRaisedCosineLog", + "EvalOrthExponential", + "ConvOrthExponential", ] @@ -61,8 +60,9 @@ def __init__( window_size: int, order: int = 4, label: Optional[str] = "ConvBSpline", + conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( self, n_basis_funcs, @@ -97,8 +97,9 @@ def __init__( window_size: int, order: int = 4, label: Optional[str] = "ConvCyclicBSpline", + conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( self, n_basis_funcs, @@ -133,43 +134,9 @@ def __init__( window_size: int, order: int = 4, label: Optional[str] = "ConvMSpline", + conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size) - MSplineBasis.__init__( - self, - n_basis_funcs, - mode="conv", - order=order, - label=label, - ) - -class EvalMSpline(EvalBasisMixin, MSplineBasis): - def __init__( - self, - n_basis_funcs: int, - order: int = 4, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalMSpline", - ): - EvalBasisMixin.__init__(self, bounds=bounds) - MSplineBasis.__init__( - self, - n_basis_funcs, - mode="eval", - order=order, - label=label, - ) - - -class ConvMSpline(ConvBasisMixin, MSplineBasis): - def __init__( - self, - n_basis_funcs: int, - window_size: int, - order: int = 4, - label: Optional[str] = "ConvMSpline", - ): - ConvBasisMixin.__init__(self, window_size=window_size) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( self, n_basis_funcs, @@ -185,7 +152,7 @@ def __init__( n_basis_funcs: int, width: float = 2.0, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalMSpline", + label: Optional[str] = "EvalRaisedCosineLinear", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLinear.__init__( @@ -203,17 +170,16 @@ def __init__( n_basis_funcs: int, window_size: int, width: float = 2.0, - label: Optional[str] = "ConvMSpline", - **conv_kwargs, + label: Optional[str] = "ConvRaisedCosineLinear", + conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLinear.__init__( self, n_basis_funcs, mode="conv", width=width, label=label, - **conv_kwargs, ) class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): @@ -224,7 +190,7 @@ def __init__( time_scaling: float = None, enforce_decay_to_zero: bool = True, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalMSpline", + label: Optional[str] = "EvalRaisedCosineLog", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( @@ -246,10 +212,10 @@ def __init__( width: float = 2.0, time_scaling: float = None, enforce_decay_to_zero: bool = True, - label: Optional[str] = "ConvMSpline", - **conv_kwargs, + label: Optional[str] = "ConvRaisedCosineLog", + conv_kwargs: Optional[dict] = None, ): - ConvBasisMixin.__init__(self, window_size=window_size) + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( self, n_basis_funcs, @@ -258,33 +224,73 @@ def __init__( time_scaling=time_scaling, enforce_decay_to_zero=enforce_decay_to_zero, label=label, - **conv_kwargs, ) +class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): + def __init__( + self, + n_basis_funcs: int, + decay_rates: NDArray, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalOrthExponential", + ): + """Set of 1D basis decaying exponential functions numerically orthogonalized. + + Parameters + ---------- + n_basis_funcs + Number of basis functions. + decay_rates : + Decay rates of the exponentials, shape ``(n_basis_funcs,)``. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis import ConvOrthExponential + >>> X = np.random.normal(size=(1000, 1)) + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size = 10 + >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates) + >>> sample_points = linspace(0, 1, 100) + >>> # evaluate the basis + >>> basis_functions = ortho_basis.compute_features(sample_points) + + """ + EvalBasisMixin.__init__(self, bounds=bounds) + OrthExponentialBasis.__init__( + self, + n_basis_funcs, + decay_rates=decay_rates, + mode="eval", + label=label, + ) + -class OrthExponentialBasis(Basis): +class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. Parameters ---------- n_basis_funcs Number of basis functions. + window_size : + The window size for convolution as number of samples. decay_rates : Decay rates of the exponentials, shape ``(n_basis_funcs,)``. - mode : - The mode of operation. ``'eval'`` for evaluation at sample points, - ``'conv'`` for convolutional operation. - window_size : - The window size for convolution. Required if mode is ``'conv'``. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : + conv_kwargs : Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. @@ -293,183 +299,31 @@ class OrthExponentialBasis(Basis): Examples -------- - >>> from numpy import linspace - >>> from nemos.basis import OrthExponentialBasis + >>> import numpy as np + >>> from nemos.basis import ConvOrthExponential >>> X = np.random.normal(size=(1000, 1)) >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = ortho_basis(sample_points) - """ + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size = 10 + >>> ortho_basis = ConvOrthExponential(n_basis_funcs, window_size, decay_rates) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> basis_functions = ortho_basis.compute_features(sample_points) + """ def __init__( - self, - n_basis_funcs: int, - decay_rates: NDArray[np.floating], - mode="eval", - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "OrthExponentialBasis", - **kwargs, + self, + n_basis_funcs: int, + window_size: int, + decay_rates: NDArray, + label: Optional[str] = "ConvOrthExponential", + conv_kwargs: Optional[dict] = None, ): - super().__init__( + ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) + OrthExponentialBasis.__init__( + self, n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, + mode="conv", + decay_rates=decay_rates, label=label, - **kwargs, - ) - self.decay_rates = decay_rates - self._check_rates() - self._n_input_dimensionality = 1 - - @property - def decay_rates(self): - """Decay rate. - - The rate of decay of the exponential functions. If :math:`f_i(t) = \exp{-\alpha_i t}` is the i-th decay - exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector. - """ - return self._decay_rates - - @decay_rates.setter - def decay_rates(self, value: NDArray): - """Decay rate setter.""" - value = np.asarray(value) - if value.shape[0] != self.n_basis_funcs: - raise ValueError( - f"The number of basis functions must match the number of decay rates provided. " - f"Number of basis functions provided: {self.n_basis_funcs}, " - f"Number of decay rates provided: {value.shape[0]}" - ) - self._decay_rates = value - - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Checks that the number of basis is at least 1. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - if self.n_basis_funcs < 1: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 1 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) - - def _check_rates(self) -> None: - """ - Check if the decay rates list has duplicate entries. - - Raises - ------ - ValueError - If two or more decay rates are repeated, which would result in a linearly - dependent set of functions for the basis. - """ - if len(set(self._decay_rates)) != len(self._decay_rates): - raise ValueError( - "Two or more rate are repeated! Repeating rate will result in a " - "linearly dependent set of function for the basis." - ) - - def _check_sample_size(self, *sample_pts: NDArray) -> None: - """Check that the sample size is greater than the number of basis. - - This is necessary for the orthogonalization procedure, - that otherwise will return (sample_size, ) basis elements instead of the expected number. - - Parameters - ---------- - sample_pts - Spacing for basis functions, holding elements on the interval [0, inf). - - Raises - ------ - ValueError - If the number of basis element is less than the number of samples. - """ - if sample_pts[0].size < self.n_basis_funcs: - raise ValueError( - "OrthExponentialBasis requires at least as many samples as basis functions!\n" - f"Class instantiated with {self.n_basis_funcs} basis functions " - f"but only {sample_pts[0].size} samples provided!" - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: NDArray, - ) -> FeatureMatrix: - """Generate basis functions with given spacing. - - Parameters - ---------- - sample_pts - Spacing for basis functions, holding elements on the interval [0, - inf), shape (n_samples,). - - Returns - ------- - basis_funcs - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) - - """ - self._check_sample_size(sample_pts) - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - valid_idx = ~np.isnan(sample_pts) - # because of how scipy.linalg.orth works, have to create a matrix of - # shape (n_pts, n_basis_funcs) and then transpose, rather than - # directly computing orth on the matrix of shape (n_basis_funcs, - # n_pts) - exp_decay_eval = np.stack( - [np.exp(-lam * sample_pts[valid_idx]) for lam in self._decay_rates], axis=1 ) - # count the linear independent components (could be lower than n_basis_funcs for num precision). - n_independent_component = np.linalg.matrix_rank(exp_decay_eval) - # initialize output to nan - basis_funcs = np.full( - shape=(sample_pts.shape[0], n_independent_component), fill_value=np.nan - ) - # orthonormalize on valid points - basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) - return basis_funcs - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - - Returns - ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import OrthExponentialBasis - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index afd03f2c..fe1eab31 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -9,7 +9,7 @@ from jax.typing import ArrayLike as JaxArray from numpy.typing import NDArray -from nemos.basis.basis import Basis +from .basis._basis import Basis from .tree_utils import get_valid_multitree, tree_slice from .type_casting import support_pynapple from .validation import _warn_if_not_float64 From 346f8123e653d719b50c5ef91b985a5722467a0a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 18 Nov 2024 16:42:53 -0500 Subject: [PATCH 004/109] documentation fix --- docs/api_reference.rst | 12 ++---------- src/nemos/basis/__init__.py | 4 +++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 779e31a6..a10be9b4 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -31,16 +31,8 @@ Provides basis function classes to construct and transform features for model in :nosignatures: Basis - SplineBasis - BSplineBasis - CyclicBSplineBasis - MSplineBasis - OrthExponentialBasis - RaisedCosineBasisLinear - RaisedCosineBasisLog - AdditiveBasis - MultiplicativeBasis - TransformerBasis + EvalOrthExponential + ConvOrthExponential .. _observation_models: The ``nemos.observation_models`` module diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index fe6de99a..175baaa9 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -4,4 +4,6 @@ EvalRaisedCosineLinear, ConvRaisedCosineLinear, EvalRaisedCosineLog, ConvRaisedCosineLog, EvalOrthExponential, ConvOrthExponential) -from ._basis import TransformerBasis +from ._basis import AdditiveBasis, MultiplicativeBasis, Basis +from ._spline_basis import BSplineBasis +from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog From ef438d3b49b61be4a17d15c1d97d428491582e15 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 18 Nov 2024 16:52:09 -0500 Subject: [PATCH 005/109] evaluate on grid orth exp --- src/nemos/basis/_decaying_exponential.py | 17 ------- src/nemos/basis/basis.py | 63 ++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index bd82ca41..8acc515f 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -16,12 +16,6 @@ from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples -_ORTHEXPONENTIAL_EVAL_IMPORT = ">>> from nemos.basis import EvalOrthExponential" -_ORTHEXPONENTIAL_CONV_IMPORT = ">>> from nemos.basis import ConvOrthExponential" - -_ORTHEXPONENTIAL_EVAL_INIT = '>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates)' -_ORTHEXPONENTIAL_CONV_INIT = '>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, window_size, decay_rates)' - class OrthExponentialBasis(Basis, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -205,16 +199,5 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: basis_funcs : Evaluated exponentially decaying basis functions, numerically orthogonalized, shape (n_samples, n_basis_funcs) - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import OrthExponentialBasis - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index b3a43ae2..7e3731f7 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -275,6 +275,37 @@ def __init__( label=label, ) + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Generate basis functions with given spacing. + + Parameters + ---------- + n_samples: + The number of samples. + + Returns + ------- + X : + Array of shape ``(n_samples,)`` containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape ``(n_samples, n_basis_funcs)`` + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalOrthExponential + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates=decay_rates) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return super().evaluate_on_grid(n_samples=n_samples) + class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -327,3 +358,35 @@ def __init__( decay_rates=decay_rates, label=label, ) + + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """Generate basis functions with given spacing. + + Parameters + ---------- + n_samples: + The number of samples. + + Returns + ------- + X : + Array of shape ``(n_samples,)`` containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape ``(n_samples, n_basis_funcs)`` + + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvOrthExponential + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = ConvOrthExponential(n_basis_funcs, window_size, decay_rates=decay_rates) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return super().evaluate_on_grid(n_samples=n_samples) + From 8028855565f370c2770b28f4fe3975fab74c81c1 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 18 Nov 2024 16:59:56 -0500 Subject: [PATCH 006/109] compute features orth exp --- src/nemos/basis/basis.py | 70 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 7e3731f7..8de1e095 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple -from numpy.typing import NDArray +from numpy.typing import NDArray, ArrayLike from ._basis_mixin import EvalBasisMixin, ConvBasisMixin @@ -13,6 +13,7 @@ from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._decaying_exponential import OrthExponentialBasis +from ..typing import FeatureMatrix __all__ = [ "EvalMSpline", @@ -306,6 +307,40 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples=n_samples) + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Compute the basis functions and transform input data into model features. + + This method is designed to be a high-level interface for transforming input + data using the basis functions defined by the subclass. It evaluates the basis functions at the sample + points. + + Parameters + ---------- + *xi : + Input data arrays to be transformed. + + Returns + ------- + : + Transformed features, consisting of the basis functions evaluated at the input samples. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvOrthExponential + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvOrthExponential(10, window_size=100, decay_rates=np.arange(1, 11)) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return super().compute_features(*xi) + class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -390,3 +425,36 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples=n_samples) + def compute_features(*xi: ArrayLike) -> FeatureMatrix: + """ + Compute the basis functions and transform input data into model features. + + This method is designed to be a high-level interface for transforming input + data using the basis functions defined by the subclass. Performs a convolution operation between + the input data and the basis functions. + + Parameters + ---------- + *xi : + Input data arrays to be transformed. + + Returns + ------- + : + Transformed features, consisting of convolved input samples with the basis functions. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvOrthExponential + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvOrthExponential(10, window_size=100, decay_rates=np.arange(1, 11)) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return super().compute_features(*xi) \ No newline at end of file From 71ff34cdee7295814e1b05ad96cbd79ed1e9f968 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 18 Nov 2024 17:25:51 -0500 Subject: [PATCH 007/109] split by orth exp --- src/nemos/basis/basis.py | 142 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 8de1e095..e36b0642 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -341,6 +341,78 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ return super().compute_features(*xi) + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. + + This function takes an array (e.g., a design matrix or model coefficients) and splits it along + a designated axis. + + **How it works:** + + - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will + be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions + ``(n_inputs, n_basis_funcs)``. + - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will + be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. + + For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, + then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. + + The specified axis (``axis``) determines where the split occurs, and all other dimensions + remain unchanged. See the example section below for the most common use cases. + + Parameters + ---------- + x : + The input array to be split, representing concatenated features, coefficients, + or other data. The shape of ``x`` along the specified axis must match the total + number of features generated by the basis, i.e., ``self.n_output_features``. + + **Examples:** + - For a design matrix: ``(n_samples, total_n_features)`` + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. + + axis : int, optional + The axis along which to split the features. Defaults to 1. + Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for + coefficient arrays (features along rows). All other dimensions are preserved. + + Raises + ------ + ValueError + If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. + + Returns + ------- + dict + A dictionary where: + - **Key**: Label of the basis. + - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalOrthExponential + >>> from nemos.glm import GLM + >>> # Define an additive basis + >>> basis = EvalOrthExponential(n_basis_funcs=5, label="feature") + >>> # Generate a sample input array and compute features + >>> x = np.random.randn(20) + >>> X = basis.compute_features(x) + >>> # Split the feature matrix along axis 1 + >>> split_features = basis.split_by_feature(X, axis=1) + >>> for feature, arr in split_features.items(): + ... print(f"{feature}: shape {arr.shape}") + feature: shape (20, 1, 5) + + """ + return super().split_by_feature(x, axis=axis) + class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -457,4 +529,72 @@ def compute_features(*xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return super().compute_features(*xi) \ No newline at end of file + return super().compute_features(*xi) + + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. + + This function takes an array (e.g., a design matrix or model coefficients) and splits it along + a designated axis. + + **How it works:** + + - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will + be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions + ``(n_inputs, n_basis_funcs)``. + - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will + be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. + + For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, + then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. + + The specified axis (``axis``) determines where the split occurs, and all other dimensions + remain unchanged. See the example section below for the most common use cases. + + Parameters + ---------- + x : + The input array to be split, representing concatenated features, coefficients, + or other data. The shape of ``x`` along the specified axis must match the total + number of features generated by the basis, i.e., ``self.n_output_features``. + + **Examples:** + - For a design matrix: ``(n_samples, total_n_features)`` + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. + + axis : int, optional + The axis along which to split the features. Defaults to 1. + Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for + coefficient arrays (features along rows). All other dimensions are preserved. + + Raises + ------ + ValueError + If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. + + Returns + ------- + dict + A dictionary where: + - **Key**: Label of the basis. + - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvOrthExponential + >>> from nemos.glm import GLM + >>> basis = ConvOrthExponential(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return super().split_by_feature(x, axis=axis) \ No newline at end of file From e3ad3ea441debd12987503014c25abfe209494d2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 11:47:02 -0500 Subject: [PATCH 008/109] added changes --- src/nemos/_documentation_utils/plotting.py | 2 +- src/nemos/typing.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/nemos/_documentation_utils/plotting.py b/src/nemos/_documentation_utils/plotting.py index b94c0079..086ccb2c 100644 --- a/src/nemos/_documentation_utils/plotting.py +++ b/src/nemos/_documentation_utils/plotting.py @@ -33,7 +33,7 @@ from matplotlib.patches import Rectangle from numpy.typing import NDArray -from nemos.basis.basis import RaisedCosineBasisLog +from ..basis import RaisedCosineBasisLog warnings.warn( "plotting functions contained within `_documentation_utils` are intended for nemos's documentation. " diff --git a/src/nemos/typing.py b/src/nemos/typing.py index 339e6441..f1cfc4fc 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -7,8 +7,6 @@ from jax.typing import ArrayLike from .pytrees import FeaturePytree -from pynapple import TsdFrame -from numpy.typing import NDArray DESIGN_INPUT_TYPE = Union[jnp.ndarray, FeaturePytree] @@ -53,5 +51,3 @@ ], # Step-size for optimization (must be a float) Tuple[jnp.ndarray, jnp.ndarray], ] - -FeatureMatrix = Union[NDArray, TsdFrame] \ No newline at end of file From 798bec0b60a2f05bff4161b399f30014a396dd7e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 11:47:20 -0500 Subject: [PATCH 009/109] updated basis old --- src/nemos/basis_old.py | 49 ++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/nemos/basis_old.py b/src/nemos/basis_old.py index ad4f0625..f7067de8 100644 --- a/src/nemos/basis_old.py +++ b/src/nemos/basis_old.py @@ -503,7 +503,7 @@ class Basis(Base, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -512,11 +512,14 @@ class Basis(Base, abc.ABC): Raises ------ ValueError: - - If ``mode`` is not 'eval' or 'conv'. - - If ``kwargs`` are not None and ``mode =="eval"``. - - If ``kwargs`` include parameters not recognized or do not have + If ``mode`` is not 'eval' or 'conv'. + ValueError: + If ``kwargs`` are not None and ``mode =="eval"``. + ValueError: + If ``kwargs`` include parameters not recognized or do not have default values in ``create_convolutional_predictor``. - - If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). + ValueError: + If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). """ def __init__( @@ -1005,9 +1008,10 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Raises ------ ValueError - - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what + If the time point number is inconsistent between inputs or if the number of inputs doesn't match what the Basis object requires. - - If one of the n_samples is <= 0. + ValueError + If one of the n_samples is <= 0. Notes ----- @@ -1151,7 +1155,7 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis: Returns ------- : - The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. + The product of the basis with itself "exponent" times. Equivalent to ``self * self * ... * self``. Raises ------ @@ -1326,10 +1330,11 @@ def split_by_feature( **How it works:** - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will - be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions - ``(n_inputs, n_basis_funcs)``. + be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions + ``(n_inputs, n_basis_funcs)``. + - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will - be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. + be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. @@ -1345,7 +1350,9 @@ def split_by_feature( number of features generated by the basis, i.e., ``self.n_output_features``. **Examples:** + - For a design matrix: ``(n_samples, total_n_features)`` + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. axis : int, optional @@ -1921,7 +1928,7 @@ class SplineBasis(Basis, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2087,7 +2094,7 @@ class MSplineBasis(SplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs: - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2253,7 +2260,7 @@ class BSplineBasis(SplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2392,7 +2399,7 @@ class CyclicBSplineBasis(SplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2556,7 +2563,7 @@ class RaisedCosineBasisLinear(Basis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2771,7 +2778,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2942,7 +2949,7 @@ class OrthExponentialBasis(Basis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -2985,9 +2992,9 @@ def __init__( @property def decay_rates(self): - """Decay rate. + r"""Decay rate. - The rate of decay of the exponential functions. If :math:`f_i(t) = \exp{-\alpha_i t}` is the i-th decay + The rate of decay of the exponential functions. If :math:`f_i(t) = e^{-\alpha_i t}` is the i-th decay exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector. """ return self._decay_rates @@ -3272,4 +3279,4 @@ def bspline( sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der ) - return basis_eval.T + return basis_eval.T \ No newline at end of file From dc46d931ad8265f75d8201f532f2204c0eff0f09 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 11:52:22 -0500 Subject: [PATCH 010/109] fixed _basis.py docstrings --- src/nemos/basis/_basis.py | 176 ++++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 85 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 4b3ef319..26a5db31 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -109,7 +109,7 @@ class Basis(Base, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -118,13 +118,15 @@ class Basis(Base, abc.ABC): Raises ------ ValueError: - - If ``mode`` is not 'eval' or 'conv'. - - If ``kwargs`` are not None and ``mode =="eval"``. - - If ``kwargs`` include parameters not recognized or do not have + If ``mode`` is not 'eval' or 'conv'. + ValueError: + If ``kwargs`` are not None and ``mode =="eval"``. + ValueError: + If ``kwargs`` include parameters not recognized or do not have default values in ``create_convolutional_predictor``. - - If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). + ValueError: + If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). """ - def __init__( self, n_basis_funcs: int, @@ -175,7 +177,10 @@ def label(self) -> str: @property def n_basis_input(self) -> tuple | None: - """Number of expected inputs.""" + """Number of expected inputs. + + The number of inputs ``compute_feature`` expects. + """ if self._n_basis_input is None: return return self._n_basis_input @@ -334,7 +339,7 @@ def _get_samples(self, *n_samples: int) -> Generator[NDArray]: Returns ------- : - A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by `n_samples`. + A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by ``n_samples``. """ # handling of defaults when evaluating on a grid # (i.e. when we cannot use max and min of samples) @@ -393,42 +398,43 @@ def _check_has_kernel(self) -> None: def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. - The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. - The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing - instead of the default cartesian indexing, see Notes. - - Parameters - ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of - n_samples must equal the number of combined bases. - - Returns - ------- - *Xs : - A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, - where n equals to the number of inputs. - The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. - Y : - The basis function evaluated at the samples, - shape ``(n_samples[0], ... , n_samples[n], number of basis)``. - - Raises - ------ - ValueError - - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what - the Basis object requires. - - If one of the n_samples is <= 0. - - Notes - ----- - Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size - :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. - This differs from the numpy.meshgrid default, which uses Cartesian indexing. - For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. - - Examples - -------- + The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. + The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing + instead of the default cartesian indexing, see Notes. + + Parameters + ---------- + n_samples[0],...,n_samples[n] + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. + + Returns + ------- + *Xs : + A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, + where n equals to the number of inputs. + The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. + Y : + The basis function evaluated at the samples, + shape ``(n_samples[0], ... , n_samples[n], number of basis)``. + + Raises + ------ + ValueError + If the time point number is inconsistent between inputs or if the number of inputs doesn't match what + the Basis object requires. + ValueError + If one of the n_samples is <= 0. + + Notes + ----- + Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size + :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. + This differs from the numpy.meshgrid default, which uses Cartesian indexing. + For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. + + Examples + -------- >>> # Evaluate and visualize 4 M-spline basis functions of order 3: >>> import numpy as np >>> import matplotlib.pyplot as plt @@ -562,7 +568,7 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis: Returns ------- : - The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. + The product of the basis with itself "exponent" times. Equivalent to ``self * self * ... * self``. Raises ------ @@ -621,19 +627,19 @@ def _get_feature_slicing( Calculate and return the slicing for features based on the input structure. This method determines how to slice the features for different basis types. - If the instance is of `AdditiveBasis` type, the slicing is calculated recursively + If the instance is of ``AdditiveBasis`` type, the slicing is calculated recursively for each component basis. Otherwise, it determines the slicing based on - the number of basis functions and `split_by_input` flag. + the number of basis functions and ``split_by_input`` flag. Parameters ---------- n_inputs : - The number of input basis for each component, by default it uses `self._n_basis_input`. + The number of input basis for each component, by default it uses ``self._n_basis_input``. start_slice : The starting index for slicing, by default it starts from 0. split_by_input : Flag indicating whether to split the slicing by individual inputs or not. - If `False`, a single slice is generated for all inputs. + If ``False``, a single slice is generated for all inputs. Returns ------- @@ -737,10 +743,11 @@ def split_by_feature( **How it works:** - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will - be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions - ``(n_inputs, n_basis_funcs)``. + be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions + ``(n_inputs, n_basis_funcs)``. + - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will - be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. + be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. @@ -756,7 +763,9 @@ def split_by_feature( number of features generated by the basis, i.e., ``self.n_output_features``. **Examples:** + - For a design matrix: ``(n_samples, total_n_features)`` + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. axis : int, optional @@ -874,8 +883,8 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: This function computes the number of inputs that are provided to the basis and uses that number, and the n_basis_funcs to calculate the number of output features that - `self.compute_features` will return. These quantities and the input shape (excluding the sample axis) - are stored in `self._n_basis_input` and `self._n_output_features`, and `self._input_shape` + ``self.compute_features`` will return. These quantities and the input shape (excluding the sample axis) + are stored in ``self._n_basis_input`` and ``self._n_output_features``, and ``self._input_shape`` respectively. Parameters @@ -891,15 +900,15 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: Raises ------ ValueError: - If the number of inputs do not match `self._n_basis_input`, if `self._n_basis_input` was + If the number of inputs do not match ``self._n_basis_input``, if ``self._n_basis_input`` was not None. Notes ----- - Once a `compute_features` is called, we enforce that for all subsequent calls of the method, + Once a ``compute_features`` is called, we enforce that for all subsequent calls of the method, the input that the basis receives preserves the shape of all axes, except for the sample axis. This condition guarantees the consistency of the feature axis, and therefore that - `self.split_by_feature` behaves appropriately. + ``self.split_by_feature`` behaves appropriately. """ # Check that the input shape matches expectation @@ -1435,7 +1444,7 @@ def _set_kernel(self, *xi: ArrayLike) -> Basis: Parameters ---------- *xi: - The sample inputs. Unused, necessary to conform to `scikit-learn` API. + The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. Returns ------- @@ -1460,52 +1469,50 @@ def split_by_feature( **How It Works:** - Suppose the basis is made up of **m components**, each with $b_i$ basis functions and $n_i$ inputs. - The total number of features, $N$, is calculated as: + Suppose the basis is made up of **m components**, each with :math:`b_i` basis functions and :math:`n_i` inputs. + The total number of features, :math:`N`, is calculated as: - $$ - N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m - $$ + .. math:: + N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m - This method splits any axis of length $N$ into sub-arrays, one for each basis component. + This method splits any axis of length :math:`N` into sub-arrays, one for each basis component. The sub-array for the i-th basis component is reshaped into dimensions - $(n_i, b_i)$. + :math:`(n_i, b_i)`. - For example, if the array shape is $(1, 2, N, 4, 5)$, then each split sub-array will have shape: + For example, if the array shape is :math:`(1, 2, N, 4, 5)`, then each split sub-array will have shape: - $$ - (1, 2, n_i, b_i, 4, 5) - $$ + .. math:: + (1, 2, n_i, b_i, 4, 5) where: - - $n_i$ represents the number of inputs associated with the i-th component, - - $b_i$ represents the number of basis functions in that component. + - :math:`n_i` represents the number of inputs associated with the i-th component, + - :math:`b_i` represents the number of basis functions in that component. - The specified axis (`axis`) determines where the split occurs, and all other dimensions + The specified axis (``axis``) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases. Parameters ---------- x : The input array to be split, representing concatenated features, coefficients, - or other data. The shape of `x` along the specified axis must match the total - number of features generated by the basis, i.e., `self.n_output_features`. + or other data. The shape of ``x`` along the specified axis must match the total + number of features generated by the basis, i.e., ``self.n_output_features``. **Examples:** - - For a design matrix: `(n_samples, total_n_features)` - - For model coefficients: `(total_n_features,)` or `(total_n_features, n_neurons)`. + - For a design matrix: ``(n_samples, total_n_features)`` + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. axis : int, optional The axis along which to split the features. Defaults to 1. - Use `axis=1` for design matrices (features along columns) and `axis=0` for + Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for coefficient arrays (features along rows). All other dimensions are preserved. Raises ------ ValueError - If the shape of `x` along the specified axis does not match `self.n_output_features`. + If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. Returns ------- @@ -1514,12 +1521,11 @@ def split_by_feature( - **Keys**: Labels of the additive basis components. - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: - $$ - (..., n_i, b_i, ...) - $$ + .. math:: + (..., n_i, b_i, ...) - - `n_i`: The number of inputs processed by the i-th basis component. - - `b_i`: The number of basis functions for the i-th basis component. + - ``n_i``: The number of inputs processed by the i-th basis component. + - ``b_i``: The number of basis functions for the i-th basis component. These sub-arrays are reshaped along the specified axis, with all other dimensions remaining the same. @@ -1577,7 +1583,7 @@ class MultiplicativeBasis(Basis): Attributes ---------- - n_basis_funcs : int + n_basis_funcs : Number of basis functions. Examples @@ -1622,7 +1628,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis: Parameters ---------- *xi: - The sample inputs. Unused, necessary to conform to `scikit-learn` API. + The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. Returns ------- From 9d7d34f3c6a56e6baa447f3537d42804bafc44f8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 11:56:43 -0500 Subject: [PATCH 011/109] fix minor changes --- README.md | 10 +++---- docs/background/plot_02_ND_basis_function.md | 6 ++-- docs/how_to_guide/plot_02_glm_demo.md | 2 +- docs/how_to_guide/plot_06_glm_pytree.md | 2 +- docs/quickstart.md | 30 ++++++++++---------- tests/conftest.py | 2 +- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 1dbeab9f..e9abf895 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,9 @@ In this example, we'll construct a time-series of features using the basis objec import nemos as nmo # Instantiate the basis -basis_1 = nemos.basis.basis.EvalMSpline(n_basis_funcs=5) -basis_2 = nemos.basis.basis.CyclicBSplineBasis(n_basis_funcs=6) -basis_3 = nemos.basis.basis.EvalMSpline(n_basis_funcs=7) +basis_1 = nmo.basis.MSplineBasis(n_basis_funcs=5) +basis_2 = nmo.basis.CyclicBSplineBasis(n_basis_funcs=6) +basis_3 = nmo.basis.MSplineBasis(n_basis_funcs=7) basis = basis_1 * basis_2 + basis_3 @@ -111,8 +111,8 @@ import nemos as nmo # generate 5 basis functions of 100 time-bins, # and convolve the counts with the basis. -X = nemos.basis.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100 - ).compute_features(spike_counts) +X = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=100 + ).compute_features(spike_counts) ``` #### Population GLM diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/plot_02_ND_basis_function.md index f6948b5b..dd2ca2c9 100644 --- a/docs/background/plot_02_ND_basis_function.md +++ b/docs/background/plot_02_ND_basis_function.md @@ -339,9 +339,9 @@ will output a $K^N \times T$ matrix. T = 10 n_basis = 8 -a_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -b_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -c_basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +a_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +b_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +c_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) diff --git a/docs/how_to_guide/plot_02_glm_demo.md b/docs/how_to_guide/plot_02_glm_demo.md index 0fcecc91..fe18d3ec 100644 --- a/docs/how_to_guide/plot_02_glm_demo.md +++ b/docs/how_to_guide/plot_02_glm_demo.md @@ -329,7 +329,7 @@ coupling_filter_bank *= 0.8 # define a basis function n_basis_funcs = 20 -basis = nemos.basis.basis.RaisedCosineBasisLog(n_basis_funcs) +basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 6616ac1f..6945460e 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -283,7 +283,7 @@ Let's create our basis and then arrange our data properly. unit_no = 7 spikes = nwb['units'][unit_no] -basis = nemos.basis.basis.CyclicBSplineBasis(10, order=5) +basis = nmo.basis.CyclicBSplineBasis(10, order=5) x = np.linspace(-np.pi, np.pi, 100) plt.figure() plt.plot(x, basis(x)) diff --git a/docs/quickstart.md b/docs/quickstart.md index 0bddcf7d..062bdb25 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -162,10 +162,10 @@ you need to specify the number of basis functions. For some `basis` objects, add ```python ->> > import nemos as nmo +>>> import nemos as nmo ->> > n_basis_funcs = 10 ->> > basis = nemos.basis.basis.RaisedCosineBasisLinear(n_basis_funcs) +>>> n_basis_funcs = 10 +>>> basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs) ``` @@ -201,11 +201,11 @@ number of sample points. ```python ->> > import nemos as nmo +>>> import nemos as nmo ->> > n_basis_funcs = 10 ->> > # define a filter bank of 10 basis function, 200 samples long. ->> > basis = nemos.basis.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) +>>> n_basis_funcs = 10 +>>> # define a filter bank of 10 basis function, 200 samples long. +>>> basis = nmo.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) ``` @@ -339,21 +339,21 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa >>> path = nmo.fetch.fetch_data("A2929-200711.nwb") >>> data = nap.load_file(path) ->>> # load spikes and head direction +>>> # load spikes and head direction >>> spikes = data["units"] >>> head_dir = data["ry"] ->>> # restrict and bin +>>> # restrict and bin >>> counts = spikes[6].count(0.01, ep=head_dir.time_support) ->>> # down-sample head direction ->>> upsampled_head_dir = head_dir.bin_average(0.01) +>>> # down-sample head direction +>>> upsampled_head_dir = head_dir.bin_average(0.01) ->>> # create your features ->>> X = nemos.basis.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) +>>> # create your features +>>> X = nmo.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) ->>> # add a neuron axis and fit model ->>> model = nmo.glm.GLM().fit(X, counts) +>>> # add a neuron axis and fit model +>>> model = nmo.glm.GLM().fit(X, counts) ``` diff --git a/tests/conftest.py b/tests/conftest.py index e3debe48..77af28b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -286,7 +286,7 @@ def coupled_model_simulate(): ) # shrink the filters for simulation stability coupling_filter_bank *= 0.8 - basis = nemos.basis.basis.RaisedCosineBasisLog(20) + basis = nmo.basis.RaisedCosineBasisLog(20) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) From 25abf3d9ff8b0b84151a46f0d8bf8e10431d76e6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 12:00:54 -0500 Subject: [PATCH 012/109] added back feature matrix --- src/nemos/typing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index f1cfc4fc..9cd83a26 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -6,6 +6,9 @@ import jaxopt from jax.typing import ArrayLike +import pynapple as nap +from statsmodels.tools.typing import NDArray + from .pytrees import FeaturePytree DESIGN_INPUT_TYPE = Union[jnp.ndarray, FeaturePytree] @@ -51,3 +54,5 @@ ], # Step-size for optimization (must be a float) Tuple[jnp.ndarray, jnp.ndarray], ] + +FeatureMatrix = nap.TsdFrame | NDArray \ No newline at end of file From 13bc62a73c597d5c4be2a86e7ae911c5330a2df8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 12:04:45 -0500 Subject: [PATCH 013/109] updated mixin description --- src/nemos/basis/_basis_mixin.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 35ea3087..9c0e1a6a 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -15,7 +15,7 @@ def _compute_features(self, *xi: ArrayLike): """ Apply the basis transformation to the input data. - The basis evaluated at the samples, or $b_i(*xi)$, where $b_i$ is a + The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. Parameters @@ -27,8 +27,8 @@ def _compute_features(self, *xi: ArrayLike): Returns ------- : - A matrix with the transformed features. Faturehe basis evaluated at the samples, - or $b_i(*xi)$, where $b_i$ is a basis element. xi[k] must be a one-dimensional array + A matrix with the transformed features. The basis evaluated at the samples, + or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. """ @@ -79,7 +79,7 @@ def _check_convolution_kwargs(self): Raises ------ ValueError: - If `self._conv_kwargs` are not None. + If ``self._conv_kwargs`` are not None. """ # this should not be hit since **kwargs are not allowed at EvalBasis init. # still keep it for compliance with Abstract class Basis. @@ -105,8 +105,8 @@ def _compute_features(self, *xi: ArrayLike): except for the sample-axis are flattened, so that the method always returns a matrix. For example, if samples are of shape (num_samples, 2, 3), the output will be (num_samples, num_basis_funcs * 2 * 3). - The time-axis can be specified at basis initialization by setting the keyword argument `axis`. - For example, if `axis == 1` your samples should be (N1, num_samples N3, ...), the output of + The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. + For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of transform will be (num_samples, num_basis_funcs * N1 * N3 *...). Parameters @@ -188,10 +188,11 @@ def _check_convolution_kwargs(self): Raises ------ ValueError: - - If `axis` is provided as an argument, and it is different from 0 + If ``axis`` is provided as an argument, and it is different from 0 (samples must always be in the first axis). - - If `self._conv_kwargs` include parameters not recognized or that do not have - default values in `create_convolutional_predictor`. + ValueError: + If ``self._conv_kwargs`` include parameters not recognized or that do not have + default values in ``create_convolutional_predictor``. """ if "axis" in self._conv_kwargs: raise ValueError( From 7c45ab740ab97a65d5b99f417e0a7c8809bd0299 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 12:09:54 -0500 Subject: [PATCH 014/109] fixed docstrings orthexp --- src/nemos/basis/_decaying_exponential.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 8acc515f..72ae7199 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -39,12 +39,23 @@ class OrthExponentialBasis(Basis, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes that the convolution axis is ``axis=0``. + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import OrthExponentialBasis + >>> X = np.random.normal(size=(1000, 1)) + >>> n_basis_funcs = 5 + >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates + >>> window_size=10 + >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = ortho_basis(sample_points) """ def __init__( @@ -69,7 +80,7 @@ def __init__( def decay_rates(self): r"""Decay rate. - The rate of decay of the exponential functions. If :math:`f_i(t) = \exp{-\alpha_i t}` is the i-th decay + The rate of decay of the exponential functions. If :math:`f_i(t) = e^{-\alpha_i t}` is the i-th decay exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector. """ return self._decay_rates @@ -153,14 +164,14 @@ def __call__( Parameters ---------- sample_pts - Spacing for basis functions, holding elements on the interval [0, - inf), shape (n_samples,). + Spacing for basis functions, holding elements on the interval :math:`[0,inf)`, + shape ``(n_samples,)``. Returns ------- basis_funcs Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) + orthogonalized, shape ``(n_samples, n_basis_funcs)``. """ self._check_sample_size(sample_pts) From 6dbe3625aa15b91808bd3794e6f08cdaf45e9b58 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 12:14:35 -0500 Subject: [PATCH 015/109] updated raised cos --- src/nemos/basis/_raised_cosine_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 64dd492a..58ffafa5 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -252,7 +252,7 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes From 113ca768d82773a5f5cc18a0f2c9f8db177e956a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 12:27:06 -0500 Subject: [PATCH 016/109] fixed spline docstrings --- src/nemos/basis/_spline_basis.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index b23d6a80..8b7a0688 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -191,7 +191,7 @@ class MSplineBasis(SplineBasis, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs: - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -297,10 +297,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: ------- X : NDArray A 1D array of uniformly spaced sample points within the domain [0, 1]. - Shape: `(n_samples,)`. + Shape: ``(n_samples,)``. Y : NDArray A 2D array where each row corresponds to the evaluated M-spline basis - function values at the points in X. Shape: `(n_samples, n_basis_funcs)`. + function values at the points in X. Shape: ``(n_samples, n_basis_funcs)``. Examples -------- @@ -351,7 +351,7 @@ class BSplineBasis(SplineBasis, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -421,7 +421,7 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: Notes ----- - The evaluation is performed by looping over each element and using `splev` + The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. """ sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) @@ -444,14 +444,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Returns ------- X : - Array of shape (n_samples,) containing the equi-spaced sample + Array of shape ``(n_samples,)`` containing the equi-spaced sample points where we've evaluated the basis. basis_funcs : - Raised cosine basis functions, shape (n_samples, n_basis_funcs) + Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)`` Notes ----- - The evaluation is performed by looping over each element and using `splev` from + The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. Examples @@ -490,7 +490,7 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when ``mode='conv'``; These arguments are used to change the default behavior of the convolution. For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes @@ -560,7 +560,7 @@ def __call__( Notes ----- - The evaluation is performed by looping over each element and using `splev` from + The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. """ @@ -608,14 +608,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Returns ------- X : - Array of shape (n_samples,) containing the equi-spaced sample + Array of shape ``(n_samples,)`` containing the equi-spaced sample points where we've evaluated the basis. basis_funcs : - Raised cosine basis functions, shape (n_samples, n_basis_funcs) + Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)`` Notes ----- - The evaluation is performed by looping over each element and using `splev` from + The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. Examples From db188ade32cb97929813ed2ad02130fd1ee4138f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 16:21:22 -0500 Subject: [PATCH 017/109] moving stuff around --- src/nemos/basis/_basis.py | 212 ++++++++++++----------- src/nemos/basis/_basis_mixin.py | 15 -- src/nemos/basis/_decaying_exponential.py | 30 ++-- src/nemos/basis/_raised_cosine_basis.py | 26 +++ src/nemos/basis/_spline_basis.py | 10 -- src/nemos/basis/basis.py | 146 ++++------------ 6 files changed, 187 insertions(+), 252 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 26a5db31..a757e4af 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -152,11 +152,6 @@ def __init__( self.kernel_ = None - @abc.abstractmethod - def _check_convolution_kwargs(self): - """Check convolution kwargs settings.""" - pass - @property def n_output_features(self) -> int | None: """ @@ -267,19 +262,6 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: input samples with the basis functions. The output shape varies based on the subclass and mode. - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import EvalBSpline - - >>> # Generate data - >>> num_samples = 10000 - >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalBSpline(10) - >>> features = basis.compute_features(X) # basis transformed time series - >>> features.shape - (10000, 10) - Notes ----- Subclasses should implement how to handle the transformation specific to their @@ -343,10 +325,11 @@ def _get_samples(self, *n_samples: int) -> Generator[NDArray]: """ # handling of defaults when evaluating on a grid # (i.e. when we cannot use max and min of samples) - if self.bounds is None: + bounds = getattr(self, "bounds", None) + if bounds is None: mn, mx = 0, 1 else: - mn, mx = self.bounds + mn, mx = bounds return (np.linspace(mn, mx, n_samples[k]) for k in range(len(n_samples))) @support_pynapple(conv_type="numpy") @@ -398,55 +381,20 @@ def _check_has_kernel(self) -> None: def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. - The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. - The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing - instead of the default cartesian indexing, see Notes. - Parameters ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of - n_samples must equal the number of combined bases. + n_samples : + The number of samples. Returns ------- - *Xs : - A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, - where n equals to the number of inputs. - The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. - Y : - The basis function evaluated at the samples, - shape ``(n_samples[0], ... , n_samples[n], number of basis)``. - - Raises - ------ - ValueError - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what - the Basis object requires. - ValueError - If one of the n_samples is <= 0. - - Notes - ----- - Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size - :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. - This differs from the numpy.meshgrid default, which uses Cartesian indexing. - For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. - - Examples - -------- - >>> # Evaluate and visualize 4 M-spline basis functions of order 3: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalMSpline - >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> p = plt.plot(sample_points, basis_values) - >>> _ = plt.title('M-Spline Basis Functions') - >>> _ = plt.xlabel('Domain') - >>> _ = plt.ylabel('Basis Function Value') - >>> _ = plt.legend([f'Function {i+1}' for i in range(4)]); - """ + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + Evaluated exponentially decaying basis functions, numerically + orthogonalized, shape (n_samples, n_basis_funcs) + """ self._check_input_dimensionality(n_samples) if self._has_zero_samples(n_samples): @@ -784,38 +732,6 @@ def split_by_feature( A dictionary where: - **Key**: Label of the basis. - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import ConvBSpline - >>> from nemos.glm import GLM - >>> # Define an additive basis - >>> basis = ConvBSpline(n_basis_funcs=5, window_size=10, label="feature") - >>> # Generate a sample input array and compute features - >>> x = np.random.randn(20) - >>> X = basis.compute_features(x) - >>> # Split the feature matrix along axis 1 - >>> split_features = basis.split_by_feature(X, axis=1) - >>> for feature, arr in split_features.items(): - ... print(f"{feature}: shape {arr.shape}") - feature: shape (20, 1, 5) - >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = ConvBSpline(n_basis_funcs=6, window_size=10, - ... label="multi_input") - >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) - >>> for feature, sub_dict in split_features_multi.items(): - ... print(f"{feature}, shape {sub_dict.shape}") - multi_input, shape (20, 2, 6) - >>> # the method can be used to decompose the glm coefficients in the various features - >>> counts = np.random.poisson(size=20) - >>> model = GLM().fit(X, counts) - >>> split_coef = basis.split_by_feature(model.coef_, axis=0) - >>> for feature, coef in split_coef.items(): - ... print(f"{feature}: shape {coef.shape}") - feature: shape (1, 5) - """ if x.shape[axis] != self.n_output_features: raise ValueError( @@ -1569,6 +1485,46 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. + The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing + instead of the default cartesian indexing, see Notes. + + Parameters + ---------- + n_samples[0],...,n_samples[n] + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. + + Returns + ------- + *Xs : + A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, + where n equals to the number of inputs. + The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. + Y : + The basis function evaluated at the samples, + shape ``(n_samples[0], ... , n_samples[n], number of basis)``. + + Raises + ------ + ValueError + If the time point number is inconsistent between inputs or if the number of inputs doesn't match what + the Basis object requires. + ValueError + If one of the n_samples is <= 0. + + Notes + ----- + Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size + :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. + This differs from the numpy.meshgrid default, which uses Cartesian indexing. + For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. + """ + return super().evaluate_on_grid(*n_samples) + class MultiplicativeBasis(Basis): """ @@ -1656,6 +1612,14 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix: ------- : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + >>> mult_basis = nmo.basis.EvalBSpline(5) * nmo.basis.EvalRaisedCosineLinear(6) + >>> x, y = np.random.randn(2, 30) + >>> X = mult_basis(x, y) """ X = np.asarray( row_wise_kron( @@ -1681,6 +1645,13 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: : The features, shape (n_samples, n_basis_funcs) + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + >>> mult_basis = nmo.basis.EvalBSpline(5) * nmo.basis.EvalRaisedCosineLinear(6) + >>> x, y = np.random.randn(2, 30) + >>> X = mult_basis.compute_features(x, y) """ kron = support_pynapple(conv_type="numpy")(row_wise_kron) X = kron( @@ -1702,4 +1673,51 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: self._n_output_features = ( self._basis1.n_output_features * self._basis2.n_output_features ) - return self \ No newline at end of file + return self + + def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. + The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing + instead of the default cartesian indexing, see Notes. + + Parameters + ---------- + n_samples[0],...,n_samples[n] + The number of samples in each axis of the grid. The length of + n_samples must equal the number of combined bases. + + Returns + ------- + *Xs : + A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, + where n equals to the number of inputs. + The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. + Y : + The basis function evaluated at the samples, + shape ``(n_samples[0], ... , n_samples[n], number of basis)``. + + Raises + ------ + ValueError + If the time point number is inconsistent between inputs or if the number of inputs doesn't match what + the Basis object requires. + ValueError + If one of the n_samples is <= 0. + + Notes + ----- + Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size + :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. + This differs from the numpy.meshgrid default, which uses Cartesian indexing. + For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. + + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + >>> mult_basis = nmo.basis.EvalBSpline(4) * nmo.basis.EvalRaisedCosineLinear(5) + >>> X, Y, Z = mult_basis.evaluate_on_grid(10, 10) + """ + return super().evaluate_on_grid(*n_samples) \ No newline at end of file diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 9c0e1a6a..5ecf5f4f 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -73,21 +73,6 @@ def bounds(self, values: Union[None, Tuple[float, float]]): f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." ) - def _check_convolution_kwargs(self): - """Check convolution kwargs settings. - - Raises - ------ - ValueError: - If ``self._conv_kwargs`` are not None. - """ - # this should not be hit since **kwargs are not allowed at EvalBasis init. - # still keep it for compliance with Abstract class Basis. - if self._conv_kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" - ) - class ConvBasisMixin: diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 72ae7199..8342b9ef 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -17,6 +17,22 @@ from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples +import inspect + + +def add_orth_exp_decay_docstring(method_name): + attr = getattr(OrthExponentialBasis, method_name, None) + if attr is None: + raise AttributeError(f"OrthExponentialBasis has no attribute {method_name}!") + doc = attr.__doc__ + # Decorator to add the docstring + def wrapper(func): + func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings + return func + + return wrapper + + class OrthExponentialBasis(Basis, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -44,18 +60,6 @@ class OrthExponentialBasis(Basis, abc.ABC): For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. Note that one cannot change the default value for the ``axis`` parameter. Basis assumes that the convolution axis is ``axis=0``. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import OrthExponentialBasis - >>> X = np.random.normal(size=(1000, 1)) - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = ortho_basis(sample_points) """ def __init__( @@ -175,7 +179,7 @@ def __call__( """ self._check_sample_size(sample_pts) - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) valid_idx = ~np.isnan(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of # shape (n_pts, n_basis_funcs) and then transpose, rather than diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 58ffafa5..7994e0cd 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -16,6 +16,32 @@ import abc +def add_raised_cosine_linear_docstring(method_name): + attr = getattr(RaisedCosineBasisLinear, method_name, None) + if attr is None: + raise AttributeError(f"RaisedCosineBasisLinear has no attribute {method_name}!") + doc = attr.__doc__ + # Decorator to add the docstring + def wrapper(func): + func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings + return func + + return wrapper + + +def add_raised_cosine_log_docstring(method_name): + attr = getattr(RaisedCosineBasisLog, method_name, None) + if attr is None: + raise AttributeError(f"RaisedCosineBasisLog has no attribute {method_name}!") + doc = attr.__doc__ + # Decorator to add the docstring + def wrapper(func): + func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings + return func + + return wrapper + + class RaisedCosineBasisLinear(Basis, abc.ABC): """Represent linearly-spaced raised cosine basis functions. diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 8b7a0688..6149c733 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -341,12 +341,6 @@ class BSplineBasis(SplineBasis, abc.ABC): Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. The B-splines have (order-2) continuous derivatives at each interior knot. The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -382,8 +376,6 @@ def __init__( n_basis_funcs: int, mode="eval", order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineBasis", **kwargs, ): @@ -391,8 +383,6 @@ def __init__( n_basis_funcs, mode=mode, order=order, - window_size=window_size, - bounds=bounds, label=label, **kwargs, ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index e36b0642..959aedf7 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -12,7 +12,7 @@ from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog -from ._decaying_exponential import OrthExponentialBasis +from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring from ..typing import FeatureMatrix __all__ = [ @@ -31,6 +31,7 @@ ] + def __dir__() -> list[str]: return __all__ @@ -357,6 +358,7 @@ def split_by_feature( - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions ``(n_inputs, n_basis_funcs)``. + - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. @@ -369,30 +371,32 @@ def split_by_feature( Parameters ---------- x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of ``x`` along the specified axis must match the total - number of features generated by the basis, i.e., ``self.n_output_features``. + The input array to be split, representing concatenated features, coefficients, + or other data. The shape of ``x`` along the specified axis must match the total + number of features generated by the basis, i.e., ``self.n_output_features``. + + **Examples:** - **Examples:** - - For a design matrix: ``(n_samples, total_n_features)`` - - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. + - For a design matrix: ``(n_samples, total_n_features)`` + + - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. axis : int, optional - The axis along which to split the features. Defaults to 1. - Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for - coefficient arrays (features along rows). All other dimensions are preserved. + The axis along which to split the features. Defaults to 1. + Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for + coefficient arrays (features along rows). All other dimensions are preserved. Raises ------ ValueError - If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. + If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. Returns ------- dict - A dictionary where: - - **Key**: Label of the basis. - - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` + A dictionary where: + - **Key**: Label of the basis. + - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` Examples -------- @@ -415,26 +419,7 @@ def split_by_feature( class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): - """Set of 1D basis decaying exponential functions numerically orthogonalized. - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - window_size : - The window size for convolution as number of samples. - decay_rates : - Decay rates of the exponentials, shape ``(n_basis_funcs,)``. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - conv_kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - + """ Examples -------- >>> import numpy as np @@ -466,23 +451,9 @@ def __init__( label=label, ) + @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Generate basis functions with given spacing. - - Parameters - ---------- - n_samples: - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape ``(n_samples, n_basis_funcs)`` - + """ Examples -------- >>> import numpy as np @@ -495,26 +466,11 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return super().evaluate_on_grid(n_samples=n_samples) + return super().evaluate_on_grid(n_samples) - def compute_features(*xi: ArrayLike) -> FeatureMatrix: + @add_orth_exp_decay_docstring("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ - Compute the basis functions and transform input data into model features. - - This method is designed to be a high-level interface for transforming input - data using the basis functions defined by the subclass. Performs a convolution operation between - the input data and the basis functions. - - Parameters - ---------- - *xi : - Input data arrays to be transformed. - - Returns - ------- - : - Transformed features, consisting of convolved input samples with the basis functions. - Examples -------- >>> import numpy as np @@ -529,61 +485,15 @@ def compute_features(*xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return super().compute_features(*xi) + return OrthExponentialBasis.compute_features(self, *xi) + @add_orth_exp_decay_docstring("split_by_feature") def split_by_feature( self, x: NDArray, axis: int = 1, ): r""" - Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. - - This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. - - **How it works:** - - - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will - be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions - ``(n_inputs, n_basis_funcs)``. - - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will - be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. - - For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, - then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. - - The specified axis (``axis``) determines where the split occurs, and all other dimensions - remain unchanged. See the example section below for the most common use cases. - - Parameters - ---------- - x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of ``x`` along the specified axis must match the total - number of features generated by the basis, i.e., ``self.n_output_features``. - - **Examples:** - - For a design matrix: ``(n_samples, total_n_features)`` - - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. - - axis : int, optional - The axis along which to split the features. Defaults to 1. - Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for - coefficient arrays (features along rows). All other dimensions are preserved. - - Raises - ------ - ValueError - If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. - - Returns - ------- - dict - A dictionary where: - - **Key**: Label of the basis. - - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` - Examples -------- >>> import numpy as np @@ -597,4 +507,6 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return super().split_by_feature(x, axis=axis) \ No newline at end of file + return OrthExponentialBasis.split_by_feature(self, x, axis=axis) + + From 32bcf82a67fa1a144a9a685d7e990e32e37bb690 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 17:29:22 -0500 Subject: [PATCH 018/109] start editing docstrings --- src/nemos/basis/_basis.py | 15 ++ src/nemos/basis/_decaying_exponential.py | 37 +--- src/nemos/basis/_raised_cosine_basis.py | 46 +---- src/nemos/basis/basis.py | 210 +++++++++++++---------- tests/test_basis.py | 88 +++++++--- 5 files changed, 216 insertions(+), 180 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index a757e4af..5a91e73a 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -19,6 +19,21 @@ from ..typing import FeatureMatrix from ..validation import check_fraction_valid_samples + + +def add_docstring(method_name, cls=None): + attr = getattr(cls, method_name, None) + if attr is None: + raise AttributeError(f"{cls.__name__} has no attribute {method_name}!") + doc = attr.__doc__ + # Decorator to add the docstring + def wrapper(func): + func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings + return func + + return wrapper + + def check_transform_input(func: Callable) -> Callable: """Check input before calling basis. diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 8342b9ef..532a26e7 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -4,6 +4,7 @@ from __future__ import annotations import abc +from functools import partial from typing import Optional, Tuple import numpy as np @@ -14,23 +15,7 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples - - -import inspect - - -def add_orth_exp_decay_docstring(method_name): - attr = getattr(OrthExponentialBasis, method_name, None) - if attr is None: - raise AttributeError(f"OrthExponentialBasis has no attribute {method_name}!") - doc = attr.__doc__ - # Decorator to add the docstring - def wrapper(func): - func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings - return func - - return wrapper +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring class OrthExponentialBasis(Basis, abc.ABC): @@ -198,21 +183,5 @@ def __call__( basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) return basis_funcs - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - Returns - ------- - X : - Array of shape (n_samples,) containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape (n_samples, n_basis_funcs) - """ - return super().evaluate_on_grid(n_samples) +add_orth_exp_decay_docstring = partial(add_docstring, cls=OrthExponentialBasis) \ No newline at end of file diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 7994e0cd..3bed1d46 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -7,39 +7,14 @@ import numpy as np from numpy.typing import ArrayLike, NDArray - from ..type_casting import support_pynapple from ..typing import FeatureMatrix -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring import abc - -def add_raised_cosine_linear_docstring(method_name): - attr = getattr(RaisedCosineBasisLinear, method_name, None) - if attr is None: - raise AttributeError(f"RaisedCosineBasisLinear has no attribute {method_name}!") - doc = attr.__doc__ - # Decorator to add the docstring - def wrapper(func): - func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings - return func - - return wrapper - - -def add_raised_cosine_log_docstring(method_name): - attr = getattr(RaisedCosineBasisLog, method_name, None) - if attr is None: - raise AttributeError(f"RaisedCosineBasisLog has no attribute {method_name}!") - doc = attr.__doc__ - # Decorator to add the docstring - def wrapper(func): - func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings - return func - - return wrapper +from functools import partial class RaisedCosineBasisLinear(Basis, abc.ABC): @@ -216,14 +191,6 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: points where we've evaluated the basis. basis_funcs : Raised cosine basis functions, shape (n_samples, n_basis_funcs) - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import RaisedCosineBasisLinear - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -368,7 +335,7 @@ def _transform_samples( """ # rescale to [0,1] # copy is necessary to avoid unwanted rescaling in additive/multiplicative basis. - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) + sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), getattr(self, "bounds", None)) # This log-stretching of the sample axis has the following effect: # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. # - as the time_scaling tends to inf, basis will be small and dense around 0 and @@ -423,4 +390,9 @@ def __call__( ValueError If the sample provided do not lie in [0,1]. """ - return super().__call__(self._transform_samples(sample_pts)) \ No newline at end of file + return super().__call__(self._transform_samples(sample_pts)) + + +add_raised_cosine_linear_docstring = partial(add_docstring, cls=RaisedCosineBasisLinear) + +add_raised_cosine_log_docstring = partial(add_docstring, cls=RaisedCosineBasisLog) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 959aedf7..e48c372e 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -11,7 +11,7 @@ from ._basis_mixin import EvalBasisMixin, ConvBasisMixin from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis -from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog, add_raised_cosine_linear_docstring, add_raised_cosine_log_docstring from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring from ..typing import FeatureMatrix @@ -205,6 +205,64 @@ def __init__( label=label, ) + @add_raised_cosine_log_docstring("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalRaisedCosineLog + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = EvalRaisedCosineLog(n_basis_funcs) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) + + @add_raised_cosine_log_docstring("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLog + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalRaisedCosineLog(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return RaisedCosineBasisLog.compute_features(self, *xi) + + @add_raised_cosine_log_docstring("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.glm import GLM + >>> basis = EvalRaisedCosineLog(n_basis_funcs=6, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) + class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): def __init__( @@ -228,6 +286,64 @@ def __init__( label=label, ) + @add_raised_cosine_log_docstring("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvRaisedCosineLog + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = ConvRaisedCosineLog(n_basis_funcs, window_size) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) + + @add_raised_cosine_log_docstring("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLog + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvRaisedCosineLog(10, window_size=100) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return RaisedCosineBasisLog.compute_features(self, *xi) + + @add_raised_cosine_log_docstring("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLog + >>> from nemos.glm import GLM + >>> basis = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) + class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): def __init__( @@ -277,23 +393,9 @@ def __init__( label=label, ) + @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Generate basis functions with given spacing. - - Parameters - ---------- - n_samples: - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape ``(n_samples, n_basis_funcs)`` - + """ Examples -------- >>> import numpy as np @@ -308,33 +410,18 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples=n_samples) + @add_orth_exp_decay_docstring("compute_features") def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ - Compute the basis functions and transform input data into model features. - - This method is designed to be a high-level interface for transforming input - data using the basis functions defined by the subclass. It evaluates the basis functions at the sample - points. - - Parameters - ---------- - *xi : - Input data arrays to be transformed. - - Returns - ------- - : - Transformed features, consisting of the basis functions evaluated at the input samples. - Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import EvalOrthExponential >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvOrthExponential(10, window_size=100, decay_rates=np.arange(1, 11)) + >>> basis = EvalOrthExponential(10, decay_rates=np.arange(1, 11)) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) @@ -342,62 +429,13 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ return super().compute_features(*xi) + @add_orth_exp_decay_docstring("split_by_feature") def split_by_feature( self, x: NDArray, axis: int = 1, ): - r""" - Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. - - This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. - - **How it works:** - - - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will - be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions - ``(n_inputs, n_basis_funcs)``. - - - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will - be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. - - For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, - then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. - - The specified axis (``axis``) determines where the split occurs, and all other dimensions - remain unchanged. See the example section below for the most common use cases. - - Parameters - ---------- - x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of ``x`` along the specified axis must match the total - number of features generated by the basis, i.e., ``self.n_output_features``. - - **Examples:** - - - For a design matrix: ``(n_samples, total_n_features)`` - - - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. - - axis : int, optional - The axis along which to split the features. Defaults to 1. - Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for - coefficient arrays (features along rows). All other dimensions are preserved. - - Raises - ------ - ValueError - If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. - - Returns - ------- - dict - A dictionary where: - - **Key**: Label of the basis. - - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` - + """ Examples -------- >>> import numpy as np diff --git a/tests/test_basis.py b/tests/test_basis.py index 156a8292..3c65bed1 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,6 +1,7 @@ import abc import inspect import itertools +from functools import partial import pickle from contextlib import nullcontext as does_not_raise from typing import Literal @@ -15,7 +16,7 @@ import nemos.basis.basis as basis import nemos.convolve as convolve from nemos.utils import pynapple_concatenate_numpy - +from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring # automatic define user accessible basis and check the methods def list_all_basis_classes() -> list[type]: @@ -26,10 +27,9 @@ def list_all_basis_classes() -> list[type]: return [ class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis) - if issubclass(class_obj, basis.Basis) + if issubclass(class_obj, Basis) ] - def test_all_basis_are_tested() -> None: """Meta-test. @@ -60,6 +60,48 @@ def test_all_basis_are_tested() -> None: ) +@pytest.mark.parametrize( + "basis_instance", + [ + basis.EvalRaisedCosineLog(10), + basis.ConvRaisedCosineLog(10, window_size=11), + basis.EvalOrthExponential(10, np.arange(1, 11)), + basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + ] +) +@pytest.mark.parametrize( + "method_name", ["evaluate_on_grid", "compute_features", "split_by_feature"] +) +def test_example_docstrings_add(basis_instance, method_name): + method = getattr(basis_instance, method_name) + doc = method.__doc__ + examp_delim = "\n Examples\n --------" + assert examp_delim in doc + doc_components = doc.split(examp_delim) + assert len(doc_components) == 2 + assert len(doc_components[0].strip()) > 0 + assert basis_instance.__class__.__name__ in doc_components[1] + + +def test_add_docstring(): + + class CustomClass: + def method(self): + """My extra text.""" + pass + + custom_add_docstring = partial(add_docstring, cls=CustomClass) + + class CustomSubClass(CustomClass): + @custom_add_docstring("method") + def method(self): + """My custom method.""" + pass + + assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method." + + + class BasisFuncsTesting(abc.ABC): """ An abstract base class that sets the foundation for individual basis function testing. @@ -4963,13 +5005,13 @@ def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): basis_obj = basis_class( n_basis_funcs=n_basis, order=3, mode=mode, window_size=window_size ) - elif basis_class == basis.AdditiveBasis: + elif basis_class == AdditiveBasis: b1 = basis.EvalMSpline( n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size ) b2 = basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis + 1) basis_obj = b1 + b2 - elif basis_class == basis.MultiplicativeBasis: + elif basis_class == MultiplicativeBasis: b1 = basis.EvalMSpline( n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size ) @@ -4983,7 +5025,7 @@ def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): class TestAdditiveBasis(CombinedBasis): - cls = basis.AdditiveBasis + cls = AdditiveBasis @pytest.mark.parametrize( "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] @@ -5525,7 +5567,7 @@ def test_expected_input_number(self, n_input, expectation): class TestMultiplicativeBasis(CombinedBasis): - cls = basis.MultiplicativeBasis + cls = MultiplicativeBasis @pytest.mark.parametrize( "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] @@ -6103,7 +6145,7 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize( - "exponent", [-1, 0, 0.5, basis.RaisedCosineBasisLog(4), 1, 2, 3] + "exponent", [-1, 0, 0.5, basis.EvalRaisedCosineLog(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) def test_power_of_basis(exponent, basis_class): @@ -6385,7 +6427,7 @@ def test_transformerbasis_addition(basis_cls): trans_bas_b = basis.TransformerBasis(basis_cls(n_basis_funcs_b)) trans_bas_sum = trans_bas_a + trans_bas_b assert isinstance(trans_bas_sum, basis.TransformerBasis) - assert isinstance(trans_bas_sum._basis, basis.AdditiveBasis) + assert isinstance(trans_bas_sum._basis, AdditiveBasis) assert ( trans_bas_sum.n_basis_funcs == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs @@ -6415,7 +6457,7 @@ def test_transformerbasis_multiplication(basis_cls): trans_bas_b = basis.TransformerBasis(basis_cls(n_basis_funcs_b)) trans_bas_prod = trans_bas_a * trans_bas_b assert isinstance(trans_bas_prod, basis.TransformerBasis) - assert isinstance(trans_bas_prod._basis, basis.MultiplicativeBasis) + assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) assert ( trans_bas_prod.n_basis_funcs == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs @@ -6456,7 +6498,7 @@ def test_transformerbasis_exponentiation( with pytest.raises(error_type, match=error_message): trans_bas_exp = trans_bas**exponent assert isinstance(trans_bas_exp, basis.TransformerBasis) - assert isinstance(trans_bas_exp._basis, basis.MultiplicativeBasis) + assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) @pytest.mark.parametrize( @@ -6568,15 +6610,15 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, basis.RaisedCosineBasisLog, - basis.AdditiveBasis, - basis.MultiplicativeBasis, + AdditiveBasis, + MultiplicativeBasis, ], ) def test_multi_epoch_pynapple_basis( basis_cls, tsd, window_size, shift, predictor_causality, nan_index ): """Test nan location in multi-epoch pynapple tsd.""" - if basis_cls == basis.AdditiveBasis: + if basis_cls == AdditiveBasis: bas = basis.BSplineBasis( 5, mode="conv", @@ -6591,7 +6633,7 @@ def test_multi_epoch_pynapple_basis( predictor_causality=predictor_causality, shift=shift, ) - elif basis_cls == basis.MultiplicativeBasis: + elif basis_cls == MultiplicativeBasis: bas = basis.RaisedCosineBasisLog( 5, mode="conv", @@ -6656,15 +6698,15 @@ def test_multi_epoch_pynapple_basis( basis.CyclicBSplineBasis, basis.RaisedCosineBasisLinear, basis.RaisedCosineBasisLog, - basis.AdditiveBasis, - basis.MultiplicativeBasis, + AdditiveBasis, + MultiplicativeBasis, ], ) def test_multi_epoch_pynapple_basis_transformer( basis_cls, tsd, window_size, shift, predictor_causality, nan_index ): """Test nan location in multi-epoch pynapple tsd.""" - if basis_cls == basis.AdditiveBasis: + if basis_cls == AdditiveBasis: bas = basis.BSplineBasis( 5, mode="conv", @@ -6679,7 +6721,7 @@ def test_multi_epoch_pynapple_basis_transformer( predictor_causality=predictor_causality, shift=shift, ) - elif basis_cls == basis.MultiplicativeBasis: + elif basis_cls == MultiplicativeBasis: bas = basis.RaisedCosineBasisLog( 5, mode="conv", @@ -6804,7 +6846,7 @@ def test__get_splitter( ): # skip nested if any( - bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2, bas3] ): return @@ -6981,7 +7023,7 @@ def test__get_splitter_split_by_input( ): # skip nested if any( - bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2] ): return @@ -7034,7 +7076,7 @@ def test__get_splitter_split_by_input( def test_duplicate_keys(bas1, bas2, bas3): # skip nested if any( - bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2, bas3] ): return @@ -7086,7 +7128,7 @@ def test_duplicate_keys(bas1, bas2, bas3): def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes): # skip nested if any( - bas in (basis.AdditiveBasis, basis.MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2] ): return From 5761ccdd5025e834d6b0cfe1193ac45bf0a06c88 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 18:41:35 -0500 Subject: [PATCH 019/109] improved tests --- src/nemos/basis/_decaying_exponential.py | 19 + src/nemos/basis/_raised_cosine_basis.py | 3 +- src/nemos/basis/_spline_basis.py | 86 +--- src/nemos/basis/basis.py | 512 ++++++++++++++++++++++- tests/test_basis.py | 20 +- 5 files changed, 570 insertions(+), 70 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 532a26e7..cf72c0a1 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -183,5 +183,24 @@ def __call__( basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) return basis_funcs + def evaluate_on_grid(self, n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: + """Evaluate the basis set on a grid of equi-spaced sample points. + + Parameters + ---------- + n_samples : + The number of points in the uniformly spaced grid. A higher number of + samples will result in a more detailed visualization of the basis functions. + + Returns + ------- + X : + Array of shape (n_samples,) containing the equi-spaced sample + points where we've evaluated the basis. + basis_funcs : + OrthExponential basis functions, shape (n_samples, n_basis_funcs). + """ + return super().evaluate_on_grid(n_samples) + add_orth_exp_decay_docstring = partial(add_docstring, cls=OrthExponentialBasis) \ No newline at end of file diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 3bed1d46..f25449c3 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -182,7 +182,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Parameters ---------- n_samples : - The number of samples. + The number of points in the uniformly spaced grid. A higher number of + samples will result in a more detailed visualization of the basis functions. Returns ------- diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 6149c733..413e870d 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -4,6 +4,7 @@ import abc import copy +from functools import partial from typing import Optional, Tuple import numpy as np @@ -13,7 +14,7 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples +from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring class SplineBasis(Basis, abc.ABC): @@ -183,10 +184,6 @@ class MSplineBasis(SplineBasis, abc.ABC): derivatives at each interior knot, resulting in smoother basis functions. window_size : The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -197,16 +194,6 @@ class MSplineBasis(SplineBasis, abc.ABC): Note that one cannot change the default value for the ``axis`` parameter. Basis assumes that the convolution axis is ``axis=0``. - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import MSplineBasis - >>> n_basis_funcs = 5 - >>> order = 3 - >>> mspline_basis = MSplineBasis(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) - References ---------- .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, @@ -219,13 +206,22 @@ class MSplineBasis(SplineBasis, abc.ABC): will shrink by a factor of :math:`1/\alpha`. For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = mspline_basis(sample_points) """ def __init__( self, n_basis_funcs: int, order: int = 2, - bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalMSpline", **kwargs, ) -> None: @@ -233,7 +229,6 @@ def __init__( n_basis_funcs, mode="eval", order=order, - bounds=bounds, label=label, ) @@ -301,25 +296,6 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Y : NDArray A 2D array where each row corresponds to the evaluated M-spline basis function values at the points in X. Shape: ``(n_samples, n_basis_funcs)``. - - Examples - -------- - Evaluate and visualize 4 M-spline basis functions of order 3: - - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalMSpline - >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> for i in range(4): - ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') - >>> plt.title('M-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') - >>> plt.xlabel('Domain') - Text(0.5, 0, 'Domain') - >>> plt.ylabel('Basis Function Value') - Text(0, 0.5, 'Basis Function Value') - >>> l = plt.legend() """ return super().evaluate_on_grid(n_samples) @@ -429,7 +405,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Parameters ---------- n_samples : - The number of samples. + The number of points in the uniformly spaced grid. A higher number of + samples will result in a more detailed visualization of the basis functions. Returns ------- @@ -443,14 +420,6 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: ----- The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import BSplineBasis - >>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -470,12 +439,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. The B-splines have (order-2) continuous derivatives at each interior knot. The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -496,9 +459,9 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import CyclicBSplineBasis + >>> from nemos.basis import EvalCyclicBSpline >>> X = np.random.normal(size=(1000, 1)) - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) + >>> cyclic_basis = EvalCyclicBSpline(n_basis_funcs=5, order=3, mode="conv", window_size=10) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = cyclic_basis(sample_points) """ @@ -517,8 +480,6 @@ def __init__( n_basis_funcs, mode=mode, order=order, - window_size=window_size, - bounds=bounds, label=label, **kwargs, ) @@ -593,7 +554,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Parameters ---------- n_samples : - The number of samples. + The number of points in the uniformly spaced grid. A higher number of + samples will result in a more detailed visualization of the basis functions. Returns ------- @@ -607,14 +569,6 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: ----- The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import CyclicBSplineBasis - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100) """ return super().evaluate_on_grid(n_samples) @@ -761,3 +715,7 @@ def bspline( ) return basis_eval.T + +add_docstrings_mspline = partial(add_docstring, cls=MSplineBasis) +add_docstrings_bspline = partial(add_docstring, cls=BSplineBasis) +add_docstrings_cyclic_bspline = partial(add_docstring, cls=CyclicBSplineBasis) \ No newline at end of file diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index e48c372e..49cb657f 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -7,10 +7,9 @@ from numpy.typing import NDArray, ArrayLike - from ._basis_mixin import EvalBasisMixin, ConvBasisMixin -from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis +from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis, add_docstrings_mspline, add_docstrings_bspline, add_docstrings_cyclic_bspline from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog, add_raised_cosine_linear_docstring, add_raised_cosine_log_docstring from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring from ..typing import FeatureMatrix @@ -53,6 +52,72 @@ def __init__( label=label, ) + @add_docstrings_bspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalBSpline + >>> from nemos.glm import GLM + >>> basis = EvalBSpline(n_basis_funcs=6, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return BSplineBasis.split_by_feature(self, x, axis=axis) + + @add_docstrings_bspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalBSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalBSpline(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return BSplineBasis.compute_features(self, xi) + + @add_docstrings_bspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 B-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalBSpline + >>> bspline_basis = EvalBSpline(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('B-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return BSplineBasis.evaluate_on_grid(self, n_samples) + + class ConvBSpline(ConvBasisMixin, BSplineBasis): @@ -73,6 +138,71 @@ def __init__( label=label, ) + @add_docstrings_bspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvBSpline + >>> from nemos.glm import GLM + >>> basis = ConvBSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return BSplineBasis.split_by_feature(self, x, axis=axis) + + @add_docstrings_bspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvBSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvBSpline(10, window_size=11) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return BSplineBasis.compute_features(self, xi) + + @add_docstrings_bspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 B-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvBSpline + >>> bspline_basis = ConvBSpline(n_basis_funcs=4, order=3, window_size=10) + >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('B-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return BSplineBasis.evaluate_on_grid(self, n_samples) + class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): def __init__( @@ -91,6 +221,71 @@ def __init__( label=label, ) + @add_docstrings_cyclic_bspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalCyclicBSpline + >>> from nemos.glm import GLM + >>> basis = EvalCyclicBSpline(n_basis_funcs=6, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) + + @add_docstrings_cyclic_bspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalCyclicBSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalCyclicBSpline(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return CyclicBSplineBasis.compute_features(self, xi) + + @add_docstrings_cyclic_bspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 Cyclic B-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalCyclicBSpline + >>> cbspline_basis = EvalCyclicBSpline(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = cbspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('Cyclic B-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) + class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): def __init__( @@ -110,6 +305,72 @@ def __init__( label=label, ) + @add_docstrings_cyclic_bspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvCyclicBSpline + >>> from nemos.glm import GLM + >>> basis = ConvCyclicBSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) + + @add_docstrings_cyclic_bspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvCyclicBSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvCyclicBSpline(10, window_size=11) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return CyclicBSplineBasis.compute_features(self, xi) + + @add_docstrings_cyclic_bspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 Cyclic B-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvCyclicBSpline + >>> cbspline_basis = ConvCyclicBSpline(n_basis_funcs=4, order=3, window_size=10) + >>> sample_points, basis_values = cbspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('Cyclic B-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) + + class EvalMSpline(EvalBasisMixin, MSplineBasis): def __init__( @@ -128,6 +389,71 @@ def __init__( label=label, ) + @add_docstrings_mspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline + >>> from nemos.glm import GLM + >>> basis = EvalMSpline(n_basis_funcs=6, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return EvalMSpline.split_by_feature(self, x, axis=axis) + + @add_docstrings_mspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalMSpline(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return MSplineBasis.compute_features(self, xi) + + @add_docstrings_mspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 M-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalMSpline + >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) + >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('M-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return MSplineBasis.evaluate_on_grid(self, n_samples) + class ConvMSpline(ConvBasisMixin, MSplineBasis): def __init__( @@ -147,6 +473,71 @@ def __init__( label=label, ) + @add_docstrings_mspline("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvMSpline + >>> from nemos.glm import GLM + >>> basis = ConvMSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return MSplineBasis.split_by_feature(self, x, axis=axis) + + @add_docstrings_mspline("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvMSpline + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvMSpline(10, window_size=11) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return MSplineBasis.compute_features(self, xi) + + @add_docstrings_mspline("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + Evaluate and visualize 4 M-spline basis functions of order 3: + + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvMSpline + >>> mspline_basis = ConvMSpline(n_basis_funcs=4, order=3, window_size=10) + >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) + >>> for i in range(4): + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + >>> plt.title('M-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') + >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') + >>> plt.ylabel('Basis Function Value') + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() + """ + return MSplineBasis.evaluate_on_grid(self, n_samples) + class EvalRaisedCosineLinear(EvalBasisMixin, RaisedCosineBasisLinear): def __init__( @@ -165,6 +556,64 @@ def __init__( label=label, ) + @add_raised_cosine_linear_docstring("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import EvalRaisedCosineLinear + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = EvalRaisedCosineLinear(n_basis_funcs) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) + + @add_raised_cosine_linear_docstring("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLinear + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalRaisedCosineLinear(10) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return RaisedCosineBasisLinear.compute_features(self, *xi) + + @add_raised_cosine_linear_docstring("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLinear + >>> from nemos.glm import GLM + >>> basis = EvalRaisedCosineLinear(n_basis_funcs=6, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) + class ConvRaisedCosineLinear(ConvBasisMixin, RaisedCosineBasisLinear): def __init__( @@ -184,6 +633,65 @@ def __init__( label=label, ) + @add_raised_cosine_linear_docstring("evaluate_on_grid") + def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: + """ + Examples + -------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from nemos.basis import ConvRaisedCosineLinear + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size=10 + >>> ortho_basis = ConvRaisedCosineLinear(n_basis_funcs, window_size) + >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) + + """ + return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) + + @add_raised_cosine_linear_docstring("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLinear + + >>> # Generate data + >>> num_samples = 1000 + >>> X = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = ConvRaisedCosineLinear(10, window_size=100) + >>> features = basis.compute_features(X) # basis transformed time series + >>> features.shape + (1000, 10) + + """ + return RaisedCosineBasisLinear.compute_features(self, *xi) + + @add_raised_cosine_linear_docstring("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLinear + >>> from nemos.glm import GLM + >>> basis = ConvRaisedCosineLinear(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> X_multi = basis.compute_features(np.random.randn(20, 2)) + >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> for feature, sub_dict in split_features_multi.items(): + ... print(f"{feature}, shape {sub_dict.shape}") + two_inputs, shape (20, 2, 6) + + """ + return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) + + class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): def __init__( self, diff --git a/tests/test_basis.py b/tests/test_basis.py index 3c65bed1..470ca9e9 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -5,7 +5,7 @@ import pickle from contextlib import nullcontext as does_not_raise from typing import Literal - +import re import jax.numpy import numpy as np import pynapple as nap @@ -63,6 +63,14 @@ def test_all_basis_are_tested() -> None: @pytest.mark.parametrize( "basis_instance", [ + basis.EvalBSpline(10), + basis.ConvBSpline(10, window_size=11), + basis.EvalCyclicBSpline(10), + basis.ConvCyclicBSpline(10, window_size=11), + basis.EvalMSpline(10), + basis.ConvMSpline(10, window_size=11), + basis.EvalRaisedCosineLinear(10), + basis.ConvRaisedCosineLinear(10, window_size=11), basis.EvalRaisedCosineLog(10), basis.ConvRaisedCosineLog(10, window_size=11), basis.EvalOrthExponential(10, np.arange(1, 11)), @@ -70,9 +78,14 @@ def test_all_basis_are_tested() -> None: ] ) @pytest.mark.parametrize( - "method_name", ["evaluate_on_grid", "compute_features", "split_by_feature"] + "method_name, descr_match", + [ + ("evaluate_on_grid", ".+The number of points in the uniformly spaced grid"), + ("compute_features", "Compute the basis functions and transform input data into model features"), + ("split_by_feature", "Decompose an array along a specified axis into sub-arrays") + ] ) -def test_example_docstrings_add(basis_instance, method_name): +def test_example_docstrings_add(basis_instance, method_name, descr_match): method = getattr(basis_instance, method_name) doc = method.__doc__ examp_delim = "\n Examples\n --------" @@ -80,6 +93,7 @@ def test_example_docstrings_add(basis_instance, method_name): doc_components = doc.split(examp_delim) assert len(doc_components) == 2 assert len(doc_components[0].strip()) > 0 + assert re.search(descr_match, doc_components[0]) assert basis_instance.__class__.__name__ in doc_components[1] From 649c0d6d8526c9f63e27d4b2ab4c8e7815e95574 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 18:41:46 -0500 Subject: [PATCH 020/109] improved tests --- tests/test_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 470ca9e9..6f6643c8 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -80,7 +80,7 @@ def test_all_basis_are_tested() -> None: @pytest.mark.parametrize( "method_name, descr_match", [ - ("evaluate_on_grid", ".+The number of points in the uniformly spaced grid"), + ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), ("compute_features", "Compute the basis functions and transform input data into model features"), ("split_by_feature", "Decompose an array along a specified axis into sub-arrays") ] From 8a7f2595e218e2cde17e89073542a1587921b0a6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 18:47:39 -0500 Subject: [PATCH 021/109] improved tests --- tests/test_basis.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_basis.py b/tests/test_basis.py index 6f6643c8..f5d30fb4 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -89,13 +89,22 @@ def test_example_docstrings_add(basis_instance, method_name, descr_match): method = getattr(basis_instance, method_name) doc = method.__doc__ examp_delim = "\n Examples\n --------" + assert examp_delim in doc doc_components = doc.split(examp_delim) assert len(doc_components) == 2 assert len(doc_components[0].strip()) > 0 assert re.search(descr_match, doc_components[0]) + + # check that the basis name is in the example assert basis_instance.__class__.__name__ in doc_components[1] + # check that no other basis name is in the example + for basis_name in basis.__dir__(): + if basis_name == basis_instance.__class__.__name__: + continue + assert basis_name not in doc_components[1] + def test_add_docstring(): From ab56c9b5c28f50f88f23e120ce6714727acc1fec Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 19:03:38 -0500 Subject: [PATCH 022/109] fixed eval changed testing --- src/nemos/basis/_raised_cosine_basis.py | 2 +- src/nemos/basis/_spline_basis.py | 6 +++--- src/nemos/basis/basis.py | 10 ++++----- tests/test_basis.py | 28 +++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index f25449c3..8a214fe8 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -147,7 +147,7 @@ def __call__( # basis2 = nmo.basis.RaisedCosineBasisLog(5) # additive_basis = basis1 + basis2 # additive_basis(*([x] * 2)) would modify both inputs - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) + sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), getattr(self, "bounds", None)) peaks = self._compute_peaks() delta = peaks[1] - peaks[0] diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 413e870d..ac87d273 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -258,7 +258,7 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: conditions are handled such that the basis functions are positive and integrate to one over the domain defined by the sample points. """ - sample_pts, scaling = min_max_rescale_samples(sample_pts, self.bounds) + sample_pts, scaling = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) # add knots if not passed knot_locs = self._generate_knots(is_cyclic=False) @@ -390,7 +390,7 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) # add knots knot_locs = self._generate_knots(is_cyclic=False) @@ -515,7 +515,7 @@ def __call__( SciPy to compute the basis values. """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) + sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) knot_locs = self._generate_knots(is_cyclic=True) # for cyclic, do not repeat knots diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 49cb657f..73bb677b 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -409,7 +409,7 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return EvalMSpline.split_by_feature(self, x, axis=axis) + return MSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_mspline("compute_features") def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: @@ -916,7 +916,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return super().evaluate_on_grid(n_samples=n_samples) + return OrthExponentialBasis.evaluate_on_grid(self, n_samples=n_samples) @add_orth_exp_decay_docstring("compute_features") def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: @@ -935,7 +935,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return super().compute_features(*xi) + return OrthExponentialBasis.compute_features(self, *xi) @add_orth_exp_decay_docstring("split_by_feature") def split_by_feature( @@ -961,7 +961,7 @@ def split_by_feature( feature: shape (20, 1, 5) """ - return super().split_by_feature(x, axis=axis) + return OrthExponentialBasis.split_by_feature(self, x, axis=axis) class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): @@ -1012,7 +1012,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return super().evaluate_on_grid(n_samples) + return OrthExponentialBasis.evaluate_on_grid(self, n_samples) @add_orth_exp_decay_docstring("compute_features") def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: diff --git a/tests/test_basis.py b/tests/test_basis.py index f5d30fb4..90df413e 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -15,6 +15,11 @@ import nemos.basis.basis as basis import nemos.convolve as convolve +from nemos.basis import EvalMSpline +from nemos.basis._spline_basis import BSplineBasis, MSplineBasis, CyclicBSplineBasis +from nemos.basis._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from nemos.basis._decaying_exponential import OrthExponentialBasis + from nemos.utils import pynapple_concatenate_numpy from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring @@ -124,6 +129,29 @@ def method(self): assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method." +@pytest.mark.parametrize( + "basis_instance, super_class", + [ + (basis.EvalBSpline(10), BSplineBasis), + (basis.ConvBSpline(10, window_size=11), BSplineBasis), + (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), + (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), + (basis.EvalMSpline(10), MSplineBasis), + (basis.ConvMSpline(10, window_size=11), MSplineBasis), + (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), + (basis.ConvRaisedCosineLinear(10, window_size=11),RaisedCosineBasisLinear), + (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), + (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), + (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis), + ] +) +def test_expected_output_eval_on_grid(basis_instance, super_class): + x, y = super_class.evaluate_on_grid(basis_instance, 100) + xx, yy = basis_instance.evaluate_on_grid(100) + np.testing.assert_equal(xx, x) + np.testing.assert_equal(yy, y) + class BasisFuncsTesting(abc.ABC): """ From 593d7a96298d0961de7121fddfa28e11e5f1bfae Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 25 Nov 2024 19:23:14 -0500 Subject: [PATCH 023/109] added some basic testing --- src/nemos/basis/basis.py | 36 +++++++++++++------------- tests/test_basis.py | 55 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 18 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 73bb677b..4eca4c89 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -75,7 +75,7 @@ def split_by_feature( return BSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_bspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -161,7 +161,7 @@ def split_by_feature( return BSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_bspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -244,7 +244,7 @@ def split_by_feature( return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_cyclic_bspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -328,7 +328,7 @@ def split_by_feature( return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_cyclic_bspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -412,7 +412,7 @@ def split_by_feature( return MSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_mspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -496,7 +496,7 @@ def split_by_feature( return MSplineBasis.split_by_feature(self, x, axis=axis) @add_docstrings_mspline("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -574,7 +574,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) @add_raised_cosine_linear_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -590,7 +590,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return RaisedCosineBasisLinear.compute_features(self, *xi) + return RaisedCosineBasisLinear.compute_features(self, xi) @add_raised_cosine_linear_docstring("split_by_feature") def split_by_feature( @@ -651,7 +651,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) @add_raised_cosine_linear_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -667,7 +667,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return RaisedCosineBasisLinear.compute_features(self, *xi) + return RaisedCosineBasisLinear.compute_features(self, xi) @add_raised_cosine_linear_docstring("split_by_feature") def split_by_feature( @@ -731,7 +731,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) @add_raised_cosine_log_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -747,7 +747,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return RaisedCosineBasisLog.compute_features(self, *xi) + return RaisedCosineBasisLog.compute_features(self, xi) @add_raised_cosine_log_docstring("split_by_feature") def split_by_feature( @@ -812,7 +812,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) @add_raised_cosine_log_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -828,7 +828,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return RaisedCosineBasisLog.compute_features(self, *xi) + return RaisedCosineBasisLog.compute_features(self, xi) @add_raised_cosine_log_docstring("split_by_feature") def split_by_feature( @@ -919,7 +919,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return OrthExponentialBasis.evaluate_on_grid(self, n_samples=n_samples) @add_orth_exp_decay_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -935,7 +935,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return OrthExponentialBasis.compute_features(self, *xi) + return OrthExponentialBasis.compute_features(self, xi) @add_orth_exp_decay_docstring("split_by_feature") def split_by_feature( @@ -1015,7 +1015,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return OrthExponentialBasis.evaluate_on_grid(self, n_samples) @add_orth_exp_decay_docstring("compute_features") - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- @@ -1031,7 +1031,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: (1000, 10) """ - return OrthExponentialBasis.compute_features(self, *xi) + return OrthExponentialBasis.compute_features(self, xi) @add_orth_exp_decay_docstring("split_by_feature") def split_by_feature( diff --git a/tests/test_basis.py b/tests/test_basis.py index 90df413e..b170847b 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -153,6 +153,61 @@ def test_expected_output_eval_on_grid(basis_instance, super_class): np.testing.assert_equal(yy, y) +@pytest.mark.parametrize( + "basis_instance, super_class", + [ + (basis.EvalBSpline(10), BSplineBasis), + (basis.ConvBSpline(10, window_size=11), BSplineBasis), + (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), + (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), + (basis.EvalMSpline(10), MSplineBasis), + (basis.ConvMSpline(10, window_size=11), MSplineBasis), + (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), + (basis.ConvRaisedCosineLinear(10, window_size=11),RaisedCosineBasisLinear), + (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), + (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), + (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis), + ] +) +def test_expected_output_compute_features(basis_instance, super_class): + x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) + xx = basis_instance.compute_features(np.linspace(0, 1, 100)) + nans = np.isnan(x.sum(axis=1)) + assert np.all(np.isnan(xx[nans])) + np.testing.assert_array_equal(xx[~nans], x[~nans]) + + +@pytest.mark.parametrize( + "basis_instance, super_class", + [ + (basis.EvalBSpline(10, label="label"), BSplineBasis), + (basis.ConvBSpline(10, window_size=11, label="label"), BSplineBasis), + (basis.EvalCyclicBSpline(10, label="label"), CyclicBSplineBasis), + (basis.ConvCyclicBSpline(10, window_size=11, label="label"), CyclicBSplineBasis), + (basis.EvalMSpline(10, label="label"), MSplineBasis), + (basis.ConvMSpline(10, window_size=11, label="label"), MSplineBasis), + (basis.EvalRaisedCosineLinear(10, label="label"), RaisedCosineBasisLinear), + (basis.ConvRaisedCosineLinear(10, window_size=11, label="label"),RaisedCosineBasisLinear), + (basis.EvalRaisedCosineLog(10, label="label"), RaisedCosineBasisLog), + (basis.ConvRaisedCosineLog(10, window_size=11, label="label"), RaisedCosineBasisLog), + (basis.EvalOrthExponential(10, np.arange(1, 11), label="label"), OrthExponentialBasis), + (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12, label="label"), OrthExponentialBasis), + ] +) +def test_expected_output_split_by_feature(basis_instance, super_class): + x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) + xdict = super_class.split_by_feature(basis_instance, x) + xxdict = basis_instance.split_by_feature(x) + assert xdict.keys() == xxdict.keys() + xx = xxdict["label"] + x = xdict["label"] + nans = np.isnan(x.sum(axis=(1,2))) + assert np.all(np.isnan(xx[nans])) + np.testing.assert_array_equal(xx[~nans], x[~nans]) + + + class BasisFuncsTesting(abc.ABC): """ An abstract base class that sets the foundation for individual basis function testing. From 5d3e746ce73e4c3d2a27198925fe89b7b9c88182 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 09:06:02 -0500 Subject: [PATCH 024/109] linted --- src/nemos/basis/__init__.py | 24 ++- src/nemos/basis/_basis.py | 9 +- src/nemos/basis/_basis_mixin.py | 15 +- src/nemos/basis/_decaying_exponential.py | 16 +- src/nemos/basis/_raised_cosine_basis.py | 24 ++- src/nemos/basis/_spline_basis.py | 26 ++- src/nemos/basis/basis.py | 224 ++++++++++++----------- src/nemos/basis_old.py | 2 +- src/nemos/typing.py | 5 +- tests/test_basis.py | 89 ++++++--- tests/test_simulation.py | 2 +- 11 files changed, 252 insertions(+), 184 deletions(-) diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index 175baaa9..fa3fc70d 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -1,9 +1,17 @@ -from .basis import (EvalMSpline, ConvMSpline, - EvalCyclicBSpline, ConvCyclicBSpline, - EvalBSpline, ConvBSpline, - EvalRaisedCosineLinear, ConvRaisedCosineLinear, - EvalRaisedCosineLog, ConvRaisedCosineLog, - EvalOrthExponential, ConvOrthExponential) -from ._basis import AdditiveBasis, MultiplicativeBasis, Basis -from ._spline_basis import BSplineBasis +from ._basis import AdditiveBasis, Basis, MultiplicativeBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from ._spline_basis import BSplineBasis +from .basis import ( + ConvBSpline, + ConvCyclicBSpline, + ConvMSpline, + ConvOrthExponential, + ConvRaisedCosineLinear, + ConvRaisedCosineLog, + EvalBSpline, + EvalCyclicBSpline, + EvalMSpline, + EvalOrthExponential, + EvalRaisedCosineLinear, + EvalRaisedCosineLog, +) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 5a91e73a..a17c5f9b 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -3,7 +3,6 @@ import abc import copy - from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union @@ -14,18 +13,17 @@ from ..base_class import Base from ..type_casting import support_pynapple - -from ..utils import row_wise_kron from ..typing import FeatureMatrix +from ..utils import row_wise_kron from ..validation import check_fraction_valid_samples - def add_docstring(method_name, cls=None): attr = getattr(cls, method_name, None) if attr is None: raise AttributeError(f"{cls.__name__} has no attribute {method_name}!") doc = attr.__doc__ + # Decorator to add the docstring def wrapper(func): func.__doc__ = "\n".join([doc, func.__doc__]) # Combine docstrings @@ -142,6 +140,7 @@ class Basis(Base, abc.ABC): ValueError: If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). """ + def __init__( self, n_basis_funcs: int, @@ -1735,4 +1734,4 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: >>> mult_basis = nmo.basis.EvalBSpline(4) * nmo.basis.EvalRaisedCosineLinear(5) >>> X, Y, Z = mult_basis.evaluate_on_grid(10, 10) """ - return super().evaluate_on_grid(*n_samples) \ No newline at end of file + return super().evaluate_on_grid(*n_samples) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 5ecf5f4f..4ed23bb7 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -1,10 +1,13 @@ """Mixin classes for basis.""" +import inspect +from typing import Optional, Tuple, Union + +import numpy as np from numpy.typing import ArrayLike + from ..convolve import create_convolutional_predictor -import numpy as np -from typing import Union, Tuple, Optional -import inspect + class EvalBasisMixin: @@ -104,9 +107,7 @@ def _compute_features(self, *xi: ArrayLike): # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. - conv = create_convolutional_predictor( - self.kernel_, *xi, **self._conv_kwargs - ) + conv = create_convolutional_predictor(self.kernel_, *xi, **self._conv_kwargs) # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) @@ -207,4 +208,4 @@ def _check_convolution_kwargs(self): raise ValueError( f"Unrecognized keyword arguments: {invalid}. " f"Allowed convolution keyword arguments are: {convolve_configs}." - ) \ No newline at end of file + ) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index cf72c0a1..e8e95093 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -11,11 +11,15 @@ import scipy.linalg from numpy.typing import NDArray - from ..type_casting import support_pynapple from ..typing import FeatureMatrix - -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring +from ._basis import ( + Basis, + add_docstring, + check_one_dimensional, + check_transform_input, + min_max_rescale_samples, +) class OrthExponentialBasis(Basis, abc.ABC): @@ -164,7 +168,9 @@ def __call__( """ self._check_sample_size(sample_pts) - sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) + sample_pts, _ = min_max_rescale_samples( + sample_pts, getattr(self, "bounds", None) + ) valid_idx = ~np.isnan(sample_pts) # because of how scipy.linalg.orth works, have to create a matrix of # shape (n_pts, n_basis_funcs) and then transpose, rather than @@ -203,4 +209,4 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(n_samples) -add_orth_exp_decay_docstring = partial(add_docstring, cls=OrthExponentialBasis) \ No newline at end of file +add_orth_exp_decay_docstring = partial(add_docstring, cls=OrthExponentialBasis) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 8a214fe8..7bab06e3 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -1,7 +1,8 @@ - # required to get ArrayLike to render correctly from __future__ import annotations +import abc +from functools import partial from typing import Optional, Tuple import numpy as np @@ -9,12 +10,13 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix - - -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring -import abc - -from functools import partial +from ._basis import ( + Basis, + add_docstring, + check_one_dimensional, + check_transform_input, + min_max_rescale_samples, +) class RaisedCosineBasisLinear(Basis, abc.ABC): @@ -147,7 +149,9 @@ def __call__( # basis2 = nmo.basis.RaisedCosineBasisLog(5) # additive_basis = basis1 + basis2 # additive_basis(*([x] * 2)) would modify both inputs - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), getattr(self, "bounds", None)) + sample_pts, _ = min_max_rescale_samples( + np.copy(sample_pts), getattr(self, "bounds", None) + ) peaks = self._compute_peaks() delta = peaks[1] - peaks[0] @@ -336,7 +340,9 @@ def _transform_samples( """ # rescale to [0,1] # copy is necessary to avoid unwanted rescaling in additive/multiplicative basis. - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), getattr(self, "bounds", None)) + sample_pts, _ = min_max_rescale_samples( + np.copy(sample_pts), getattr(self, "bounds", None) + ) # This log-stretching of the sample axis has the following effect: # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. # - as the time_scaling tends to inf, basis will be small and dense around 0 and diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index ac87d273..0eba604b 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -1,4 +1,3 @@ - # required to get ArrayLike to render correctly from __future__ import annotations @@ -11,10 +10,15 @@ from numpy.typing import ArrayLike, NDArray from scipy.interpolate import splev - from ..type_casting import support_pynapple from ..typing import FeatureMatrix -from ._basis import Basis, check_transform_input, check_one_dimensional, min_max_rescale_samples, add_docstring +from ._basis import ( + Basis, + add_docstring, + check_one_dimensional, + check_transform_input, + min_max_rescale_samples, +) class SplineBasis(Basis, abc.ABC): @@ -258,7 +262,9 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: conditions are handled such that the basis functions are positive and integrate to one over the domain defined by the sample points. """ - sample_pts, scaling = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) + sample_pts, scaling = min_max_rescale_samples( + sample_pts, getattr(self, "bounds", None) + ) # add knots if not passed knot_locs = self._generate_knots(is_cyclic=False) @@ -390,7 +396,9 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: The evaluation is performed by looping over each element and using ``splev`` from SciPy to compute the basis values. """ - sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) + sample_pts, _ = min_max_rescale_samples( + sample_pts, getattr(self, "bounds", None) + ) # add knots knot_locs = self._generate_knots(is_cyclic=False) @@ -515,7 +523,9 @@ def __call__( SciPy to compute the basis values. """ - sample_pts, _ = min_max_rescale_samples(sample_pts, getattr(self, "bounds", None)) + sample_pts, _ = min_max_rescale_samples( + sample_pts, getattr(self, "bounds", None) + ) knot_locs = self._generate_knots(is_cyclic=True) # for cyclic, do not repeat knots @@ -573,7 +583,6 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) - def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray: """Compute M-spline basis function. @@ -716,6 +725,7 @@ def bspline( return basis_eval.T + add_docstrings_mspline = partial(add_docstring, cls=MSplineBasis) add_docstrings_bspline = partial(add_docstring, cls=BSplineBasis) -add_docstrings_cyclic_bspline = partial(add_docstring, cls=CyclicBSplineBasis) \ No newline at end of file +add_docstrings_cyclic_bspline = partial(add_docstring, cls=CyclicBSplineBasis) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 4eca4c89..258dad5e 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -5,14 +5,25 @@ from typing import Optional, Tuple -from numpy.typing import NDArray, ArrayLike +from numpy.typing import ArrayLike, NDArray -from ._basis_mixin import EvalBasisMixin, ConvBasisMixin - -from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis, add_docstrings_mspline, add_docstrings_bspline, add_docstrings_cyclic_bspline -from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog, add_raised_cosine_linear_docstring, add_raised_cosine_log_docstring -from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring from ..typing import FeatureMatrix +from ._basis_mixin import ConvBasisMixin, EvalBasisMixin +from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring +from ._raised_cosine_basis import ( + RaisedCosineBasisLinear, + RaisedCosineBasisLog, + add_raised_cosine_linear_docstring, + add_raised_cosine_log_docstring, +) +from ._spline_basis import ( + BSplineBasis, + CyclicBSplineBasis, + MSplineBasis, + add_docstrings_bspline, + add_docstrings_cyclic_bspline, + add_docstrings_mspline, +) __all__ = [ "EvalMSpline", @@ -30,18 +41,17 @@ ] - def __dir__() -> list[str]: return __all__ class EvalBSpline(EvalBasisMixin, BSplineBasis): def __init__( - self, - n_basis_funcs: int, - order: int = 4, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalBSpline", + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalBSpline", ): EvalBasisMixin.__init__(self, bounds=bounds) BSplineBasis.__init__( @@ -54,9 +64,9 @@ def __init__( @add_docstrings_bspline("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -118,16 +128,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return BSplineBasis.evaluate_on_grid(self, n_samples) - - class ConvBSpline(ConvBasisMixin, BSplineBasis): def __init__( - self, - n_basis_funcs: int, - window_size: int, - order: int = 4, - label: Optional[str] = "ConvBSpline", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvBSpline", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( @@ -140,9 +148,9 @@ def __init__( @add_docstrings_bspline("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -206,11 +214,11 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): def __init__( - self, - n_basis_funcs: int, - order: int = 4, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalCyclicBSpline", + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalCyclicBSpline", ): EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( @@ -223,9 +231,9 @@ def __init__( @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -289,12 +297,12 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): def __init__( - self, - n_basis_funcs: int, - window_size: int, - order: int = 4, - label: Optional[str] = "ConvCyclicBSpline", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvCyclicBSpline", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( @@ -307,9 +315,9 @@ def __init__( @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -371,14 +379,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) - class EvalMSpline(EvalBasisMixin, MSplineBasis): def __init__( - self, - n_basis_funcs: int, - order: int = 4, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalMSpline", + self, + n_basis_funcs: int, + order: int = 4, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalMSpline", ): EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( @@ -457,12 +464,12 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class ConvMSpline(ConvBasisMixin, MSplineBasis): def __init__( - self, - n_basis_funcs: int, - window_size: int, - order: int = 4, - label: Optional[str] = "ConvMSpline", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + order: int = 4, + label: Optional[str] = "ConvMSpline", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( @@ -475,9 +482,9 @@ def __init__( @add_docstrings_mspline("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -541,11 +548,11 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class EvalRaisedCosineLinear(EvalBasisMixin, RaisedCosineBasisLinear): def __init__( - self, - n_basis_funcs: int, - width: float = 2.0, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalRaisedCosineLinear", + self, + n_basis_funcs: int, + width: float = 2.0, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalRaisedCosineLinear", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLinear.__init__( @@ -594,9 +601,9 @@ def compute_features(self, xi: ArrayLike) -> FeatureMatrix: @add_raised_cosine_linear_docstring("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -617,12 +624,12 @@ def split_by_feature( class ConvRaisedCosineLinear(ConvBasisMixin, RaisedCosineBasisLinear): def __init__( - self, - n_basis_funcs: int, - window_size: int, - width: float = 2.0, - label: Optional[str] = "ConvRaisedCosineLinear", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + width: float = 2.0, + label: Optional[str] = "ConvRaisedCosineLinear", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLinear.__init__( @@ -671,9 +678,9 @@ def compute_features(self, xi: ArrayLike) -> FeatureMatrix: @add_raised_cosine_linear_docstring("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -694,13 +701,13 @@ def split_by_feature( class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): def __init__( - self, - n_basis_funcs: int, - width: float = 2.0, - time_scaling: float = None, - enforce_decay_to_zero: bool = True, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalRaisedCosineLog", + self, + n_basis_funcs: int, + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalRaisedCosineLog", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( @@ -751,9 +758,9 @@ def compute_features(self, xi: ArrayLike) -> FeatureMatrix: @add_raised_cosine_log_docstring("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -774,14 +781,14 @@ def split_by_feature( class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): def __init__( - self, - n_basis_funcs: int, - window_size: int, - width: float = 2.0, - time_scaling: float = None, - enforce_decay_to_zero: bool = True, - label: Optional[str] = "ConvRaisedCosineLog", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + width: float = 2.0, + time_scaling: float = None, + enforce_decay_to_zero: bool = True, + label: Optional[str] = "ConvRaisedCosineLog", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( @@ -832,9 +839,9 @@ def compute_features(self, xi: ArrayLike) -> FeatureMatrix: @add_raised_cosine_log_docstring("split_by_feature") def split_by_feature( - self, - x: NDArray, - axis: int = 1, + self, + x: NDArray, + axis: int = 1, ): r""" Examples @@ -855,11 +862,11 @@ def split_by_feature( class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): def __init__( - self, - n_basis_funcs: int, - decay_rates: NDArray, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalOrthExponential", + self, + n_basis_funcs: int, + decay_rates: NDArray, + bounds: Optional[Tuple[float, float]] = None, + label: Optional[str] = "EvalOrthExponential", ): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -980,13 +987,14 @@ class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): >>> basis_functions = ortho_basis.compute_features(sample_points) """ + def __init__( - self, - n_basis_funcs: int, - window_size: int, - decay_rates: NDArray, - label: Optional[str] = "ConvOrthExponential", - conv_kwargs: Optional[dict] = None, + self, + n_basis_funcs: int, + window_size: int, + decay_rates: NDArray, + label: Optional[str] = "ConvOrthExponential", + conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) OrthExponentialBasis.__init__( @@ -1054,5 +1062,3 @@ def split_by_feature( """ return OrthExponentialBasis.split_by_feature(self, x, axis=axis) - - diff --git a/src/nemos/basis_old.py b/src/nemos/basis_old.py index f7067de8..164936c7 100644 --- a/src/nemos/basis_old.py +++ b/src/nemos/basis_old.py @@ -3279,4 +3279,4 @@ def bspline( sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der ) - return basis_eval.T \ No newline at end of file + return basis_eval.T diff --git a/src/nemos/typing.py b/src/nemos/typing.py index 9cd83a26..42314b90 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -4,9 +4,8 @@ import jax.numpy as jnp import jaxopt -from jax.typing import ArrayLike - import pynapple as nap +from jax.typing import ArrayLike from statsmodels.tools.typing import NDArray from .pytrees import FeaturePytree @@ -55,4 +54,4 @@ Tuple[jnp.ndarray, jnp.ndarray], ] -FeatureMatrix = nap.TsdFrame | NDArray \ No newline at end of file +FeatureMatrix = nap.TsdFrame | NDArray diff --git a/tests/test_basis.py b/tests/test_basis.py index b170847b..2d5fe8f7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,11 +1,12 @@ import abc import inspect import itertools -from functools import partial import pickle +import re from contextlib import nullcontext as does_not_raise +from functools import partial from typing import Literal -import re + import jax.numpy import numpy as np import pynapple as nap @@ -16,12 +17,15 @@ import nemos.basis.basis as basis import nemos.convolve as convolve from nemos.basis import EvalMSpline -from nemos.basis._spline_basis import BSplineBasis, MSplineBasis, CyclicBSplineBasis -from nemos.basis._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring from nemos.basis._decaying_exponential import OrthExponentialBasis - +from nemos.basis._raised_cosine_basis import ( + RaisedCosineBasisLinear, + RaisedCosineBasisLog, +) +from nemos.basis._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from nemos.utils import pynapple_concatenate_numpy -from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring + # automatic define user accessible basis and check the methods def list_all_basis_classes() -> list[type]: @@ -35,6 +39,7 @@ def list_all_basis_classes() -> list[type]: if issubclass(class_obj, Basis) ] + def test_all_basis_are_tested() -> None: """Meta-test. @@ -80,15 +85,21 @@ def test_all_basis_are_tested() -> None: basis.ConvRaisedCosineLog(10, window_size=11), basis.EvalOrthExponential(10, np.arange(1, 11)), basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), - ] + ], ) @pytest.mark.parametrize( "method_name, descr_match", [ ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), - ("compute_features", "Compute the basis functions and transform input data into model features"), - ("split_by_feature", "Decompose an array along a specified axis into sub-arrays") - ] + ( + "compute_features", + "Compute the basis functions and transform input data into model features", + ), + ( + "split_by_feature", + "Decompose an array along a specified axis into sub-arrays", + ), + ], ) def test_example_docstrings_add(basis_instance, method_name, descr_match): method = getattr(basis_instance, method_name) @@ -139,12 +150,15 @@ def method(self): (basis.EvalMSpline(10), MSplineBasis), (basis.ConvMSpline(10, window_size=11), MSplineBasis), (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11),RaisedCosineBasisLinear), + (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), - (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis), - ] + ( + basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + OrthExponentialBasis, + ), + ], ) def test_expected_output_eval_on_grid(basis_instance, super_class): x, y = super_class.evaluate_on_grid(basis_instance, 100) @@ -163,12 +177,15 @@ def test_expected_output_eval_on_grid(basis_instance, super_class): (basis.EvalMSpline(10), MSplineBasis), (basis.ConvMSpline(10, window_size=11), MSplineBasis), (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11),RaisedCosineBasisLinear), + (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), - (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis), - ] + ( + basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + OrthExponentialBasis, + ), + ], ) def test_expected_output_compute_features(basis_instance, super_class): x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) @@ -184,16 +201,33 @@ def test_expected_output_compute_features(basis_instance, super_class): (basis.EvalBSpline(10, label="label"), BSplineBasis), (basis.ConvBSpline(10, window_size=11, label="label"), BSplineBasis), (basis.EvalCyclicBSpline(10, label="label"), CyclicBSplineBasis), - (basis.ConvCyclicBSpline(10, window_size=11, label="label"), CyclicBSplineBasis), + ( + basis.ConvCyclicBSpline(10, window_size=11, label="label"), + CyclicBSplineBasis, + ), (basis.EvalMSpline(10, label="label"), MSplineBasis), (basis.ConvMSpline(10, window_size=11, label="label"), MSplineBasis), (basis.EvalRaisedCosineLinear(10, label="label"), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11, label="label"),RaisedCosineBasisLinear), + ( + basis.ConvRaisedCosineLinear(10, window_size=11, label="label"), + RaisedCosineBasisLinear, + ), (basis.EvalRaisedCosineLog(10, label="label"), RaisedCosineBasisLog), - (basis.ConvRaisedCosineLog(10, window_size=11, label="label"), RaisedCosineBasisLog), - (basis.EvalOrthExponential(10, np.arange(1, 11), label="label"), OrthExponentialBasis), - (basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12, label="label"), OrthExponentialBasis), - ] + ( + basis.ConvRaisedCosineLog(10, window_size=11, label="label"), + RaisedCosineBasisLog, + ), + ( + basis.EvalOrthExponential(10, np.arange(1, 11), label="label"), + OrthExponentialBasis, + ), + ( + basis.ConvOrthExponential( + 10, decay_rates=np.arange(1, 11), window_size=12, label="label" + ), + OrthExponentialBasis, + ), + ], ) def test_expected_output_split_by_feature(basis_instance, super_class): x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) @@ -202,12 +236,11 @@ def test_expected_output_split_by_feature(basis_instance, super_class): assert xdict.keys() == xxdict.keys() xx = xxdict["label"] x = xdict["label"] - nans = np.isnan(x.sum(axis=(1,2))) + nans = np.isnan(x.sum(axis=(1, 2))) assert np.all(np.isnan(xx[nans])) np.testing.assert_array_equal(xx[~nans], x[~nans]) - class BasisFuncsTesting(abc.ABC): """ An abstract base class that sets the foundation for individual basis function testing. @@ -5140,9 +5173,9 @@ class TestAdditiveBasis(CombinedBasis): def test_non_empty_samples(self, samples, mode, ws): if mode == "conv" and len(samples[0]) < 2: return - basis_obj = basis.EvalMSpline( + basis_obj = basis.EvalMSpline(5, mode=mode, window_size=ws) + basis.EvalMSpline( 5, mode=mode, window_size=ws - ) + basis.EvalMSpline(5, mode=mode, window_size=ws) + ) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -5682,9 +5715,9 @@ class TestMultiplicativeBasis(CombinedBasis): def test_non_empty_samples(self, samples, mode, ws): if mode == "conv" and len(samples[0]) < 2: return - basis_obj = basis.EvalMSpline( + basis_obj = basis.EvalMSpline(5, mode=mode, window_size=ws) * basis.EvalMSpline( 5, mode=mode, window_size=ws - ) * basis.EvalMSpline(5, mode=mode, window_size=ws) + ) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 939f5eab..f6444bc1 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -5,8 +5,8 @@ import numpy as np import pytest -from nemos import basis import nemos.simulation as simulation +from nemos import basis @pytest.mark.parametrize( From a1f62713ff73f6eeef8373412d56ce6a4c7a5654 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 09:06:32 -0500 Subject: [PATCH 025/109] removed basis old --- src/nemos/basis_old.py | 3282 ---------------------------------------- 1 file changed, 3282 deletions(-) delete mode 100644 src/nemos/basis_old.py diff --git a/src/nemos/basis_old.py b/src/nemos/basis_old.py deleted file mode 100644 index 164936c7..00000000 --- a/src/nemos/basis_old.py +++ /dev/null @@ -1,3282 +0,0 @@ -"""Bases classes.""" - -# required to get ArrayLike to render correctly -from __future__ import annotations - -import abc -import copy -import inspect -from functools import wraps -from typing import Callable, Generator, Literal, Optional, Tuple, Union - -import jax -import numpy as np -import scipy.linalg -from numpy.typing import ArrayLike, NDArray -from pynapple import Tsd, TsdFrame -from scipy.interpolate import splev - -from .base_class import Base -from .convolve import create_convolutional_predictor -from .type_casting import support_pynapple -from .utils import row_wise_kron -from .validation import check_fraction_valid_samples - -FeatureMatrix = Union[NDArray, TsdFrame] - -__all__ = [ - "MSplineBasis", - "BSplineBasis", - "CyclicBSplineBasis", - "RaisedCosineBasisLinear", - "RaisedCosineBasisLog", - "OrthExponentialBasis", - "AdditiveBasis", - "MultiplicativeBasis", - "TransformerBasis", -] - - -def __dir__() -> list[str]: - return __all__ - - -def check_transform_input(func: Callable) -> Callable: - """Check input before calling basis. - - This decorator allows to raise an exception that is more readable - when the wrong number of input is provided to __call__. - """ - - @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs) -> NDArray: - xi = self._check_transform_input(*xi) - return func(self, *xi, **kwargs) # Call the basis - - return wrapper - - -def check_one_dimensional(func: Callable) -> Callable: - @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs): - if any(x.ndim != 1 for x in xi): - raise ValueError("Input sample must be one dimensional!") - return func(self, *xi, **kwargs) - - return wrapper - - -def min_max_rescale_samples( - sample_pts: NDArray, - bounds: Optional[Tuple[float, float]] = None, -) -> Tuple[NDArray, float]: - """Rescale samples to [0,1]. - - Parameters - ---------- - sample_pts: - The original samples. - bounds: - Sample bounds. ``bounds[0]`` and ``bounds[1]`` are mapped to 0 and 1, respectively. - Default are ``min(sample_pts), max(sample_pts)``. - - Warns - ----- - UserWarning - If more than 90% of the sample points contain NaNs or Infs. - """ - sample_pts = sample_pts.astype(float) - vmin = np.nanmin(sample_pts) if bounds is None else bounds[0] - vmax = np.nanmax(sample_pts) if bounds is None else bounds[1] - sample_pts[(sample_pts < vmin) | (sample_pts > vmax)] = np.nan - sample_pts -= vmin - # this passes if `samples_pts` contains a single value - if vmin != vmax: - scaling = vmax - vmin - sample_pts /= scaling - else: - scaling = 1.0 - - check_fraction_valid_samples( - sample_pts, - err_msg="All the samples lie outside the [vmin, vmax] range.", - warn_msg="More than 90% of the samples lie outside the [vmin, vmax] range.", - ) - - return sample_pts, scaling - - -class TransformerBasis: - """Basis as ``scikit-learn`` transformers. - - This class abstracts the underlying basis function details, offering methods - similar to scikit-learn's transformers but specifically designed for basis - transformations. It supports fitting to data (calculating any necessary parameters - of the basis functions), transforming data (applying the basis functions to - data), and both fitting and transforming in one step. - - ``TransformerBasis``, unlike ``Basis``, is compatible with scikit-learn pipelining and - model selection, enabling the cross-validation of the basis type and parameters, - for example ``n_basis_funcs``. See the example section below. - - Parameters - ---------- - basis : - A concrete subclass of ``Basis``. - - Examples - -------- - >>> from nemos.basis import BSplineBasis, TransformerBasis - >>> from nemos.glm import GLM - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> import numpy as np - >>> np.random.seed(123) - >>> # Generate data - >>> num_samples, num_features = 10000, 1 - >>> x = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) - >>> features = basis.compute_features(x) # basis transformed time series - >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights - >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts - >>> # transformer can be used in pipelines - >>> transformer = TransformerBasis(basis) - >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) - >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API - >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas - >>> # TransformerBasis parameter can be cross-validated. - >>> # 5-fold cross-validate the number of basis - >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) - >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) - >>> grid_cv = grid_cv.fit(x[:, None], y) - >>> print("Cross-validated number of basis:", grid_cv.best_params_) - Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} - """ - - def __init__(self, basis: Basis): - self._basis = copy.deepcopy(basis) - - @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack impute without using transpose. - - Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. - - Parameters - ---------- - X: - The inputs horizontally stacked. - - Returns - ------- - : - A tuple of each individual input. - - """ - return (X[:, k] for k in range(X.shape[1])) - - def fit(self, X: FeatureMatrix, y=None): - """ - Compute the convolutional kernels. - - If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. - - Parameters - ---------- - X : - The data to fit the basis functions to, shape (num_samples, num_input). - y : ignored - Not used, present for API consistency by convention. - - Returns - ------- - self : - The transformer object. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis - >>> # Example input - >>> X = np.random.normal(size=(100, 2)) - >>> # Define and fit tranformation basis - >>> basis = MSplineBasis(10) - >>> transformer = TransformerBasis(basis) - >>> transformer_fitted = transformer.fit(X) - """ - self._basis._set_kernel(*self._unpack_inputs(X)) - return self - - def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: - """ - Transform the data using the fitted basis functions. - - Parameters - ---------- - X : - The data to transform using the basis functions, shape (num_samples, num_input). - y : - Not used, present for API consistency by convention. - - Returns - ------- - : - The data transformed by the basis functions. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis - >>> # Example input - >>> X = np.random.normal(size=(10000, 2)) - >>> # Define and fit tranformation basis - >>> basis = MSplineBasis(10, mode="conv", window_size=200) - >>> transformer = TransformerBasis(basis) - >>> # Before calling `fit` the convolution kernel is not set - >>> transformer.kernel_ - >>> transformer_fitted = transformer.fit(X) - >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) - >>> transformer_fitted.kernel_.shape - (200, 10) - >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) - """ - # transpose does not work with pynapple - # can't use func(*X.T) to unwrap - - return self._basis._compute_features(*self._unpack_inputs(X)) - - def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: - """ - Compute the kernels and the features. - - This method is a convenience that combines fit and transform into - one step. - - Parameters - ---------- - X : - The data to fit the basis functions to and then transform. - y : - Not used, present for API consistency by convention. - - Returns - ------- - array-like - The data transformed by the basis functions, after fitting the basis - functions to the data. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import MSplineBasis, TransformerBasis - >>> # Example input - >>> X = np.random.normal(size=(100, 1)) - >>> # Define tranformation basis - >>> basis = MSplineBasis(10) - >>> transformer = TransformerBasis(basis) - >>> # Fit and transform basis - >>> feature_transformed = transformer.fit_transform(X) - """ - return self._basis.compute_features(*self._unpack_inputs(X)) - - def __getstate__(self): - """ - Explicitly define how to pickle TransformerBasis object. - - See https://docs.python.org/3/library/pickle.html#object.__getstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - return {"_basis": self._basis} - - def __setstate__(self, state): - """ - Define how to populate the object's state when unpickling. - - Note that during unpickling a new object is created without calling __init__. - Needed to avoid infinite recursion in __getattr__ when unpickling. - - See https://docs.python.org/3/library/pickle.html#object.__setstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - self._basis = state["_basis"] - - def __getattr__(self, name: str): - """ - Enable easy access to attributes of the underlying Basis object. - - Examples - -------- - >>> from nemos import basis - >>> bas = basis.RaisedCosineBasisLinear(5) - >>> trans_bas = basis.TransformerBasis(bas) - >>> bas.n_basis_funcs - 5 - >>> trans_bas.n_basis_funcs - 5 - """ - return getattr(self._basis, name) - - def __setattr__(self, name: str, value) -> None: - r""" - Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. - - Setting any other attribute is not allowed. - - Returns - ------- - None - - Raises - ------ - ValueError - If the attribute being set is not ``_basis`` or an attribute of ``_basis``. - - Examples - -------- - >>> import nemos as nmo - >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.MSplineBasis(10)) - >>> # allowed - >>> trans_bas._basis = nmo.basis.BSplineBasis(10) - >>> # allowed - >>> trans_bas.n_basis_funcs = 20 - >>> # not allowed - >>> try: - ... trans_bas.random_attribute_name = "some value" - ... except ValueError as e: - ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') - """ - # allow self._basis = basis - if name == "_basis": - super().__setattr__(name, value) - # allow changing existing attributes of self._basis - elif hasattr(self._basis, name): - setattr(self._basis, name, value) - # don't allow setting any other attribute - else: - raise ValueError( - "Only setting _basis or existing attributes of _basis is allowed." - ) - - def __sklearn_clone__(self) -> TransformerBasis: - """ - Customize how TransformerBasis objects are cloned when used with sklearn.model_selection. - - By default, scikit-learn tries to clone the object by calling __init__ using the output of get_params, - which fails in our case. - - For more info: https://scikit-learn.org/stable/developers/develop.html#cloning - """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) - cloned_obj._basis.kernel_ = None - return cloned_obj - - def set_params(self, **parameters) -> TransformerBasis: - """ - Set TransformerBasis parameters. - - When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly - or the parameters of the underlying Basis, but not both. - - Examples - -------- - >>> from nemos.basis import BSplineBasis, MSplineBasis, TransformerBasis - >>> basis = MSplineBasis(10) - >>> transformer_basis = TransformerBasis(basis=basis) - >>> # setting parameters of _basis is allowed - >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) - 8 - >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)) - - >>> # mixing is not allowed, this will raise an exception - >>> try: - ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) - ... except ValueError as e: - ... print(repr(e)) - ValueError('Set either new _basis object or parameters for existing _basis, not both.') - """ - new_basis = parameters.pop("_basis", None) - if new_basis is not None: - self._basis = new_basis - if len(parameters) > 0: - raise ValueError( - "Set either new _basis object or parameters for existing _basis, not both." - ) - else: - self._basis = self._basis.set_params(**parameters) - - return self - - def get_params(self, deep: bool = True) -> dict: - """Extend the dict of parameters from the underlying Basis with _basis.""" - return {"_basis": self._basis, **self._basis.get_params(deep)} - - def __dir__(self) -> list[str]: - """Extend the list of properties of methods with the ones from the underlying Basis.""" - return super().__dir__() + self._basis.__dir__() - - def __add__(self, other: TransformerBasis) -> TransformerBasis: - """ - Add two TransformerBasis objects. - - Parameters - ---------- - other - The other TransformerBasis object to add. - - Returns - ------- - : TransformerBasis - The resulting Basis object. - """ - return TransformerBasis(self._basis + other._basis) - - def __mul__(self, other: TransformerBasis) -> TransformerBasis: - """ - Multiply two TransformerBasis objects. - - Parameters - ---------- - other - The other TransformerBasis object to multiply. - - Returns - ------- - : - The resulting Basis object. - """ - return TransformerBasis(self._basis * other._basis) - - def __pow__(self, exponent: int) -> TransformerBasis: - """Exponentiation of a TransformerBasis object. - - Define the power of a basis by repeatedly applying the method __mul__. - The exponent must be a positive integer. - - Parameters - ---------- - exponent : - Positive integer exponent - - Returns - ------- - : - The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. - - Raises - ------ - TypeError - If the provided exponent is not an integer. - ValueError - If the integer is zero or negative. - """ - # errors are handled by Basis.__pow__ - return TransformerBasis(self._basis**exponent) - - -class Basis(Base, abc.ABC): - """ - Abstract base class for defining basis functions for feature transformation. - - Basis functions are mathematical constructs that can represent data in alternative, - often more compact or interpretable forms. This class provides a template for such - transformations, with specific implementations defining the actual behavior. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Raises - ------ - ValueError: - If ``mode`` is not 'eval' or 'conv'. - ValueError: - If ``kwargs`` are not None and ``mode =="eval"``. - ValueError: - If ``kwargs`` include parameters not recognized or do not have - default values in ``create_convolutional_predictor``. - ValueError: - If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis). - """ - - def __init__( - self, - n_basis_funcs: int, - mode: Literal["eval", "conv"] = "eval", - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = None, - **kwargs, - ) -> None: - self.n_basis_funcs = n_basis_funcs - self._n_input_dimensionality = 0 - - self._conv_kwargs = kwargs - - # check mode - if mode not in ["conv", "eval"]: - raise ValueError( - f"`mode` should be either 'conv' or 'eval'. '{mode}' provided instead!" - ) - - self._mode = mode - - self._n_basis_input = None - - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - - if label is None: - self._label = self.__class__.__name__ - else: - self._label = str(label) - - self.window_size = window_size - self.bounds = bounds - - self._check_convolution_kwargs() - - self.kernel_ = None - - def _check_convolution_kwargs(self): - """Check convolution kwargs settings. - - Raises - ------ - ValueError: - - If ``self._conv_kwargs`` are not None and ``mode =="eval"``. - - If ``axis`` is provided as an argument, and it is different from 0 - (samples must always be in the first axis). - - If ``self._conv_kwargs`` include parameters not recognized or that do not have - default values in ``create_convolutional_predictor``. - """ - if self._mode == "eval" and self._conv_kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{self._mode}' provided instead!" - ) - - if "axis" in self._conv_kwargs: - raise ValueError( - "Setting the `axis` parameter is not allowed. Basis requires the " - "convolution to be applied along the first axis (`axis=0`).\n" - "Please transpose your input so that the desired axis for " - "convolution is the first dimension (axis=0)." - ) - convolve_params = inspect.signature(create_convolutional_predictor).parameters - convolve_configs = { - key - for key, param in convolve_params.items() - if param.default - # prevent user from passing - # `basis_matrix` or `time_series` in kwargs. - is not inspect.Parameter.empty - } - if not set(self._conv_kwargs.keys()).issubset(convolve_configs): - # do not encourage to set axis. - convolve_configs = convolve_configs.difference({"axis"}) - # remove the parameter in case axis=0 was passed, since it is allowed. - invalid = ( - set(self._conv_kwargs.keys()) - .difference(convolve_configs) - .difference({"axis"}) - ) - raise ValueError( - f"Unrecognized keyword arguments: {invalid}. " - f"Allowed convolution keyword arguments are: {convolve_configs}." - ) - - @property - def n_output_features(self) -> int | None: - """ - Number of features returned by the basis. - - Notes - ----- - The number of output features can be determined only when the number of inputs - provided to the basis is known. Therefore, before the first call to ``compute_features``, - this property will return ``None``. After that call, ``n_output_features`` will be available. - """ - return self._n_output_features - - @property - def label(self) -> str: - """Label for the basis.""" - return self._label - - @property - def n_basis_input(self) -> tuple | None: - """Number of expected inputs. - - The number of inputs ``compute_feature`` expects. - """ - if self._n_basis_input is None: - return - return self._n_basis_input - - @property - def n_basis_funcs(self): - """Number of basis functions.""" - return self._n_basis_funcs - - @n_basis_funcs.setter - def n_basis_funcs(self, value): - orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None)) - self._n_basis_funcs = value - try: - self._check_n_basis_min() - except ValueError as e: - self._n_basis_funcs = orig_n_basis - raise e - - @property - def bounds(self): - """Range of values covered by the basis.""" - return self._bounds - - @bounds.setter - def bounds(self, values: Union[None, Tuple[float, float]]): - """Setter for bounds.""" - if values is not None and self.mode == "conv": - raise ValueError("`bounds` should only be set when `mode=='eval'`.") - - if values is not None and len(values) != 2: - raise ValueError( - f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" - ) - - # convert to float and store - try: - self._bounds = values if values is None else tuple(map(float, values)) - except (ValueError, TypeError): - raise TypeError("Could not convert `bounds` to float.") - - if values is not None and values[1] <= values[0]: - raise ValueError( - f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." - ) - - @property - def mode(self): - """Mode of operation, either ``"conv"`` or ``"eval"``.""" - return self._mode - - @property - def window_size(self): - """Window size as number of samples. - - Duration of the convolutional kernel in number of samples. - """ - return self._window_size - - @window_size.setter - def window_size(self, window_size): - """Setter for the window size parameter.""" - if self.mode == "eval": - if window_size: - raise ValueError( - "If basis is in `mode=='eval'`, `window_size` should be None." - ) - - else: - if window_size is None: - raise ValueError( - "If the basis is in `conv` mode, you must provide a window_size!" - ) - - elif not (isinstance(window_size, int) and window_size > 0): - raise ValueError( - f"`window_size` must be a positive integer. {window_size} provided instead!" - ) - - self._window_size = window_size - - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix ``X``. - - Removes columns from ``X`` until ``[1, X]`` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - r""" - Apply the basis transformation to the input data. - - This method operates in two modes: - - 'eval': Evaluates the basis functions at the given sample points. - - 'conv': Applies a convolution operation between the input data and the basis functions, - using a window size defined at initialization. - - Parameters - ---------- - *xi: - The input samples over which to apply the basis transformation. The samples can be passed - as multiple arguments, each representing a different dimension for multivariate inputs. - - Returns - ------- - : - A matrix with the transformed features. The shape of the output depends on the operation mode: - - If ``mode == 'eval'``, the basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` - is a basis element. ``xi[k]`` must be a one-dimensional array or a pynapple Tsd. - - - If ``mode == 'conv'``, a bank of basis filters (created by calling fit) is convolved with the - samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions - except for the sample-axis are flattened, so that the method always returns a matrix. - For example, if samples are of shape (num_samples, 2, 3), the output will be - (num_samples, num_basis_funcs * 2 * 3). - The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of - transform will be (num_samples, num_basis_funcs * N1 * N3 *...). - - Raises - ------ - ValueError: - - If an invalid mode is specified or necessary parameters for the chosen mode are missing. - - In mode "conv", if the number of inputs to be convolved, doesn't match the number of inputs - set at initialization. - """ - # check if self.kernel_ is not None for mode="conv" - self._check_has_kernel() - if self.mode == "eval": # evaluate at the sample - return self.__call__(*xi) - else: # convolve, called only at the last layer - # before calling the convolve, check that the input matches - # the expectation. We can check xi[0] only, since convolution - # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. - conv = create_convolutional_predictor( - self.kernel_, *xi, **self._conv_kwargs - ) - # make sure to return a matrix - return np.reshape(conv, newshape=(conv.shape[0], -1)) - - @check_transform_input - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute the basis functions and transform input data into model features. - - This method is designed to be a high-level interface for transforming input - data using the basis functions defined by the subclass. Depending on the basis' - mode ('eval' or 'conv'), it either evaluates the basis functions at the sample - points or performs a convolution operation between the input data and the - basis functions. - - Parameters - ---------- - *xi : - Input data arrays to be transformed. The shape and content requirements - depend on the subclass and mode of operation ('eval' or 'conv'). - - Returns - ------- - : - Transformed features. In 'eval' mode, it corresponds to the basis functions - evaluated at the input samples. In 'conv' mode, it consists of convolved - input samples with the basis functions. The output shape varies based on - the subclass and mode. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import BSplineBasis - >>> # Generate data - >>> num_samples = 10000 - >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = BSplineBasis(10) - >>> features = basis.compute_features(X) # basis transformed time series - >>> features.shape - (10000, 10) - - Notes - ----- - Subclasses should implement how to handle the transformation specific to their - basis function types and operation modes. - """ - self._set_num_output_features(*xi) - if self.kernel_ is None: - self._set_kernel(*xi) - return self._compute_features(*xi) - - def _set_kernel(self, *xi: ArrayLike) -> Basis: - """ - Prepare or compute the convolutional kernel for the basis functions. - - This method is called to prepare the basis functions for convolution operations - in subclasses where the 'conv' mode is used. It typically involves computing a - kernel based on the basis functions that will be used for convolution with the - input data. The specifics of kernel computation depend on the subclass implementation - and the nature of the basis functions. - - In 'eval' mode, this method might not perform any operation but simply return the - instance itself, as no kernel preparation is necessary. - - Parameters - ---------- - *xi : - The input data based on which the kernel might be computed. The actual use of - these inputs is subclass-specific and might not be applicable for all basis types. - - Returns - ------- - self : - The instance itself, modified to include the computed kernel if applicable. This - allows for method chaining and integration into transformation pipelines. - - Notes - ----- - Subclasses implementing this method should detail the specifics of how the kernel is - computed and how the input parameters are utilized. If the basis operates in 'eval' - mode exclusively, this method should simply return ``self`` without modification. - """ - if self.mode == "conv": - self.kernel_ = self.__call__(np.linspace(0, 1, self.window_size)) - return self - - @abc.abstractmethod - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Abstract method to evaluate the basis functions at given points. - - This method must be implemented by subclasses to define the specific behavior - of the basis transformation. The implementation depends on the type of basis - (e.g., spline, raised cosine), and it should evaluate the basis functions at - the specified points in the domain. - - Parameters - ---------- - *xi : - Variable number of arguments, each representing an array of points at which - to evaluate the basis functions. The dimensions and requirements of these - inputs vary depending on the specific basis implementation. - - Returns - ------- - : - An array containing the evaluated values of the basis functions at the input - points. The shape and structure of this array are specific to the subclass - implementation. - """ - pass - - def _get_samples(self, *n_samples: int) -> Generator[NDArray]: - """Get equi-spaced samples for all the input dimensions. - - This will be used to evaluate the basis on a grid of - points derived by the samples. - - Parameters - ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. - - Returns - ------- - : - A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by ``n_samples``. - """ - # handling of defaults when evaluating on a grid - # (i.e. when we cannot use max and min of samples) - if self.bounds is None: - mn, mx = 0, 1 - else: - mn, mx = self.bounds - return (np.linspace(mn, mx, n_samples[k]) for k in range(len(n_samples))) - - @support_pynapple(conv_type="numpy") - def _check_transform_input( - self, *xi: ArrayLike - ) -> Tuple[Union[NDArray, Tsd, TsdFrame]]: - """Check transform input. - - Parameters - ---------- - xi[0],...,xi[n] : - The input samples, each with shape (number of samples, ). - - Raises - ------ - ValueError - - If the time point number is inconsistent between inputs. - - If the number of inputs doesn't match what the Basis object requires. - - At least one of the samples is empty. - - """ - # check that the input is array-like (i.e., whether we can cast it to - # numeric arrays) - try: - # make sure array is at least 1d (so that we succeed when only - # passed a scalar) - xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) - # ValueError here surfaces the exception with e.g., `x=np.array["a", "b"])` - except (TypeError, ValueError): - raise TypeError("Input samples must be array-like of floats!") - - # check for non-empty samples - if self._has_zero_samples(tuple(len(x) for x in xi)): - raise ValueError("All sample provided must be non empty.") - - # checks on input and outputs - self._check_samples_consistency(*xi) - self._check_input_dimensionality(xi) - - return xi - - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points. - The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing - instead of the default cartesian indexing, see Notes. - - Parameters - ---------- - n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of - n_samples must equal the number of combined bases. - - Returns - ------- - *Xs : - A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid, - where n equals to the number of inputs. - The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``. - Y : - The basis function evaluated at the samples, - shape ``(n_samples[0], ... , n_samples[n], number of basis)``. - - Raises - ------ - ValueError - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what - the Basis object requires. - ValueError - If one of the n_samples is <= 0. - - Notes - ----- - Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size - :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. - This differs from the numpy.meshgrid default, which uses Cartesian indexing. - For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. - - Examples - -------- - >>> # Evaluate and visualize 4 M-spline basis functions of order 3: - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import MSplineBasis - >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> p = plt.plot(sample_points, basis_values) - >>> _ = plt.title('M-Spline Basis Functions') - >>> _ = plt.xlabel('Domain') - >>> _ = plt.ylabel('Basis Function Value') - >>> _ = plt.legend([f'Function {i+1}' for i in range(4)]); - """ - self._check_input_dimensionality(n_samples) - - if self._has_zero_samples(n_samples): - raise ValueError("All sample counts provided must be greater than zero.") - - # get the samples - sample_tuple = self._get_samples(*n_samples) - Xs = np.meshgrid(*sample_tuple, indexing="ij") - - # evaluates the basis on a flat NDArray and reshape to match meshgrid output - Y = self.__call__(*tuple(grid_axis.flatten() for grid_axis in Xs)).reshape( - (*n_samples, self.n_basis_funcs) - ) - - return *Xs, Y - - @staticmethod - def _has_zero_samples(n_samples: Tuple[int, ...]) -> bool: - return any([n <= 0 for n in n_samples]) - - def _check_input_dimensionality(self, xi: Tuple) -> None: - """ - Check that the number of inputs provided by the user matches the number of inputs required. - - Parameters - ---------- - xi[0], ..., xi[n] : - The input samples, shape (number of samples, ). - - Raises - ------ - ValueError - If the number of inputs doesn't match what the Basis object requires. - """ - if len(xi) != self._n_input_dimensionality: - raise TypeError( - f"Input dimensionality mismatch. This basis evaluation requires {self._n_input_dimensionality} inputs, " - f"{len(xi)} inputs provided instead." - ) - - @staticmethod - def _check_samples_consistency(*xi: NDArray) -> None: - """ - Check that each input provided to the Basis object has the same number of time points. - - Parameters - ---------- - xi[0], ..., xi[n] : - The input samples, shape (number of samples, ). - - Raises - ------ - ValueError - If the time point number is inconsistent between inputs. - """ - sample_sizes = [sample.shape[0] for sample in xi] - if any(elem != sample_sizes[0] for elem in sample_sizes): - raise ValueError( - "Sample size mismatch. Input elements have inconsistent sample sizes." - ) - - @abc.abstractmethod - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Most of the basis work with at least 1 element, but some - such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - pass - - def __add__(self, other: Basis) -> AdditiveBasis: - """ - Add two Basis objects together. - - Parameters - ---------- - other - The other Basis object to add. - - Returns - ------- - : AdditiveBasis - The resulting Basis object. - """ - return AdditiveBasis(self, other) - - def __mul__(self, other: Basis) -> MultiplicativeBasis: - """ - Multiply two Basis objects together. - - Parameters - ---------- - other - The other Basis object to multiply. - - Returns - ------- - : - The resulting Basis object. - """ - return MultiplicativeBasis(self, other) - - def __pow__(self, exponent: int) -> MultiplicativeBasis: - """Exponentiation of a Basis object. - - Define the power of a basis by repeatedly applying the method __multiply__. - The exponent must be a positive integer. - - Parameters - ---------- - exponent : - Positive integer exponent - - Returns - ------- - : - The product of the basis with itself "exponent" times. Equivalent to ``self * self * ... * self``. - - Raises - ------ - TypeError - If the provided exponent is not an integer. - ValueError - If the integer is zero or negative. - """ - if not isinstance(exponent, int): - raise TypeError("Exponent should be an integer!") - - if exponent <= 0: - raise ValueError("Exponent should be a non-negative integer!") - - result = self - for _ in range(exponent - 1): - result = result * self - return result - - def to_transformer(self) -> TransformerBasis: - """ - Turn the Basis into a TransformerBasis for use with scikit-learn. - - Examples - -------- - Jointly cross-validating basis and GLM parameters with scikit-learn. - - >>> import nemos as nmo - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> # load some data - >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() - >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) - >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) - >>> param_grid = dict( - ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), - ... basis__n_basis_funcs=(3, 5, 10, 20, 100), - ... ) - >>> gridsearch = GridSearchCV( - ... pipeline, - ... param_grid=param_grid, - ... cv=5, - ... ) - >>> gridsearch = gridsearch.fit(X, y) - """ - return TransformerBasis(copy.deepcopy(self)) - - def _get_feature_slicing( - self, - n_inputs: Optional[tuple] = None, - start_slice: Optional[int] = None, - split_by_input: bool = True, - ) -> Tuple[dict, int]: - """ - Calculate and return the slicing for features based on the input structure. - - This method determines how to slice the features for different basis types. - If the instance is of ``AdditiveBasis`` type, the slicing is calculated recursively - for each component basis. Otherwise, it determines the slicing based on - the number of basis functions and ``split_by_input`` flag. - - Parameters - ---------- - n_inputs : - The number of input basis for each component, by default it uses ``self._n_basis_input``. - start_slice : - The starting index for slicing, by default it starts from 0. - split_by_input : - Flag indicating whether to split the slicing by individual inputs or not. - If ``False``, a single slice is generated for all inputs. - - Returns - ------- - split_dict : - Dictionary with keys as labels and values as slices representing - the slicing for each input or additive component, if split_by_input equals to - True or False respectively. - start_slice : - The updated starting index after slicing. - - See Also - -------- - _get_default_slicing : Handles default slicing logic. - _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. - """ - # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input - start_slice = start_slice or 0 - - # If the instance is of AdditiveBasis type, handle slicing for the additive components - if isinstance(self, AdditiveBasis): - split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], - start_slice, - split_by_input=split_by_input, - ) - sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], - start_slice, - split_by_input=split_by_input, - ) - split_dict = self._merge_slicing_dicts(split_dict, sp2) - else: - # Handle the default case for other basis types - split_dict, start_slice = self._get_default_slicing( - split_by_input, start_slice - ) - - return split_dict, start_slice - - def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: - """Merge two slicing dictionaries, handling key conflicts.""" - for key, val in dict2.items(): - if key in dict1: - new_key = self._generate_unique_key(dict1, key) - dict1[new_key] = val - else: - dict1[key] = val - return dict1 - - @staticmethod - def _generate_unique_key(existing_dict: dict, key: str) -> str: - """Generate a unique key if there is a conflict.""" - extra = 1 - new_key = f"{key}-{extra}" - while new_key in existing_dict: - extra += 1 - new_key = f"{key}-{extra}" - return new_key - - def _get_default_slicing( - self, split_by_input: bool, start_slice: int - ) -> Tuple[dict, int]: - """Handle default slicing logic.""" - if split_by_input: - # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): - split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) - } - else: - split_dict = { - self.label: { - f"{i}": slice( - start_slice + i * self.n_basis_funcs, - start_slice + (i + 1) * self.n_basis_funcs, - ) - for i in range(self._n_basis_input[0]) - } - } - else: - split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) - } - start_slice += self._n_output_features - return split_dict, start_slice - - def split_by_feature( - self, - x: NDArray, - axis: int = 1, - ): - r""" - Decompose an array along a specified axis into sub-arrays based on the number of expected inputs. - - This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. - - **How it works:** - - - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will - be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions - ``(n_inputs, n_basis_funcs)``. - - - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will - be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``. - - For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``, - then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``. - - The specified axis (``axis``) determines where the split occurs, and all other dimensions - remain unchanged. See the example section below for the most common use cases. - - Parameters - ---------- - x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of ``x`` along the specified axis must match the total - number of features generated by the basis, i.e., ``self.n_output_features``. - - **Examples:** - - - For a design matrix: ``(n_samples, total_n_features)`` - - - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. - - axis : int, optional - The axis along which to split the features. Defaults to 1. - Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for - coefficient arrays (features along rows). All other dimensions are preserved. - - Raises - ------ - ValueError - If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. - - Returns - ------- - dict - A dictionary where: - - **Key**: Label of the basis. - - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import BSplineBasis - >>> from nemos.glm import GLM - >>> # Define an additive basis - >>> basis = BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature") - >>> # Generate a sample input array and compute features - >>> x = np.random.randn(20) - >>> X = basis.compute_features(x) - >>> # Split the feature matrix along axis 1 - >>> split_features = basis.split_by_feature(X, axis=1) - >>> for feature, arr in split_features.items(): - ... print(f"{feature}: shape {arr.shape}") - feature: shape (20, 1, 5) - >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, - ... label="multi_input") - >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) - >>> for feature, sub_dict in split_features_multi.items(): - ... print(f"{feature}, shape {sub_dict.shape}") - multi_input, shape (20, 2, 6) - >>> # the method can be used to decompose the glm coefficients in the various features - >>> counts = np.random.poisson(size=20) - >>> model = GLM().fit(X, counts) - >>> split_coef = basis.split_by_feature(model.coef_, axis=0) - >>> for feature, coef in split_coef.items(): - ... print(f"{feature}: shape {coef.shape}") - feature: shape (1, 5) - - """ - if x.shape[axis] != self.n_output_features: - raise ValueError( - "`x.shape[axis]` does not match the expected number of features." - f" `x.shape[axis] == {x.shape[axis]}`, while the expected number " - f"of features is {self.n_output_features}" - ) - - # Get the slice dictionary based on predefined feature slicing - slice_dict = self._get_feature_slicing(split_by_input=False)[0] - - # Helper function to build index tuples for each slice - def build_index_tuple(slice_obj, axis: int, ndim: int): - """Create an index tuple to apply a slice on the given axis.""" - index = [slice(None)] * ndim # Initialize index for all dimensions - index[axis] = slice_obj # Replace the axis with the slice object - return tuple(index) - - # Get the dict for slicing the correct axis - index_dict = jax.tree_util.tree_map( - lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict - ) - - # Custom leaf function to identify index tuples as leaves - def is_leaf(val): - # Check if it's a tuple, length matches ndim, and all elements are slice objects - if isinstance(val, tuple) and len(val) == x.ndim: - return all(isinstance(v, slice) for v in val) - return False - - # Apply the slicing using the custom leaf function - out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - - # reshape the arrays to spilt by n_basis_input - reshaped_out = dict() - for i, vals in enumerate(out.items()): - key, val = vals - shape = list(val.shape) - reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] - ) - return reshaped_out - - def _check_input_shape_consistency(self, x: NDArray): - """Check input consistency across calls.""" - # remove sample axis - shape = x.shape[1:] - if self._input_shape is not None and self._input_shape != shape: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] - expected_shape_str = expected_shape_str.replace(",)", ")") - raise ValueError( - f"Input shape mismatch detected.\n\n" - f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " - f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" - f" Expected: {expected_shape_str}\n" - f" But got: {x.shape}.\n\n" - "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " - "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance." - ) - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - """ - Pre-compute the number of inputs and output features. - - This function computes the number of inputs that are provided to the basis and uses - that number, and the n_basis_funcs to calculate the number of output features that - ``self.compute_features`` will return. These quantities and the input shape (excluding the sample axis) - are stored in ``self._n_basis_input`` and ``self._n_output_features``, and ``self._input_shape`` - respectively. - - Parameters - ---------- - xi: - The input arrays. - - Returns - ------- - : - The basis itself, for chaining. - - Raises - ------ - ValueError: - If the number of inputs do not match ``self._n_basis_input``, if ``self._n_basis_input`` was - not None. - - Notes - ----- - Once a ``compute_features`` is called, we enforce that for all subsequent calls of the method, - the input that the basis receives preserves the shape of all axes, except for the sample axis. - This condition guarantees the consistency of the feature axis, and therefore that - ``self.split_by_feature`` behaves appropriately. - - """ - # Check that the input shape matches expectation - # Note that this method is reimplemented in AdditiveBasis and MultiplicativeBasis - # so we can assume that len(xi) == 1 - xi = xi[0] - self._check_input_shape_consistency(xi) - - # remove sample axis (samples are allowed to vary) - shape = xi.shape[1:] - - self._input_shape = shape - - # remove sample axis & get the total input number - n_inputs = (1,) if xi.ndim == 1 else (np.prod(shape),) - - self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] - return self - - -class AdditiveBasis(Basis): - """ - Class representing the addition of two Basis objects. - - Parameters - ---------- - basis1 : - First basis object to add. - basis2 : - Second basis object to add. - - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - - Examples - -------- - >>> # Generate sample data - >>> import numpy as np - >>> import nemos as nmo - >>> X = np.random.normal(size=(30, 2)) - >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) - >>> additive_basis = basis_1 + basis_2 - >>> # can add another basis to the AdditiveBasis object - >>> X = np.random.normal(size=(30, 3)) - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) - >>> additive_basis_2 = additive_basis + basis_3 - - """ - - def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") - self._n_input_dimensionality = ( - basis1._n_input_dimensionality + basis2._n_input_dimensionality - ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - return - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features - ) - return self - - def _check_n_basis_min(self) -> None: - pass - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Evaluate the basis at the input samples. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - - """ - X = np.hstack( - ( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), - ) - ) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute features for added bases and concatenate. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The features, shape (n_samples, n_basis_funcs) - - """ - # the numpy conversion is important, there is some in-place - # array modification in basis. - hstack_pynapple = support_pynapple(conv_type="numpy")(np.hstack) - X = hstack_pynapple( - ( - self._basis1._compute_features( - *xi[: self._basis1._n_input_dimensionality] - ), - self._basis2._compute_features( - *xi[self._basis1._n_input_dimensionality :] - ), - ), - ) - return X - - def _set_kernel(self, *xi: ArrayLike) -> Basis: - """Call fit on the added basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. - - Returns - ------- - : - The AdditiveBasis ready to be evaluated. - """ - self._basis1._set_kernel(*xi) - self._basis2._set_kernel(*xi) - return self - - def split_by_feature( - self, - x: NDArray, - axis: int = 1, - ): - r""" - Decompose an array along a specified axis into sub-arrays based on the basis components. - - This function takes an array (e.g., a design matrix or model coefficients) and splits it along - a designated axis. Each split corresponds to a different additive component of the basis, - preserving all dimensions except the specified axis. - - **How It Works:** - - Suppose the basis is made up of **m components**, each with :math:`b_i` basis functions and :math:`n_i` inputs. - The total number of features, :math:`N`, is calculated as: - - .. math:: - N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m - - This method splits any axis of length :math:`N` into sub-arrays, one for each basis component. - - The sub-array for the i-th basis component is reshaped into dimensions - :math:`(n_i, b_i)`. - - For example, if the array shape is :math:`(1, 2, N, 4, 5)`, then each split sub-array will have shape: - - .. math:: - (1, 2, n_i, b_i, 4, 5) - - where: - - - :math:`n_i` represents the number of inputs associated with the i-th component, - - :math:`b_i` represents the number of basis functions in that component. - - The specified axis (``axis``) determines where the split occurs, and all other dimensions - remain unchanged. See the example section below for the most common use cases. - - Parameters - ---------- - x : - The input array to be split, representing concatenated features, coefficients, - or other data. The shape of ``x`` along the specified axis must match the total - number of features generated by the basis, i.e., ``self.n_output_features``. - - **Examples:** - - For a design matrix: ``(n_samples, total_n_features)`` - - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``. - - axis : int, optional - The axis along which to split the features. Defaults to 1. - Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for - coefficient arrays (features along rows). All other dimensions are preserved. - - Raises - ------ - ValueError - If the shape of ``x`` along the specified axis does not match ``self.n_output_features``. - - Returns - ------- - dict - A dictionary where: - - **Keys**: Labels of the additive basis components. - - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: - - .. math:: - (..., n_i, b_i, ...) - - - ``n_i``: The number of inputs processed by the i-th basis component. - - ``b_i``: The number of basis functions for the i-th basis component. - - These sub-arrays are reshaped along the specified axis, with all other dimensions - remaining the same. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import BSplineBasis - >>> from nemos.glm import GLM - >>> # Define an additive basis - >>> basis = ( - ... BSplineBasis(n_basis_funcs=5, mode="conv", window_size=10, label="feature_1") + - ... BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, label="feature_2") - ... ) - >>> # Generate a sample input array and compute features - >>> x1, x2 = np.random.randn(20), np.random.randn(20) - >>> X = basis.compute_features(x1, x2) - >>> # Split the feature matrix along axis 1 - >>> split_features = basis.split_by_feature(X, axis=1) - >>> for feature, arr in split_features.items(): - ... print(f"{feature}: shape {arr.shape}") - feature_1: shape (20, 1, 5) - feature_2: shape (20, 1, 6) - >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = BSplineBasis(n_basis_funcs=6, mode="conv", window_size=10, - ... label="multi_input") - >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) - >>> for feature, sub_dict in split_features_multi.items(): - ... print(f"{feature}, shape {sub_dict.shape}") - multi_input, shape (20, 2, 6) - >>> # the method can be used to decompose the glm coefficients in the various features - >>> counts = np.random.poisson(size=20) - >>> model = GLM().fit(X, counts) - >>> split_coef = basis.split_by_feature(model.coef_, axis=0) - >>> for feature, coef in split_coef.items(): - ... print(f"{feature}: shape {coef.shape}") - feature_1: shape (1, 5) - feature_2: shape (1, 6) - - """ - return super().split_by_feature(x, axis=axis) - - -class MultiplicativeBasis(Basis): - """ - Class representing the multiplication (external product) of two Basis objects. - - Parameters - ---------- - basis1 : - First basis object to multiply. - basis2 : - Second basis object to multiply. - - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - - Examples - -------- - >>> # Generate sample data - >>> import numpy as np - >>> import nemos as nmo - >>> X = np.random.normal(size=(30, 3)) - >>> # define two basis and multiply - >>> basis_1 = nmo.basis.BSplineBasis(10) - >>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15) - >>> multiplicative_basis = basis_1 * basis_2 - >>> # Can multiply or add another basis to the AdditiveBasis object - >>> # This will cause the number of output features of the result basis to grow accordingly - >>> basis_3 = nmo.basis.RaisedCosineBasisLog(100) - >>> multiplicative_basis_2 = multiplicative_basis * basis_3 - """ - - def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") - self._n_input_dimensionality = ( - basis1._n_input_dimensionality + basis2._n_input_dimensionality - ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - return - - def _check_n_basis_min(self) -> None: - pass - - def _set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. - - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. - """ - self._basis1._set_kernel(*xi) - self._basis2._set_kernel(*xi) - return self - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Evaluate the basis at the input samples. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - """ - X = np.asarray( - row_wise_kron( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), - transpose=False, - ) - ) - return X - - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: - """ - Compute the features for the multiplied bases, and compute their outer product. - - Parameters - ---------- - xi[0], ..., xi[n] : (n_samples,) - Tuple of input samples, each with the same number of samples. The - number of input arrays must equal the number of combined bases. - - Returns - ------- - : - The features, shape (n_samples, n_basis_funcs) - - """ - kron = support_pynapple(conv_type="numpy")(row_wise_kron) - X = kron( - self._basis1._compute_features(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._compute_features(*xi[self._basis1._n_input_dimensionality :]), - transpose=False, - ) - return X - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features - ) - return self - - -class SplineBasis(Basis, abc.ABC): - """ - SplineBasis class inherits from the Basis class and represents spline basis functions. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : optional - Spline order. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Attributes - ---------- - order : int - Spline order. - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 2, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = None, - **kwargs, - ) -> None: - self.order = order - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - - self._n_input_dimensionality = 1 - - @property - def order(self): - """Spline order. - - Spline order, i.e. the polynomial degree of the spline plus one. - """ - return self._order - - @order.setter - def order(self, value): - """Setter for the order parameter.""" - if value != int(value): - raise ValueError( - f"Spline order must be an integer! Order {value} provided." - ) - value = int(value) - if value < 1: - raise ValueError(f"Spline order must be positive! Order {value} provided.") - - # Set to None only the first time the setter is called. - orig_order = copy.deepcopy(getattr(self, "_order", None)) - - # Set the order - self._order = value - - # If the order was already initialized, re-check basis - if orig_order is not None: - try: - self._check_n_basis_min() - except ValueError as e: - self._order = orig_order - raise e - - def _generate_knots( - self, - sample_pts: NDArray, - perc_low: float = 0.0, - perc_high: float = 1.0, - is_cyclic: bool = False, - ) -> NDArray: - """ - Generate knot locations for spline basis functions. - - Parameters - ---------- - sample_pts : (n_samples,) - The sample points. - perc_low - The low percentile value, between [0,1). - perc_high - The high percentile value, between (0,1]. - is_cyclic : optional - Whether the spline is cyclic. - - Returns - ------- - The knot locations for the spline basis functions. - - Raises - ------ - AssertionError - If the percentiles or order values are not within the valid range. - """ - # Determine number of interior knots. - num_interior_knots = self.n_basis_funcs - self.order - if is_cyclic: - num_interior_knots += self.order - 1 - - # Spline basis have support on the semi-open [a, b) interval, we add a small epsilon - # to mx so that the so that basis_element(max(samples)) != 0 - knot_locs = np.concatenate( - ( - np.zeros(self.order - 1), - np.linspace(0, (1 + np.finfo(float).eps), num_interior_knots + 2), - np.full(self.order - 1, 1 + np.finfo(float).eps), - ) - ) - return knot_locs - - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Check that the spline-basis has at least as many basis as the order. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - if self.n_basis_funcs < self.order: - raise ValueError( - f"{self.__class__.__name__} `order` parameter cannot be larger " - "than `n_basis_funcs` parameter." - ) - - -class MSplineBasis(SplineBasis): - r""" - M-spline basis functions for modeling and data transformation. - - M-splines [1]_ are a type of spline basis function used for smooth curve fitting - and data representation. They are positive and integrate to one, making them - suitable for probabilistic models and density estimation. The order of an - M-spline defines its smoothness, with higher orders resulting in smoother - splines. - - This class provides functionality to create M-spline basis functions, allowing - for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` - abstract class, providing specific implementations for M-splines. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : - The order of the splines used in basis functions. Must be between [1, - n_basis_funcs]. Default is 2. Higher order splines have more continuous - derivatives at each interior knot, resulting in smoother basis functions. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs: - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import MSplineBasis - >>> n_basis_funcs = 5 - >>> order = 3 - >>> mspline_basis = MSplineBasis(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) - - References - ---------- - .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, - 3(4), 425-441. - - Notes - ----- - ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values - will shrink by a factor of :math:`1/\alpha`. - For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. - If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 2, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "MSplineBasis", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: - """ - Evaluate the M-spline basis functions at given sample points. - - Parameters - ---------- - sample_pts : - An array of sample points where the M-spline basis functions are to be - evaluated. - - Returns - ------- - : - An array where each column corresponds to one M-spline basis function - evaluated at the input sample points. The shape of the array is - (len(sample_pts), n_basis_funcs). - - Notes - ----- - The implementation uses a recursive definition of M-splines. Boundary - conditions are handled such that the basis functions are positive and - integrate to one over the domain defined by the sample points. - """ - sample_pts, scaling = min_max_rescale_samples(sample_pts, self.bounds) - # add knots if not passed - knot_locs = self._generate_knots( - sample_pts, perc_low=0.0, perc_high=1.0, is_cyclic=False - ) - - X = np.stack( - [ - mspline(sample_pts, self.order, i, knot_locs) - for i in range(self.n_basis_funcs) - ], - axis=1, - ) - # re-normalize so that it integrates to 1 over the range. - X /= scaling - return X - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """ - Evaluate the M-spline basis functions on a uniformly spaced grid. - - This method creates a uniformly spaced grid of sample points within the domain - [0, 1] and evaluates all the M-spline basis functions at these points. It is - particularly useful for visualizing the shape and distribution of the basis - functions across their domain. - - Parameters - ---------- - n_samples : - The number of points in the uniformly spaced grid. A higher number of - samples will result in a more detailed visualization of the basis functions. - - Returns - ------- - X : NDArray - A 1D array of uniformly spaced sample points within the domain [0, 1]. - Shape: ``(n_samples,)``. - Y : NDArray - A 2D array where each row corresponds to the evaluated M-spline basis - function values at the points in X. Shape: ``(n_samples, n_basis_funcs)``. - - Examples - -------- - Evaluate and visualize 4 M-spline basis functions of order 3: - - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import MSplineBasis - >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) - >>> for i in range(4): - ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') - >>> plt.title('M-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') - >>> plt.xlabel('Domain') - Text(0.5, 0, 'Domain') - >>> plt.ylabel('Basis Function Value') - Text(0, 0.5, 'Basis Function Value') - >>> l = plt.legend() - """ - return super().evaluate_on_grid(n_samples) - - -class BSplineBasis(SplineBasis): - """ - B-spline 1-dimensional basis functions. - - Implementation of the one-dimensional BSpline basis [1]_. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. ``'eval'`` for evaluation at sample points, - 'conv' for convolutional operation. - order : - Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Attributes - ---------- - order : - Spline order. - - - References - ---------- - .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. - Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import BSplineBasis - >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis(sample_points) - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "BSplineBasis", - **kwargs, - ): - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: - """ - Evaluate the B-spline basis functions with given sample points. - - Parameters - ---------- - sample_pts : - The sample points at which the B-spline is evaluated, shape (n_samples,). - - Returns - ------- - basis_funcs : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs). - - Raises - ------ - AssertionError - If the sample points are not within the B-spline knots. - - Notes - ----- - The evaluation is performed by looping over each element and using ``splev`` - from SciPy to compute the basis values. - """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - # add knots - knot_locs = self._generate_knots(sample_pts, 0.0, 1.0) - - basis_eval = bspline( - sample_pts, knot_locs, order=self.order, der=0, outer_ok=False - ) - return basis_eval - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the B-spline basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)`` - - Notes - ----- - The evaluation is performed by looping over each element and using ``splev`` from - SciPy to compute the basis values. - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import BSplineBasis - >>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - - -class CyclicBSplineBasis(SplineBasis): - """ - B-spline 1-dimensional basis functions for cyclic splines. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - order : - Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Attributes - ---------- - n_basis_funcs : - Number of basis functions, int. - order : - Order of the splines used in basis functions, int. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import CyclicBSplineBasis - >>> X = np.random.normal(size=(1000, 1)) - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_basis(sample_points) - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "CyclicBSplineBasis", - **kwargs, - ): - super().__init__( - n_basis_funcs, - mode=mode, - order=order, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - if self.order < 2: - raise ValueError( - f"Order >= 2 required for cyclic B-spline, " - f"order {self.order} specified instead!" - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Evaluate the Cyclic B-spline basis functions with given sample points. - - Parameters - ---------- - sample_pts : - The sample points at which the cyclic B-spline is evaluated, shape - (n_samples,). - - Returns - ------- - basis_funcs : - The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) - - Notes - ----- - The evaluation is performed by looping over each element and using ``splev`` from - SciPy to compute the basis values. - - """ - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - knot_locs = self._generate_knots(sample_pts, 0.0, 1.0, is_cyclic=True) - - # for cyclic, do not repeat knots - knot_locs = np.unique(knot_locs) - - nk = knot_locs.shape[0] - - # make sure knots are sorted - knot_locs.sort() - - # extend knots - xc = knot_locs[nk - self.order] - knots = np.hstack( - ( - knot_locs[0] - knot_locs[-1] + knot_locs[nk - self.order : nk - 1], - knot_locs, - ) - ) - - ind = sample_pts > xc - - basis_eval = bspline(sample_pts, knots, order=self.order, der=0, outer_ok=True) - sample_pts[ind] = sample_pts[ind] - knots.max() + knot_locs[0] - - if np.sum(ind): - basis_eval[ind] = basis_eval[ind] + bspline( - sample_pts[ind], knots, order=self.order, outer_ok=True, der=0 - ) - # restore points - sample_pts[ind] = sample_pts[ind] + knots.max() - knot_locs[0] - return basis_eval - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the Cyclic B-spline basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)`` - - Notes - ----- - The evaluation is performed by looping over each element and using ``splev`` from - SciPy to compute the basis values. - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import CyclicBSplineBasis - >>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3) - >>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - - -class RaisedCosineBasisLinear(Basis): - """Represent linearly-spaced raised cosine basis functions. - - This implementation is based on the cosine bumps used by Pillow et al. [1]_ - to uniformly tile the internal points of the domain. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - width : - Width of the raised cosine. By default, it's set to 2.0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLinear - >>> X = np.random.normal(size=(1000, 1)) - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - - References - ---------- - .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - width: float = 2.0, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "RaisedCosineBasisLinear", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - self._n_input_dimensionality = 1 - self._check_width(width) - self._width = width - # for these linear raised-cosine basis functions, - # the samples must be rescaled to 0 and 1. - self._rescale_samples = True - - @property - def width(self): - """Return width of the raised cosine.""" - return self._width - - @width.setter - def width(self, width: float): - self._check_width(width) - self._width = width - - @staticmethod - def _check_width(width: float) -> None: - """Validate the width value. - - Parameters - ---------- - width : - The width value to validate. - - Raises - ------ - ValueError - If width <= 1 or 2*width is not a positive integer. Values that do not match - this constraint will result in: - - No overlap between bumps (width < 1). - - Oscillatory behavior when summing the basis elements (2*width not integer). - """ - if width <= 1 or (not np.isclose(width * 2, round(2 * width))): - raise ValueError( - f"Invalid raised cosine width. " - f"2*width must be a positive integer, 2*width = {2 * width} instead!" - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Generate basis functions with given samples. - - Parameters - ---------- - sample_pts : - Spacing for basis functions, holding elements on interval [0, 1], Shape (number of samples, ). - - Raises - ------ - ValueError - If the sample provided do not lie in [0,1]. - - """ - if self._rescale_samples: - # note that sample points is converted to NDArray - # with the decorator. - # copy is necessary otherwise: - # basis1 = nmo.basis.RaisedCosineBasisLinear(5) - # basis2 = nmo.basis.RaisedCosineBasisLog(5) - # additive_basis = basis1 + basis2 - # additive_basis(*([x] * 2)) would modify both inputs - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) - - peaks = self._compute_peaks() - delta = peaks[1] - peaks[0] - # generate a set of shifted cosines, and constrain them to be non-zero - # over a single period, then enforce the codomain to be [0,1], by adding 1 - # and then multiply by 0.5 - basis_funcs = 0.5 * ( - np.cos( - np.clip( - np.pi * (sample_pts[:, None] - peaks[None]) / (delta * self.width), - -np.pi, - np.pi, - ) - ) - + 1 - ) - return basis_funcs - - def _compute_peaks(self) -> NDArray: - """ - Compute the location of raised cosine peaks. - - Returns - ------- - Peak locations of each basis element. - """ - return np.linspace(0, 1, self.n_basis_funcs) - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)`` - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import RaisedCosineBasisLinear - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Check that the number of basis is at least 2. - - Raises - ------ - ValueError - If n_basis_funcs < 2. - """ - if self.n_basis_funcs < 2: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 2 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) - - -class RaisedCosineBasisLog(RaisedCosineBasisLinear): - """Represent log-spaced raised cosine basis functions. - - Similar to ``RaisedCosineBasisLinear`` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al. [1]_ - to uniformly tile the internal points of the domain. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - mode : - The mode of operation. 'eval' for evaluation at sample points, - 'conv' for convolutional operation. - width : - Width of the raised cosine. - time_scaling : - Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with - larger values resulting in more stretching. As this approaches 0, the - transformation becomes linear. - enforce_decay_to_zero: - If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements - and subsequently trims off the extra basis elements. This ensures that the final basis element - decays to 0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLog - >>> X = np.random.normal(size=(1000, 1)) - >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - - References - ---------- - .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. - """ - - def __init__( - self, - n_basis_funcs: int, - mode="eval", - width: float = 2.0, - time_scaling: float = None, - enforce_decay_to_zero: bool = True, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "RaisedCosineBasisLog", - **kwargs, - ) -> None: - super().__init__( - n_basis_funcs, - mode=mode, - width=width, - window_size=window_size, - bounds=bounds, - **kwargs, - label=label, - ) - # The samples are scaled appropriately in the self._transform_samples which scales - # and applies the log-stretch, no additional transform is needed. - self._rescale_samples = False - if time_scaling is None: - time_scaling = 50.0 - - self.time_scaling = time_scaling - self.enforce_decay_to_zero = enforce_decay_to_zero - - @property - def time_scaling(self): - """Getter property for time_scaling.""" - return self._time_scaling - - @time_scaling.setter - def time_scaling(self, time_scaling): - """Setter property for time_scaling.""" - self._check_time_scaling(time_scaling) - self._time_scaling = time_scaling - - @staticmethod - def _check_time_scaling(time_scaling: float) -> None: - if time_scaling <= 0: - raise ValueError( - f"Only strictly positive time_scaling are allowed, {time_scaling} provided instead." - ) - - def _transform_samples( - self, - sample_pts: ArrayLike, - ) -> NDArray: - """ - Map the sample domain to log-space. - - Parameters - ---------- - sample_pts : - Sample points used for evaluating the splines, - shape (n_samples, ). - - Returns - ------- - Transformed version of the sample points that matches the Raised Cosine basis domain, - shape (n_samples, ). - """ - # rescale to [0,1] - # copy is necessary to avoid unwanted rescaling in additive/multiplicative basis. - sample_pts, _ = min_max_rescale_samples(np.copy(sample_pts), self.bounds) - # This log-stretching of the sample axis has the following effect: - # - as the time_scaling tends to 0, the points will be linearly spaced across the whole domain. - # - as the time_scaling tends to inf, basis will be small and dense around 0 and - # progressively larger and less dense towards 1. - log_spaced_pts = np.log(self.time_scaling * sample_pts + 1) / np.log( - self.time_scaling + 1 - ) - return log_spaced_pts - - def _compute_peaks(self) -> NDArray: - """ - Peak location of each log-spaced cosine basis element. - - Compute the peak location for the log-spaced raised cosine basis. - Enforcing that the last basis decays to zero is equivalent to - setting the last peak to a value smaller than 1. - - Returns - ------- - Peak locations of each basis element. - - """ - if self.enforce_decay_to_zero: - # compute the last peak location such that the last - # basis element decays to zero at the last sample. - last_peak = 1 - self.width / (self.n_basis_funcs + self.width - 1) - else: - last_peak = 1 - return np.linspace(0, last_peak, self.n_basis_funcs) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: ArrayLike, - ) -> FeatureMatrix: - """Generate log-spaced raised cosine basis with given samples. - - Parameters - ---------- - sample_pts : - Spacing for basis functions. Samples will be rescaled to the interval [0, 1]. - - Returns - ------- - basis_funcs : - Log-raised cosine basis functions, shape (n_samples, n_basis_funcs). - - Raises - ------ - ValueError - If the sample provided do not lie in [0,1]. - """ - return super().__call__(self._transform_samples(sample_pts)) - - -class OrthExponentialBasis(Basis): - """Set of 1D basis decaying exponential functions numerically orthogonalized. - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - decay_rates : - Decay rates of the exponentials, shape ``(n_basis_funcs,)``. - mode : - The mode of operation. ``'eval'`` for evaluation at sample points, - ``'conv'`` for convolutional operation. - window_size : - The window size for convolution. Required if mode is ``'conv'``. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import OrthExponentialBasis - >>> X = np.random.normal(size=(1000, 1)) - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = ortho_basis(sample_points) - """ - - def __init__( - self, - n_basis_funcs: int, - decay_rates: NDArray, - mode: Literal["eval", "conv"] = "eval", - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "OrthExponentialBasis", - **kwargs, - ): - super().__init__( - n_basis_funcs, - mode=mode, - window_size=window_size, - bounds=bounds, - label=label, - **kwargs, - ) - self.decay_rates = decay_rates - self._check_rates() - self._n_input_dimensionality = 1 - - @property - def decay_rates(self): - r"""Decay rate. - - The rate of decay of the exponential functions. If :math:`f_i(t) = e^{-\alpha_i t}` is the i-th decay - exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector. - """ - return self._decay_rates - - @decay_rates.setter - def decay_rates(self, value: NDArray): - """Decay rate setter.""" - value = np.asarray(value) - if value.shape[0] != self.n_basis_funcs: - raise ValueError( - f"The number of basis functions must match the number of decay rates provided. " - f"Number of basis functions provided: {self.n_basis_funcs}, " - f"Number of decay rates provided: {value.shape[0]}" - ) - self._decay_rates = value - - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Checks that the number of basis is at least 1. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - if self.n_basis_funcs < 1: - raise ValueError( - f"Object class {self.__class__.__name__} requires >= 1 basis elements. " - f"{self.n_basis_funcs} basis elements specified instead" - ) - - def _check_rates(self) -> None: - """ - Check if the decay rates list has duplicate entries. - - Raises - ------ - ValueError - If two or more decay rates are repeated, which would result in a linearly - dependent set of functions for the basis. - """ - if len(set(self._decay_rates)) != len(self._decay_rates): - raise ValueError( - "Two or more rate are repeated! Repeating rate will result in a " - "linearly dependent set of function for the basis." - ) - - def _check_sample_size(self, *sample_pts: NDArray) -> None: - """Check that the sample size is greater than the number of basis. - - This is necessary for the orthogonalization procedure, - that otherwise will return (sample_size, ) basis elements instead of the expected number. - - Parameters - ---------- - sample_pts - Spacing for basis functions, holding elements on the interval [0, inf). - - Raises - ------ - ValueError - If the number of basis element is less than the number of samples. - """ - if sample_pts[0].size < self.n_basis_funcs: - raise ValueError( - "OrthExponentialBasis requires at least as many samples as basis functions!\n" - f"Class instantiated with {self.n_basis_funcs} basis functions " - f"but only {sample_pts[0].size} samples provided!" - ) - - @support_pynapple(conv_type="numpy") - @check_transform_input - @check_one_dimensional - def __call__( - self, - sample_pts: NDArray, - ) -> FeatureMatrix: - """Generate basis functions with given spacing. - - Parameters - ---------- - sample_pts - Spacing for basis functions, holding elements on the interval :math:`[0,inf)`, - shape ``(n_samples,)``. - - Returns - ------- - basis_funcs - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape ``(n_samples, n_basis_funcs)``. - - """ - self._check_sample_size(sample_pts) - sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds) - valid_idx = ~np.isnan(sample_pts) - # because of how scipy.linalg.orth works, have to create a matrix of - # shape (n_pts, n_basis_funcs) and then transpose, rather than - # directly computing orth on the matrix of shape (n_basis_funcs, - # n_pts) - exp_decay_eval = np.stack( - [np.exp(-lam * sample_pts[valid_idx]) for lam in self._decay_rates], axis=1 - ) - # count the linear independent components (could be lower than n_basis_funcs for num precision). - n_independent_component = np.linalg.matrix_rank(exp_decay_eval) - # initialize output to nan - basis_funcs = np.full( - shape=(sample_pts.shape[0], n_independent_component), fill_value=np.nan - ) - # orthonormalize on valid points - basis_funcs[valid_idx] = scipy.linalg.orth(exp_decay_eval) - return basis_funcs - - def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: - """Evaluate the basis set on a grid of equi-spaced sample points. - - Parameters - ---------- - n_samples : - The number of samples. - - Returns - ------- - X : - Array of shape ``(n_samples,)`` containing the equi-spaced sample - points where we've evaluated the basis. - basis_funcs : - Evaluated exponentially decaying basis functions, numerically - orthogonalized, shape ``(n_samples, n_basis_funcs)`` - - Examples - -------- - >>> import numpy as np - >>> import matplotlib.pyplot as plt - >>> from nemos.basis import OrthExponentialBasis - >>> n_basis_funcs = 5 - >>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates - >>> window_size=10 - >>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size) - >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) - """ - return super().evaluate_on_grid(n_samples) - - -def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray: - """Compute M-spline basis function. - - Parameters - ---------- - x - Spacing for basis functions, shape (n_sample_points, ). - k - Order of the spline basis. - i - Number of the spline basis. - T - knot locations. should lie in interval [0, 1], shape (k + n_basis_funcs,). - - Returns - ------- - spline - M-spline basis function, shape (n_sample_points, ). - - Examples - -------- - >>> import numpy as np - >>> from numpy import linspace - >>> from nemos.basis import mspline - >>> sample_points = linspace(0, 1, 100) - >>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline - >>> mspline_eval.shape - (100,) - """ - # Boundary conditions. - if (T[i + k] - T[i]) < 1e-6: - return np.zeros_like(x) - - # Special base case of first-order spline basis. - elif k == 1: - v = np.zeros_like(x) - v[(x >= T[i]) & (x < T[i + 1])] = 1 / (T[i + 1] - T[i]) - return v - - # General case, defined recursively - else: - return ( - k - * ( - (x - T[i]) * mspline(x, k - 1, i, T) - + (T[i + k] - x) * mspline(x, k - 1, i + 1, T) - ) - / ((k - 1) * (T[i + k] - T[i])) - ) - - -def bspline( - sample_pts: NDArray, - knots: NDArray, - order: int = 4, - der: int = 0, - outer_ok: bool = False, -) -> NDArray: - """ - Calculate and return the evaluation of B-spline basis. - - This function evaluates B-spline basis for given sample points. It checks for - out of range points and optionally handles them. It also handles the NaNs if present. - - Parameters - ---------- - sample_pts : - An array containing sample points for which B-spline basis needs to be evaluated, - shape (n_samples,) - knots : - An array containing knots for the B-spline basis. The knots are sorted in ascending order. - order : - The order of the B-spline basis. - der : - The derivative of the B-spline basis to be evaluated. - outer_ok : - If True, allows for evaluation at points outside the range of knots. - Default is False, in which case an assertion error is raised when - points outside the knots range are encountered. - - Returns - ------- - basis_eval : - An array containing the evaluation of B-spline basis for the given sample points. - Shape (n_samples, n_basis_funcs). - - Raises - ------ - AssertionError - If ``outer_ok`` is False and the sample points lie outside the B-spline knots range. - - Notes - ----- - The function uses splev function from scipy.interpolate library for the basis evaluation. - - Examples - -------- - >>> import numpy as np - >>> from numpy import linspace - >>> from nemos.basis import bspline - >>> sample_points = linspace(0, 1, 100) - >>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1]) - >>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline - >>> bspline_eval.shape - (100, 10) - """ - knots.sort() - nk = knots.shape[0] - - # check for out of range points (in cyclic b-spline need_outer must be set to False) - need_outer = any(sample_pts < knots[order - 1]) or any( - sample_pts > knots[nk - order] - ) - assert ( - not need_outer - ) | outer_ok, 'sample points must lie within the B-spline knots range unless "outer_ok==True".' - - # select knots that are within the knots range (this takes care of eventual NaNs) - in_sample = (sample_pts >= knots[0]) & (sample_pts <= knots[-1]) - - if need_outer: - reps = order - 1 - knots = np.hstack((np.ones(reps) * knots[0], knots, np.ones(reps) * knots[-1])) - nk = knots.shape[0] - else: - reps = 0 - - # number of basis elements - n_basis = nk - order - - # initialize the basis element container - basis_eval = np.full((n_basis - 2 * reps, sample_pts.shape[0]), np.nan) - - # loop one element at the time and evaluate the basis using splev - id_basis = np.eye(n_basis, nk, dtype=np.int8) - for i in range(reps, len(knots) - order - reps): - basis_eval[i - reps, in_sample] = splev( - sample_pts[in_sample], (knots, id_basis[i], order - 1), der=der - ) - - return basis_eval.T From 9927a2ca80a8f3ea5d9519fbf8bdf9ed0a5169d1 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:01:03 -0500 Subject: [PATCH 026/109] fixed first test --- tests/test_basis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 2d5fe8f7..c1d37201 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -16,7 +16,6 @@ import nemos.basis.basis as basis import nemos.convolve as convolve -from nemos.basis import EvalMSpline from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring from nemos.basis._decaying_exponential import OrthExponentialBasis from nemos.basis._raised_cosine_basis import ( @@ -253,23 +252,24 @@ def cls(self): pass + class TestRaisedCosineLogBasis(BasisFuncsTesting): - cls = basis.RaisedCosineBasisLog + cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_non_empty_samples(self, samples, mode, window_size): + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_non_empty_samples(self, samples, mode, kwargs): if mode == "conv" and len(samples) == 1: return if len(samples) == 0: with pytest.raises( ValueError, match="All sample provided must be non empty" ): - self.cls(5, mode=mode, window_size=window_size).compute_features( + self.cls[mode](5, **kwargs).compute_features( samples ) else: - self.cls(5, mode=mode, window_size=window_size).compute_features(samples) + self.cls[mode](5, **kwargs).compute_features(samples) @pytest.mark.parametrize( "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] From e89543b9a5b1e1ce526527809264a9d602f91b13 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:02:55 -0500 Subject: [PATCH 027/109] fixed second test --- tests/test_basis.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index c1d37201..ca4132c9 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -271,14 +271,9 @@ def test_non_empty_samples(self, samples, mode, kwargs): else: self.cls[mode](5, **kwargs).compute_features(samples) - @pytest.mark.parametrize( - "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] - ) + @pytest.mark.parametrize("eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])]) def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the evaluate() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5) + basis_obj = self.cls["eval"](n_basis_funcs=5) basis_obj.compute_features(eval_input) @pytest.mark.parametrize( From b2dd18fc1c49c882482efe0a5e390f20ef2940a3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:07:59 -0500 Subject: [PATCH 028/109] fixed test_set_width --- tests/test_basis.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index ca4132c9..dd0dfbc5 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -282,31 +282,35 @@ def test_compute_features_input(self, eval_input): (10, does_not_raise()), (10.5, does_not_raise()), ( - 0.5, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), + 0.5, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), ), ( - 10.3, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), + 10.3, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), ), ( - -10, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), + -10, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), ), (None, pytest.raises(TypeError, match="'<=' not supported between")), ], ) - def test_set_width(self, width, expectation): - basis_obj = self.cls(n_basis_funcs=5) + @pytest.mark.parametrize("cls, kwargs", [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 5}), + ]) + def test_set_width(self, width, expectation, cls, kwargs): + basis_obj = cls(n_basis_funcs=5, **kwargs) with expectation: basis_obj.width = width with expectation: From 0a288c6550732a27688f51627b1b7493f186408f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:12:40 -0500 Subject: [PATCH 029/109] fixed test_compute_features_axis --- tests/test_basis.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index dd0dfbc5..f64e42ef 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -321,27 +321,24 @@ def test_set_width(self, width, expectation, cls, kwargs): [ (dict(), (10,), does_not_raise()), ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), + dict(axis=0), + (10,), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), ), ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), + dict(axis=1), + (2, 10), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), ), ], ) def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ with expectation: - basis_obj = self.cls(n_basis_funcs=5, mode="conv", window_size=5, **kwargs) + basis_obj = self.cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs) basis_obj.compute_features(np.ones(input1_shape)) @pytest.mark.parametrize("n_basis_funcs", [4, 5]) From e70c07cc972ae0ec55aee33ac874b483a33485a0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:13:49 -0500 Subject: [PATCH 030/109] fixed test_compute_features_conv_input --- tests/test_basis.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index f64e42ef..5d547a80 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -357,24 +357,23 @@ def test_compute_features_axis(self, kwargs, input1_shape, expectation): ], ) def test_compute_features_conv_input( - self, - n_basis_funcs, - time_scaling, - enforce_decay, - window_size, - input_shape, - expected_n_input, + self, + n_basis_funcs, + time_scaling, + enforce_decay, + window_size, + input_shape, + expected_n_input, ): x = np.ones(input_shape) - bas = self.cls( + basis_obj = self.cls["conv"]( n_basis_funcs=n_basis_funcs, time_scaling=time_scaling, - mode="conv", window_size=window_size, enforce_decay_to_zero=enforce_decay, ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * bas.n_basis_funcs + out = basis_obj.compute_features(x) + assert out.shape[1] == expected_n_input * basis_obj.n_basis_funcs @pytest.mark.parametrize( "args, sample_size", From 417f67b28bd3d13d248de3568f0b05a8dc58a786 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:19:41 -0500 Subject: [PATCH 031/109] fixed test_compute_features_returns_expected_number_of_basis --- tests/test_basis.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 5d547a80..53008e61 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -379,23 +379,27 @@ def test_compute_features_conv_input( "args, sample_size", [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) def test_compute_features_returns_expected_number_of_basis( - self, args, mode, window_size, sample_size + self, args, sample_size, cls, kwargs ): """ - Verifies the number of basis functions returned by the evaluate() method matches + Verifies the number of basis functions returned by the compute_features() method matches the expected number of basis functions. """ - basis_obj = self.cls(mode=mode, window_size=window_size, **args) + basis_obj = cls(**args, **kwargs) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[1] != args["n_basis_funcs"]: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the evaluated basis." - f"The number of basis is {args['n_basis_funcs']}", - f"The first dimension of the evaluated basis is {eval_basis.shape[1]}", - ) - return + assert eval_basis.shape[1] == args["n_basis_funcs"], ( + "Dimensions do not agree: The number of basis should match the first dimension " + f"of the evaluated basis. The number of basis is {args['n_basis_funcs']}, but the " + f"evaluated basis has dimension {eval_basis.shape[1]}" + ) @pytest.mark.parametrize("sample_size", [100, 1000]) @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) From 77795b1b228456c389982ed803756dc0f4be1fe7 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:35:29 -0500 Subject: [PATCH 032/109] fixed test_number_of_required_inputs_compute_features --- tests/test_basis.py | 108 ++++++++++++++++++++++---------------------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 53008e61..e13fe5bf 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -403,96 +403,96 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("sample_size", [100, 1000]) @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, mode, window_size + self, n_basis_funcs, sample_size, cls, kwargs ): """ - Checks that the sample size of the output from the evaluate() method matches the input sample size. + Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size - ) + basis_obj = cls(n_basis_funcs=n_basis_funcs, **kwargs) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the evaluated basis." - f"The window size is {sample_size}", - f"The second dimension of the evaluated basis is {eval_basis.shape[0]}", - ) + assert eval_basis.shape[0] == sample_size, ( + f"Dimensions do not agree: The sample size of the output should match the input sample size. " + f"Expected {sample_size}, but got {eval_basis.shape[0]}." + ) @pytest.mark.parametrize( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), + (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), ), ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), ), ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) + """ + Tests that compute_features handles samples correctly within specified bounds. + """ + basis_obj = self.cls["eval"](5, bounds=(vmin, vmax)) with expectation: - bas(samples) + basis_obj.compute_features(samples) @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, mode, window_size - ): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, cls, kwargs): """ Verifies that the minimum number of basis functions required (i.e., 2) is enforced. """ - raise_exception = n_basis_funcs < 2 - if raise_exception: + if n_basis_funcs < 2: with pytest.raises( - ValueError, - match=f"Object class {self.cls.__name__} " - "requires >= 2 basis elements.", + ValueError, + match=f"Object class {cls.__name__} requires >= 2 basis elements.", ): - self.cls( - n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size - ) + cls(n_basis_funcs=n_basis_funcs, **kwargs) else: - self.cls(n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size) + cls(n_basis_funcs=n_basis_funcs, **kwargs) @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_number_of_required_inputs_compute_features(self, n_input, cls, kwargs): """ - Confirms that the compute_features() method correctly handles the number of input samples that are provided. + Confirms that the compute_features() method correctly handles the number of input samples provided. """ - basis_obj = self.cls(n_basis_funcs=5, mode=mode, window_size=window_size) + basis_obj = cls(n_basis_funcs=5, **kwargs) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises( - TypeError, match="Input dimensionality mismatch" - ) + expectation = pytest.raises(TypeError, match="missing 1 required positional argument") elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) + expectation = pytest.raises(TypeError, match="takes 2 positional arguments but \d were given") else: expectation = does_not_raise() + with expectation: basis_obj.compute_features(*inputs) From f6c7eadba32bdc6960a396c6990cb4815461b0cc Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:36:03 -0500 Subject: [PATCH 033/109] fixed test_evaluate_on_grid_meshgrid_size --- tests/test_basis.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index e13fe5bf..96effcd4 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -497,15 +497,21 @@ def test_number_of_required_inputs_compute_features(self, n_input, cls, kwargs): basis_obj.compute_features(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_evaluate_on_grid_meshgrid_size(self, sample_size, cls, kwargs): """ Checks that the evaluate_on_grid() method returns a grid of the expected size. """ - basis_obj = self.cls(n_basis_funcs=5) - raise_exception = sample_size <= 0 - if raise_exception: + basis_obj = cls(n_basis_funcs=5, **kwargs) + if sample_size <= 0: with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" + ValueError, match=r"All sample counts provided must be greater" ): basis_obj.evaluate_on_grid(sample_size) else: From 519c0d70e4cec0a204e3b1ac91357523489157c2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:39:37 -0500 Subject: [PATCH 034/109] fixed test_evaluate_on_grid_input_number --- tests/test_basis.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 96effcd4..12069d0d 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -519,15 +519,21 @@ def test_evaluate_on_grid_meshgrid_size(self, sample_size, cls, kwargs): assert grid.shape[0] == sample_size @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_evaluate_on_grid_basis_size(self, sample_size, cls, kwargs): """ Ensures that the evaluate_on_grid() method returns basis functions of the expected size. """ - basis_obj = self.cls(n_basis_funcs=5) - raise_exception = sample_size <= 0 - if raise_exception: + basis_obj = cls(n_basis_funcs=5, **kwargs) + if sample_size <= 0: with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" + ValueError, match=r"All sample counts provided must be greater" ): basis_obj.evaluate_on_grid(sample_size) else: @@ -535,11 +541,18 @@ def test_evaluate_on_grid_basis_size(self, sample_size): assert eval_basis.shape[0] == sample_size @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_evaluate_on_grid_input_number(self, n_input, cls, kwargs): """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. + Validates that the evaluate_on_grid() method correctly handles the number of input samples provided. """ - basis_obj = self.cls(n_basis_funcs=5) + basis_obj = cls(n_basis_funcs=5, **kwargs) inputs = [10] * n_input if n_input == 0: expectation = pytest.raises( @@ -553,6 +566,7 @@ def test_evaluate_on_grid_input_number(self, n_input): ) else: expectation = does_not_raise() + with expectation: basis_obj.evaluate_on_grid(*inputs) From 117444723f6e5689555802c65fdd2d30ce7ac925 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:41:53 -0500 Subject: [PATCH 035/109] fixed test_time_scaling_values --- tests/test_basis.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 12069d0d..ab93a886 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -571,7 +571,7 @@ def test_evaluate_on_grid_input_number(self, n_input, cls, kwargs): basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize( - "width ,expectation", + "width, expectation", [ (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), @@ -582,17 +582,24 @@ def test_evaluate_on_grid_input_number(self, n_input, cls, kwargs): (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - def test_width_values(self, width, expectation): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 2}), + ], + ) + def test_width_values(self, width, expectation, cls, kwargs): """Test allowable widths: integer multiple of 1/2, greater than 1.""" with expectation: - self.cls(n_basis_funcs=5, width=width) + cls(n_basis_funcs=5, width=width, **kwargs) @pytest.mark.parametrize("width", [1.5, 2, 2.5]) def test_decay_to_zero_basis_number_match(self, width): """Test that the number of basis is preserved.""" n_basis_funcs = 10 - _, ev = self.cls( - n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True + _, ev = self.cls["conv"]( + n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True, window_size=5 ).evaluate_on_grid(2) assert ev.shape[1] == n_basis_funcs, ( "Basis function number mismatch. " @@ -600,28 +607,25 @@ def test_decay_to_zero_basis_number_match(self, width): ) @pytest.mark.parametrize( - "time_scaling ,expectation", + "time_scaling, expectation", [ - ( - -1, - pytest.raises( - ValueError, match="Only strictly positive time_scaling are allowed" - ), - ), - ( - 0, - pytest.raises( - ValueError, match="Only strictly positive time_scaling are allowed" - ), - ), + (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), (0.1, does_not_raise()), (10, does_not_raise()), ], ) - def test_time_scaling_values(self, time_scaling, expectation): + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 5}), + ], + ) + def test_time_scaling_values(self, time_scaling, expectation, cls, kwargs): """Test that only positive time_scaling are allowed.""" with expectation: - self.cls(n_basis_funcs=5, time_scaling=time_scaling) + cls(n_basis_funcs=5, time_scaling=time_scaling, **kwargs) def test_time_scaling_property(self): """Test that larger time_scaling results in larger departures from linearity.""" From 38b499bf87ecc4a703d98e5940ded7390000636d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:46:11 -0500 Subject: [PATCH 036/109] fixed some test calls --- tests/test_basis.py | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index ab93a886..0a7f00f5 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -631,11 +631,11 @@ def test_time_scaling_property(self): """Test that larger time_scaling results in larger departures from linearity.""" time_scaling = [0.1, 10, 100] n_basis_funcs = 5 - _, lin_ev = basis.RaisedCosineBasisLinear(n_basis_funcs).evaluate_on_grid(100) + _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) corr = np.zeros(len(time_scaling)) for idx, ts in enumerate(time_scaling): # set default decay to zero to get comparable basis - basis_log = self.cls( + basis_log = self.cls["eval"]( n_basis_funcs=n_basis_funcs, time_scaling=ts, enforce_decay_to_zero=False, @@ -643,7 +643,7 @@ def test_time_scaling_property(self): _, log_ev = basis_log.evaluate_on_grid(100) # compute the correlation corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( - np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) ) # check that the correlation decreases as time_scale increases assert np.all( @@ -653,13 +653,16 @@ def test_time_scaling_property(self): @pytest.mark.parametrize("sample_size", [30]) @pytest.mark.parametrize("n_basis", [5]) def test_pynapple_support_compute_features(self, n_basis, sample_size): + """ + Test compute_features compatibility with pynapple Tsd input. + """ iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( t=np.linspace(0, 1, sample_size), d=np.linspace(0, 1, sample_size), time_support=iset, ) - out = self.cls(n_basis).compute_features(inp) + out = self.cls["eval"](n_basis_funcs=n_basis).compute_features(inp) assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) @@ -672,9 +675,18 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 3}), + ], + ) + def test_call_input_num(self, num_input, cls, kwargs, expectation): + """ + Test handling of input dimensionality mismatch when calling the basis. + """ + bas = cls(n_basis_funcs=5, **kwargs) with expectation: bas(*([np.linspace(0, 1, 10)] * num_input)) @@ -685,9 +697,18 @@ def test_call_input_num(self, num_input, mode, window_size, expectation): (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize( + "cls, kwargs", + [ + (basis.EvalRaisedCosineLog, {}), + (basis.ConvRaisedCosineLog, {"window_size": 3}), + ], + ) + def test_call_input_shape(self, inp, cls, kwargs, expectation): + """ + Test handling of input shape mismatch when calling the basis. + """ + bas = cls(n_basis_funcs=5, **kwargs) with expectation: bas(inp) From 33a99952e2d092d77c15ad6bd2801e7c674ec0dd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 11:55:34 -0500 Subject: [PATCH 037/109] moved to self.cls --- tests/test_basis.py | 284 +++++++++----------------------------------- 1 file changed, 59 insertions(+), 225 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 0a7f00f5..ba37406d 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -262,12 +262,8 @@ def test_non_empty_samples(self, samples, mode, kwargs): if mode == "conv" and len(samples) == 1: return if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls[mode](5, **kwargs).compute_features( - samples - ) + with pytest.raises(ValueError, match="All sample provided must be non empty"): + self.cls[mode](5, **kwargs).compute_features(samples) else: self.cls[mode](5, **kwargs).compute_features(samples) @@ -281,36 +277,15 @@ def test_compute_features_input(self, eval_input): [ (10, does_not_raise()), (10.5, does_not_raise()), - ( - 0.5, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), - ( - 10.3, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), - ( - -10, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), + (0.5, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + (10.3, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + (-10, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), (None, pytest.raises(TypeError, match="'<=' not supported between")), ], ) - @pytest.mark.parametrize("cls, kwargs", [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 5}), - ]) - def test_set_width(self, width, expectation, cls, kwargs): - basis_obj = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + def test_set_width(self, width, expectation, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: basis_obj.width = width with expectation: @@ -320,20 +295,8 @@ def test_set_width(self, width, expectation, cls, kwargs): "kwargs, input1_shape, expectation", [ (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), + (dict(axis=0), (10,), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), + (dict(axis=1), (2, 10), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), ], ) def test_compute_features_axis(self, kwargs, input1_shape, expectation): @@ -357,13 +320,13 @@ def test_compute_features_axis(self, kwargs, input1_shape, expectation): ], ) def test_compute_features_conv_input( - self, - n_basis_funcs, - time_scaling, - enforce_decay, - window_size, - input_shape, - expected_n_input, + self, + n_basis_funcs, + time_scaling, + enforce_decay, + window_size, + input_shape, + expected_n_input, ): x = np.ones(input_shape) basis_obj = self.cls["conv"]( @@ -379,21 +342,9 @@ def test_compute_features_conv_input( "args, sample_size", [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], ) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_compute_features_returns_expected_number_of_basis( - self, args, sample_size, cls, kwargs - ): - """ - Verifies the number of basis functions returned by the compute_features() method matches - the expected number of basis functions. - """ - basis_obj = cls(**args, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_compute_features_returns_expected_number_of_basis(self, args, sample_size, mode, kwargs): + basis_obj = self.cls[mode](**args, **kwargs) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[1] == args["n_basis_funcs"], ( "Dimensions do not agree: The number of basis should match the first dimension " @@ -403,20 +354,9 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("sample_size", [100, 1000]) @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, cls, kwargs - ): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = cls(n_basis_funcs=n_basis_funcs, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_funcs, sample_size, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[0] == sample_size, ( f"Dimensions do not agree: The sample size of the output should match the input sample size. " @@ -429,62 +369,30 @@ def test_sample_size_of_compute_features_matches_that_of_input( (0.5, 0, 1, does_not_raise()), (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), + (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - """ - Tests that compute_features handles samples correctly within specified bounds. - """ basis_obj = self.cls["eval"](5, bounds=(vmin, vmax)) with expectation: basis_obj.compute_features(samples) @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, cls, kwargs): - """ - Verifies that the minimum number of basis functions required (i.e., 2) is enforced. - """ + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): if n_basis_funcs < 2: with pytest.raises( - ValueError, - match=f"Object class {cls.__name__} requires >= 2 basis elements.", + ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", ): - cls(n_basis_funcs=n_basis_funcs, **kwargs) + self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) else: - cls(n_basis_funcs=n_basis_funcs, **kwargs) + self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_number_of_required_inputs_compute_features(self, n_input, cls, kwargs): - """ - Confirms that the compute_features() method correctly handles the number of input samples provided. - """ - basis_obj = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: expectation = pytest.raises(TypeError, match="missing 1 required positional argument") @@ -497,72 +405,39 @@ def test_number_of_required_inputs_compute_features(self, n_input, cls, kwargs): basis_obj.compute_features(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_evaluate_on_grid_meshgrid_size(self, sample_size, cls, kwargs): - """ - Checks that the evaluate_on_grid() method returns a grid of the expected size. - """ - basis_obj = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) if sample_size <= 0: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): + with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): basis_obj.evaluate_on_grid(sample_size) else: grid, _ = basis_obj.evaluate_on_grid(sample_size) assert grid.shape[0] == sample_size @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_evaluate_on_grid_basis_size(self, sample_size, cls, kwargs): - """ - Ensures that the evaluate_on_grid() method returns basis functions of the expected size. - """ - basis_obj = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) if sample_size <= 0: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): + with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): basis_obj.evaluate_on_grid(sample_size) else: _, eval_basis = basis_obj.evaluate_on_grid(sample_size) assert eval_basis.shape[0] == sample_size @pytest.mark.parametrize("n_input", [0, 1, 2]) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_evaluate_on_grid_input_number(self, n_input, cls, kwargs): - """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples provided. - """ - basis_obj = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs): + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) inputs = [10] * n_input if n_input == 0: expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", + TypeError, match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + TypeError, match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -582,21 +457,13 @@ def test_evaluate_on_grid_input_number(self, n_input, cls, kwargs): (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 2}), - ], - ) - def test_width_values(self, width, expectation, cls, kwargs): - """Test allowable widths: integer multiple of 1/2, greater than 1.""" + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_width_values(self, width, expectation, mode, kwargs): with expectation: - cls(n_basis_funcs=5, width=width, **kwargs) + self.cls[mode](n_basis_funcs=5, width=width, **kwargs) @pytest.mark.parametrize("width", [1.5, 2, 2.5]) def test_decay_to_zero_basis_number_match(self, width): - """Test that the number of basis is preserved.""" n_basis_funcs = 10 _, ev = self.cls["conv"]( n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True, window_size=5 @@ -615,37 +482,26 @@ def test_decay_to_zero_basis_number_match(self, width): (10, does_not_raise()), ], ) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 5}), - ], - ) - def test_time_scaling_values(self, time_scaling, expectation, cls, kwargs): - """Test that only positive time_scaling are allowed.""" + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): with expectation: - cls(n_basis_funcs=5, time_scaling=time_scaling, **kwargs) + self.cls[mode](n_basis_funcs=5, time_scaling=time_scaling, **kwargs) def test_time_scaling_property(self): - """Test that larger time_scaling results in larger departures from linearity.""" time_scaling = [0.1, 10, 100] n_basis_funcs = 5 _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) corr = np.zeros(len(time_scaling)) for idx, ts in enumerate(time_scaling): - # set default decay to zero to get comparable basis basis_log = self.cls["eval"]( n_basis_funcs=n_basis_funcs, time_scaling=ts, enforce_decay_to_zero=False, ) _, log_ev = basis_log.evaluate_on_grid(100) - # compute the correlation corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( - np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) ) - # check that the correlation decreases as time_scale increases assert np.all( np.diff(corr) < 0 ), "As time scales increases, deviation from linearity should increase!" @@ -653,9 +509,6 @@ def test_time_scaling_property(self): @pytest.mark.parametrize("sample_size", [30]) @pytest.mark.parametrize("n_basis", [5]) def test_pynapple_support_compute_features(self, n_basis, sample_size): - """ - Test compute_features compatibility with pynapple Tsd input. - """ iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( t=np.linspace(0, 1, sample_size), @@ -666,7 +519,6 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) - # TEST CALL @pytest.mark.parametrize( "num_input, expectation", [ @@ -675,18 +527,9 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size): (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), ], ) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 3}), - ], - ) - def test_call_input_num(self, num_input, cls, kwargs, expectation): - """ - Test handling of input dimensionality mismatch when calling the basis. - """ - bas = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_input_num(self, num_input, mode, kwargs, expectation): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: bas(*([np.linspace(0, 1, 10)] * num_input)) @@ -697,18 +540,9 @@ def test_call_input_num(self, num_input, cls, kwargs, expectation): (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize( - "cls, kwargs", - [ - (basis.EvalRaisedCosineLog, {}), - (basis.ConvRaisedCosineLog, {"window_size": 3}), - ], - ) - def test_call_input_shape(self, inp, cls, kwargs, expectation): - """ - Test handling of input shape mismatch when calling the basis. - """ - bas = cls(n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_input_shape(self, inp, mode, kwargs, expectation): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: bas(inp) From 5bf53c89feb4d1f5bbec5cae20309848dd088b84 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 12:03:39 -0500 Subject: [PATCH 038/109] refactored a bunch of raised cos tests --- tests/test_basis.py | 87 ++++++++++++++++++--------------------------- 1 file changed, 35 insertions(+), 52 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index ba37406d..a9583bfb 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -547,14 +547,14 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation): bas(inp) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_sample_axis(self, time_axis_shape, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_nan(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_nan(self, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) x = np.linspace(0, 1, 10) x[3] = np.nan assert all(np.isnan(bas(x)[3])) @@ -564,25 +564,25 @@ def test_call_nan(self, mode, window_size): [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), ), ], ) def test_call_input_type(self, samples, expectation): - bas = self.cls(5) + bas = self.cls["eval"](n_basis_funcs=5) # Only eval mode is relevant here with expectation: bas(samples) def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10) - bas_eva = self.cls(5, mode="eval") + bas_con = self.cls["conv"](n_basis_funcs=5, window_size=10) + bas_eval = self.cls["eval"](n_basis_funcs=5) x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) + assert np.all(bas_con(x) == bas_eval(x)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_pynapple_support(self, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) x = np.linspace(0, 1, 10) x_nap = nap.Tsd(t=np.arange(10), d=x) y = bas(x) @@ -592,16 +592,16 @@ def test_pynapple_support(self, mode, window_size): assert np.all(y_nap.t == x_nap.t) @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_basis_number(self, n_basis, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=n_basis, **kwargs) x = np.linspace(0, 1, 10) assert bas(x).shape[1] == n_basis @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_non_empty(self, n_basis, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=n_basis, **kwargs) with pytest.raises(ValueError, match="All sample provided must"): bas(np.array([])) @@ -612,51 +612,34 @@ def test_call_non_empty(self, n_basis, mode, window_size): (-2, 2, does_not_raise()), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_sample_range(self, mn, mx, expectation, mode, kwargs): + bas = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: bas(np.linspace(mn, mx, 10)) def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) + bas = self.cls["conv"](n_basis_funcs=5, window_size=3) + bas._set_kernel() assert bas.kernel_ is not None def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) + bas = self.cls["conv"](n_basis_funcs=5, window_size=3) + bas._set_kernel() assert bas.kernel_.shape == (3, 5) def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=3) + bas = self.cls["conv"](n_basis_funcs=5, window_size=3) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + ValueError, match="You must call `_set_kernel` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 2 - with expectation: - self.cls(5, mode=mode, window_size=window_size) - @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label): - bas = self.cls(5, label=label) - assert bas.label == (str(label) if label is not None else self.cls.__name__) + bas = self.cls["eval"](n_basis_funcs=5, label=label) + expected_label = str(label) if label is not None else self.cls["eval"].__name__ + assert bas.label == expected_label @pytest.mark.parametrize( "attribute, value", @@ -668,9 +651,9 @@ def test_init_label(self, label): ], ) def test_attr_setter(self, attribute, value): - bas = self.cls(5) + bas = self.cls["eval"](n_basis_funcs=5) with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" + AttributeError, match=rf"can't set attribute|property '{attribute}' of" ): setattr(bas, attribute, value) From 37ac5085991a8f031192871900c80e2237f059f3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 13:04:21 -0500 Subject: [PATCH 039/109] finished test raised cos log --- src/nemos/basis/_basis_mixin.py | 3 + tests/test_basis.py | 207 +++++++++++++++----------------- 2 files changed, 99 insertions(+), 111 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 4ed23bb7..16848341 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -104,6 +104,9 @@ def _compute_features(self, *xi: ArrayLike): as multiple arguments, each representing a different dimension for multivariate inputs. """ + if self.kernel_ is None: + raise ValueError("You must call `_set_kernel` before `_compute_features`! " + "Convolution kernel is not set.") # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. diff --git a/tests/test_basis.py b/tests/test_basis.py index a9583bfb..e338a60c 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -659,14 +659,14 @@ def test_attr_setter(self, attribute, value): @pytest.mark.parametrize("n_input", [1, 2, 3]) def test_set_num_output_features(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) + bas = self.cls["conv"](n_basis_funcs=5, window_size=10) assert bas.n_output_features is None bas.compute_features(np.random.randn(20, n_input)) assert bas.n_output_features == n_input * bas.n_basis_funcs @pytest.mark.parametrize("n_input", [1, 2, 3]) def test_set_num_basis_input(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) + bas = self.cls["conv"](n_basis_funcs=5, window_size=10) assert bas.n_basis_input is None bas.compute_features(np.random.randn(20, n_input)) assert bas.n_basis_input == (n_input,) @@ -682,7 +682,7 @@ def test_set_num_basis_input(self, n_input): ], ) def test_expected_input_number(self, n_input, expectation): - bas = self.cls(5, mode="conv", window_size=10) + bas = self.cls["conv"](n_basis_funcs=5, window_size=10) x = np.random.randn(20, 2) bas.compute_features(x) with expectation: @@ -693,75 +693,72 @@ def test_expected_input_number(self, n_input, expectation): [ (dict(), does_not_raise()), ( - dict(axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), + dict(axis=0), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), ), ( - dict(axis=1), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), + dict(axis=1), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), ), (dict(shift=True), does_not_raise()), ( - dict(shift=True, axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), + dict(shift=True, axis=0), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), ), ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), ), (dict(shift=True, predictor_causality="causal"), does_not_raise()), ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), ), ], ) def test_init_conv_kwargs(self, conv_kwargs, expectation): with expectation: - self.cls(5, mode="conv", window_size=200, **conv_kwargs) + self.cls["conv"](n_basis_funcs=5, window_size=200, conv_kwargs=conv_kwargs) @pytest.mark.parametrize( "mode, ws, expectation", [ ("conv", 2, does_not_raise()), ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), + "conv", + -1, + pytest.raises(ValueError, match="`window_size` must be a positive "), ), ( - "conv", - None, - pytest.raises( - ValueError, - match="If the basis is in `conv` mode, you must provide a ", - ), + "conv", + None, + pytest.raises( + ValueError, + match="If the basis is in `conv` mode, you must provide a ", + ), ), ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), + "conv", + 1.5, + pytest.raises(ValueError, match="`window_size` must be a positive "), ), - ("eval", None, does_not_raise()), + ("eval", None, pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + )), ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), + "eval", + 10, + pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + ), ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: - self.cls(5, mode=mode, window_size=ws) + self.cls[mode](n_basis_funcs=5, window_size=ws) @pytest.mark.parametrize( "enforce_decay_to_zero, time_scaling, width, window_size, n_basis_funcs, bounds, mode", @@ -771,14 +768,14 @@ def test_init_window_size(self, mode, ws, expectation): ], ) def test_set_params( - self, - enforce_decay_to_zero, - time_scaling, - width, - window_size, - n_basis_funcs, - bounds, - mode: Literal["eval", "conv"], + self, + enforce_decay_to_zero, + time_scaling, + width, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -789,72 +786,62 @@ def test_set_params( n_basis_funcs=n_basis_funcs, bounds=bounds, ) + if window_size is None: + pars.pop("window_size") + if bounds is None: + pars.pop("bounds") + keys = list(pars.keys()) - bas = self.cls( - enforce_decay_to_zero=enforce_decay_to_zero, - time_scaling=time_scaling, - width=width, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - mode=mode, + bas = self.cls[mode]( + **pars ) for i in range(len(pars)): for j in range(i + 1, len(pars)): par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} bas = bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) + assert isinstance(bas, self.cls[mode]) @pytest.mark.parametrize( "mode, expectation", [ ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ("conv", pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), ], ) def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) + kwargs = {"bounds": (1, 2)} with expectation: - self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + self.cls[mode](n_basis_funcs=10, **kwargs) - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) + if mode == "conv": + bas = self.cls["conv"](n_basis_funcs=10, window_size=10) + with pytest.raises(ValueError, match="Invalid parameter 'bounds' for estimator"): + bas.set_params(bounds=(1, 2)) @pytest.mark.parametrize( "mode, expectation", [ ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ("eval", pytest.raises(TypeError, match="got an unexpected keyword argument 'window_size'")), ], ) def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" + kwargs = {"window_size": 10} with expectation: - self.cls(window_size=10, n_basis_funcs=10, mode=mode) + self.cls[mode](n_basis_funcs=10, **kwargs) - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) + if mode == "conv": + bas = self.cls["conv"](n_basis_funcs=10, window_size=10) + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) - bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) + if mode == "eval": + bas = self.cls["eval"](n_basis_funcs=10) + with pytest.raises(ValueError, match="Invalid parameter 'window_size' for estimator"): + bas.set_params(window_size=10) def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10) + bas = self.cls["conv"](n_basis_funcs=5, window_size=10) x = np.random.normal(size=100) conv = bas.compute_features(x) conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) @@ -863,8 +850,8 @@ def test_convolution_is_performed(self): assert np.all(np.isnan(conv_2[~valid])) def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, mode="eval", test="hi") + with pytest.raises(TypeError, match="got an unexpected keyword argument 'test'"): + self.cls["eval"](n_basis_funcs=5, test="hi") @pytest.mark.parametrize( "bounds, expectation", @@ -877,16 +864,16 @@ def test_conv_kwargs_error(self): ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), + (1, 2, 3), + pytest.raises( + ValueError, match="The provided `bounds` must be of length two" + ), ), ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: - bas = self.cls(3, bounds=bounds) + bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize( @@ -900,15 +887,15 @@ def test_vmin_vmax_init(self, bounds, expectation): ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), ), ], ) def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls(3, bounds=(1, 3)) + bas = self.cls["eval"](n_basis_funcs=3, bounds=(1, 3)) with expectation: bas.set_params(bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None @@ -924,7 +911,7 @@ def test_vmin_vmax_setter(self, bounds, expectation): ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) - bas = self.cls(3, mode="eval", bounds=bounds) + bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) out = bas.compute_features(samples) assert np.all(np.isnan(out[nan_idx])) valid_idx = list(set(samples).difference(nan_idx)) @@ -938,11 +925,9 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): (1, 3, np.arange(5), [0, 4]), ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval( - self, vmin, vmax, samples, nan_idx - ): - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) + def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + bas_no_range = self.cls["eval"](n_basis_funcs=3, bounds=None) + bas = self.cls["eval"](n_basis_funcs=3, bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) _, out2 = bas_no_range.evaluate_on_grid(10) assert np.allclose(out1, out2) @@ -957,8 +942,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval( ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=bounds) + bas_no_range = self.cls["eval"](n_basis_funcs=3, bounds=None) + bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) x1, _ = bas.evaluate_on_grid(10) x2, _ = bas_no_range.evaluate_on_grid(10) assert np.allclose(x1, x2 * (mx - mn) + mn) @@ -966,18 +951,18 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx @pytest.mark.parametrize( "bounds, samples, exception", [ - (None, np.arange(5), does_not_raise()), - ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + (None, np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((0, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((1, 4), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((1, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: - self.cls(3, mode="conv", window_size=10, bounds=bounds) + self.cls["conv"](n_basis_funcs=3, window_size=10, bounds=bounds) def test_transformer_get_params(self): - bas = self.cls(5) + bas = self.cls["eval"](n_basis_funcs=5) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") From 2624a0aab31dabc4377bbafbc625330f319f73ae Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 14:45:19 -0500 Subject: [PATCH 040/109] test jointly shared methods --- src/nemos/basis/_basis_mixin.py | 18 +- src/nemos/basis/basis.py | 2 +- tests/test_basis.py | 494 ++++++++++++++++++++++++++++++++ 3 files changed, 507 insertions(+), 7 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 16848341..1f3b342d 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -81,8 +81,7 @@ class ConvBasisMixin: def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.window_size = window_size - self._conv_kwargs = {} if conv_kwargs is None else conv_kwargs - self._check_convolution_kwargs() + self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs def _compute_features(self, *xi: ArrayLike): """ @@ -171,7 +170,14 @@ def conv_kwargs(self): """ return self._conv_kwargs - def _check_convolution_kwargs(self): + @conv_kwargs.setter + def conv_kwargs(self, values: dict): + """Check and set convolution kwargs.""" + self._check_convolution_kwargs(values) + self._conv_kwargs = values + + @staticmethod + def _check_convolution_kwargs(conv_kwargs: dict): """Check convolution kwargs settings. Raises @@ -183,7 +189,7 @@ def _check_convolution_kwargs(self): If ``self._conv_kwargs`` include parameters not recognized or that do not have default values in ``create_convolutional_predictor``. """ - if "axis" in self._conv_kwargs: + if "axis" in conv_kwargs: raise ValueError( "Setting the `axis` parameter is not allowed. Basis requires the " "convolution to be applied along the first axis (`axis=0`).\n" @@ -199,12 +205,12 @@ def _check_convolution_kwargs(self): # `basis_matrix` or `time_series` in kwargs. is not inspect.Parameter.empty } - if not set(self._conv_kwargs.keys()).issubset(convolve_configs): + if not set(conv_kwargs.keys()).issubset(convolve_configs): # do not encourage to set axis. convolve_configs = convolve_configs.difference({"axis"}) # remove the parameter in case axis=0 was passed, since it is allowed. invalid = ( - set(self._conv_kwargs.keys()) + set(conv_kwargs.keys()) .difference(convolve_configs) .difference({"axis"}) ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 258dad5e..7d5c56c0 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -888,7 +888,7 @@ def __init__( -------- >>> import numpy as np >>> from numpy import linspace - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import EvalOrthExponential >>> X = np.random.normal(size=(1000, 1)) >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates diff --git a/tests/test_basis.py b/tests/test_basis.py index e338a60c..49271568 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -16,6 +16,7 @@ import nemos.basis.basis as basis import nemos.convolve as convolve +from nemos.basis import EvalOrthExponential from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring from nemos.basis._decaying_exponential import OrthExponentialBasis from nemos.basis._raised_cosine_basis import ( @@ -26,6 +27,26 @@ from nemos.utils import pynapple_concatenate_numpy +@pytest.fixture() +def class_specific_params(): + shared_params = ["n_basis_funcs", "label"] + eval_params = ["bounds"] + conv_params = ["window_size", "conv_kwargs"] + return dict( + EvalBSpline = shared_params + eval_params + ["order"], + ConvBSpline = shared_params + conv_params + ["order"], + EvalMSpline = shared_params + eval_params + ["order"], + ConvMSpline = shared_params + conv_params +["order"], + EvalCyclicBSpline = shared_params + eval_params + ["order"], + ConvCyclicBSpline = shared_params + conv_params +["order"], + EvalRaisedCosineLinear= shared_params + eval_params + ["width"], + ConvRaisedCosineLinear=shared_params + conv_params +["width"], + EvalRaisedCosineLog= shared_params + eval_params + ["width", "time_scaling", "enforce_decay_to_zero"], + ConvRaisedCosineLog= shared_params + conv_params +["width", "time_scaling", "enforce_decay_to_zero"], + EvalOrthExponential= shared_params + eval_params + ["decay_rates"], + ConvOrthExponential = shared_params + conv_params +["decay_rates"] + ) + # automatic define user accessible basis and check the methods def list_all_basis_classes() -> list[type]: """ @@ -252,6 +273,479 @@ def cls(self): pass +# Auto-generated file with stripped classes and shared methods +@pytest.mark.parametrize( + "cls", + [ + # {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog}, + {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} + ] +) +class TestSharedMethods: + + @pytest.mark.parametrize("n_basis", [2, 3]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_basis_number(self, n_basis, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=n_basis, **kwargs) + x = np.linspace(0, 1, 10) + assert bas(x).shape[1] == n_basis + + def test_call_equivalent_in_conv(self,cls): + bas_con = cls["conv"](n_basis_funcs=5, window_size=10) + bas_eval = cls["eval"](n_basis_funcs=5) + x = np.linspace(0, 1, 10) + assert np.all(bas_con(x) == bas_eval(x)) + + @pytest.mark.parametrize( + "num_input, expectation", + [ + (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), + (1, does_not_raise()), + (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), + ], + ) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_input_num(self, num_input, mode, kwargs, expectation,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + with expectation: + bas(*([np.linspace(0, 1, 10)] * num_input)) + + @pytest.mark.parametrize( + "inp, expectation", + [ + (np.linspace(0, 1, 10), does_not_raise()), + (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), + ], + ) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_input_shape(self, inp, mode, kwargs, expectation,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + with expectation: + bas(inp) + + @pytest.mark.parametrize( + "samples, expectation", + [ + (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) + def test_call_input_type(self, samples, expectation,cls): + bas = cls["eval"](n_basis_funcs=5) # Only eval mode is relevant here + with expectation: + bas(samples) + + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_nan(self, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + x = np.linspace(0, 1, 10) + x[3] = np.nan + assert all(np.isnan(bas(x)[3])) + + @pytest.mark.parametrize("n_basis", [2, 3]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_non_empty(self, n_basis, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=n_basis, **kwargs) + with pytest.raises(ValueError, match="All sample provided must"): + bas(np.array([])) + + @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_sample_axis(self, time_axis_shape, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape + + @pytest.mark.parametrize( + "mn, mx, expectation", + [ + (0, 1, does_not_raise()), + (-2, 2, does_not_raise()), + ], + ) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_call_sample_range(self, mn, mx, expectation, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + with expectation: + bas(np.linspace(mn, mx, 10)) + + @pytest.mark.parametrize( + "kwargs, input1_shape, expectation", + [ + (dict(), (10,), does_not_raise()), + (dict(axis=0), (10,), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), + (dict(axis=1), (2, 10), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), + ], + ) + def test_compute_features_axis(self, kwargs, input1_shape, expectation,cls): + with expectation: + basis_obj = cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs) + basis_obj.compute_features(np.ones(input1_shape)) + + @pytest.mark.parametrize("n_basis_funcs", [4, 5]) + @pytest.mark.parametrize("time_scaling", [50, 70]) + @pytest.mark.parametrize("enforce_decay", [True, False]) + @pytest.mark.parametrize("window_size", [10, 15]) + @pytest.mark.parametrize("order", [3, 4]) + @pytest.mark.parametrize("width", [2, 3]) + @pytest.mark.parametrize( + "input_shape, expected_n_input", + [ + ((20,), 1), + ((20, 1), 1), + ((20, 2), 2), + ((20, 1, 2), 2), + ((20, 2, 1), 2), + ((20, 2, 2), 4), + ], + ) + def test_compute_features_conv_input( + self, + n_basis_funcs, + time_scaling, + enforce_decay, + window_size, + input_shape, + expected_n_input, + order, + width, + cls, + class_specific_params, + ): + x = np.ones(input_shape) + + kwargs = dict( + n_basis_funcs=n_basis_funcs, + decay_rates=np.arange(1, n_basis_funcs+1), + time_scaling=time_scaling, + window_size=window_size, + enforce_decay_to_zero=enforce_decay, + order=order, + width=width,) + + # figure out which kwargs needs to be removed + kwargs = {key: value for key, value in kwargs.items() if key in class_specific_params[cls["conv"].__name__]} + + basis_obj = cls["conv"](**kwargs) + out = basis_obj.compute_features(x) + assert out.shape[1] == expected_n_input * basis_obj.n_basis_funcs + + @pytest.mark.parametrize("eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])]) + def test_compute_features_input(self, eval_input,cls): + basis_obj = cls["eval"](n_basis_funcs=5) + basis_obj.compute_features(eval_input) + + @pytest.mark.parametrize( + "args, sample_size", + [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], + ) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_compute_features_returns_expected_number_of_basis(self, args, sample_size, mode, kwargs,cls): + basis_obj = cls[mode](**args, **kwargs) + eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) + assert eval_basis.shape[1] == args["n_basis_funcs"], ( + "Dimensions do not agree: The number of basis should match the first dimension " + f"of the evaluated basis. The number of basis is {args['n_basis_funcs']}, but the " + f"evaluated basis has dimension {eval_basis.shape[1]}" + ) + + @pytest.mark.parametrize( + "samples, vmin, vmax, expectation", + [ + (0.5, 0, 1, does_not_raise()), + (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), + (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + ], + ) + def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation,cls): + basis_obj = cls["eval"](5, bounds=(vmin, vmax)) + with expectation: + basis_obj.compute_features(samples) + + def test_conv_kwargs_error(self,cls): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'test'"): + cls["eval"](n_basis_funcs=5, test="hi") + + def test_convolution_is_performed(self,cls): + bas = cls["conv"](n_basis_funcs=5, window_size=10) + x = np.random.normal(size=100) + conv = bas.compute_features(x) + conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) + valid = ~np.isnan(conv) + assert np.all(conv[valid] == conv_2[valid]) + assert np.all(np.isnan(conv_2[~valid])) + + @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs,cls): + basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + if sample_size <= 0: + with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): + basis_obj.evaluate_on_grid(sample_size) + else: + _, eval_basis = basis_obj.evaluate_on_grid(sample_size) + assert eval_basis.shape[0] == sample_size + + @pytest.mark.parametrize("n_input", [0, 1, 2]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs,cls): + basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + inputs = [10] * n_input + if n_input == 0: + expectation = pytest.raises( + TypeError, match=r"evaluate_on_grid\(\) missing 1 required positional argument", + ) + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises( + TypeError, match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + ) + else: + expectation = does_not_raise() + + with expectation: + basis_obj.evaluate_on_grid(*inputs) + + @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs,cls): + basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + if sample_size <= 0: + with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): + basis_obj.evaluate_on_grid(sample_size) + else: + grid, _ = basis_obj.evaluate_on_grid(sample_size) + assert grid.shape[0] == sample_size + + def test_fit_kernel(self, cls): + bas = cls["conv"](n_basis_funcs=5, window_size=3) + bas._set_kernel() + assert bas.kernel_ is not None + + def test_fit_kernel_shape(self,cls): + bas = cls["conv"](n_basis_funcs=5, window_size=3) + bas._set_kernel() + assert bas.kernel_.shape == (3, 5) + + @pytest.mark.parametrize( + "mode, ws, expectation", + [ + ("conv", 2, does_not_raise()), + ( + "conv", + -1, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ( + "conv", + None, + pytest.raises( + ValueError, + match="If the basis is in `conv` mode, you must provide a ", + ), + ), + ( + "conv", + 1.5, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ("eval", None, pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + )), + ( + "eval", + 10, + pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + ), + ), + ], + ) + def test_init_window_size(self, mode, ws, expectation,cls): + with expectation: + cls[mode](n_basis_funcs=5, window_size=ws) + + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs,cls): + if n_basis_funcs < 2: + with pytest.raises( + ValueError, match=f"Object class {cls[mode].__name__} requires >= 2 basis elements.", + ): + cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + else: + cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + + @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_non_empty_samples(self, samples, mode, kwargs,cls): + if mode == "conv" and len(samples) == 1: + return + if len(samples) == 0: + with pytest.raises(ValueError, match="All sample provided must be non empty"): + cls[mode](5, **kwargs).compute_features(samples) + else: + cls[mode](5, **kwargs).compute_features(samples) + + @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs,cls): + basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + inputs = [np.linspace(0, 1, 20)] * n_input + if n_input == 0: + expectation = pytest.raises(TypeError, match="missing 1 required positional argument") + elif n_input != basis_obj._n_input_dimensionality: + expectation = pytest.raises(TypeError, match="takes 2 positional arguments but \d were given") + else: + expectation = does_not_raise() + + with expectation: + basis_obj.compute_features(*inputs) + + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + def test_pynapple_support(self, mode, kwargs,cls): + bas = cls[mode](n_basis_funcs=5, **kwargs) + x = np.linspace(0, 1, 10) + x_nap = nap.Tsd(t=np.arange(10), d=x) + y = bas(x) + y_nap = bas(x_nap) + assert isinstance(y_nap, nap.TsdFrame) + assert np.all(y == y_nap.d) + assert np.all(y_nap.t == x_nap.t) + + @pytest.mark.parametrize("sample_size", [30]) + @pytest.mark.parametrize("n_basis", [5]) + def test_pynapple_support_compute_features(self, n_basis, sample_size,cls): + iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) + inp = nap.Tsd( + t=np.linspace(0, 1, sample_size), + d=np.linspace(0, 1, sample_size), + time_support=iset, + ) + out = cls["eval"](n_basis_funcs=n_basis).compute_features(inp) + assert isinstance(out, nap.TsdFrame) + assert np.all(out.time_support.values == inp.time_support.values) + + @pytest.mark.parametrize("sample_size", [100, 1000]) + @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_funcs, sample_size, mode, kwargs,cls): + basis_obj = cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) + assert eval_basis.shape[0] == sample_size, ( + f"Dimensions do not agree: The sample size of the output should match the input sample size. " + f"Expected {sample_size}, but got {eval_basis.shape[0]}." + ) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ], + ) + def test_set_bounds(self, mode, expectation,cls): + kwargs = {"bounds": (1, 2)} + with expectation: + cls[mode](n_basis_funcs=10, **kwargs) + + if mode == "conv": + bas = cls["conv"](n_basis_funcs=10, window_size=10) + with pytest.raises(ValueError, match="Invalid parameter 'bounds' for estimator"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "enforce_decay_to_zero, time_scaling, width, window_size, n_basis_funcs, bounds, mode, decay_rates", + [ + (False, 15, 4, None, 10, (1, 2), "eval", np.arange(1, 11)), + (False, 15, 4, 10, 10, None, "conv", np.arange(1, 11)), + ], + ) + @pytest.mark.parametrize( + "order, conv_kwargs", + [ + (10, dict(shift=True)), + ], + ) + def test_set_params( + self, + enforce_decay_to_zero, + time_scaling, + width, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + order, decay_rates, conv_kwargs, + cls, + class_specific_params + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + enforce_decay_to_zero=enforce_decay_to_zero, + time_scaling=time_scaling, + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + order=order, + decay_rates=decay_rates, + conv_kwargs=conv_kwargs, + ) + pars = {key: value for key, value in pars.items() if key in class_specific_params[cls[mode].__name__]} + + keys = list(pars.keys()) + bas = cls[mode]( + **pars + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas = bas.set_params(**par_set) + assert isinstance(bas, cls[mode]) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(TypeError, match="got an unexpected keyword argument 'window_size'")), + ], + ) + def test_set_window_size(self, mode, expectation,cls): + kwargs = {"window_size": 10} + with expectation: + cls[mode](n_basis_funcs=10, **kwargs) + + if mode == "conv": + bas = cls["conv"](n_basis_funcs=10, window_size=10) + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + if mode == "eval": + bas = cls["eval"](n_basis_funcs=10) + with pytest.raises(ValueError, match="Invalid parameter 'window_size' for estimator"): + bas.set_params(window_size=10) + + def test_transform_fails(self,cls): + bas = cls["conv"](n_basis_funcs=5, window_size=3) + with pytest.raises( + ValueError, match="You must call `_set_kernel` before `_compute_features`" + ): + bas._compute_features(np.linspace(0, 1, 10)) + + def test_transformer_get_params(self,cls): + bas = cls["eval"](n_basis_funcs=5) + bas_transformer = bas.to_transformer() + params_transf = bas_transformer.get_params() + params_transf.pop("_basis") + params_basis = bas.get_params() + assert params_transf == params_basis + class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} From 0f9c28f875bbde207552cd804a271e9c044b3985 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 26 Nov 2024 17:24:54 -0500 Subject: [PATCH 041/109] refactored all 1d basis --- src/nemos/basis/_decaying_exponential.py | 2 +- tests/test_basis.py | 5390 +++------------------- 2 files changed, 761 insertions(+), 4631 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index e8e95093..d6b92e59 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -118,7 +118,7 @@ def _check_rates(self) -> None: """ if len(set(self._decay_rates)) != len(self._decay_rates): raise ValueError( - "Two or more rate are repeated! Repeating rate will result in a " + "Two or more rates are repeated! Repeating rates will result in a " "linearly dependent set of function for the basis." ) diff --git a/tests/test_basis.py b/tests/test_basis.py index 49271568..b1290fa3 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -11,6 +11,8 @@ import numpy as np import pynapple as nap import pytest +from scipy.stats import expon + import utils_testing from sklearn.base import clone as sk_clone @@ -47,6 +49,14 @@ def class_specific_params(): ConvOrthExponential = shared_params + conv_params +["decay_rates"] ) + +def extra_decay_rates(cls, n_basis): + name = cls.__name__ + if "OrthExp" in name: + return dict(decay_rates=np.arange(1, n_basis + 1)) + return {} + + # automatic define user accessible basis and check the methods def list_all_basis_classes() -> list[type]: """ @@ -277,22 +287,265 @@ def cls(self): @pytest.mark.parametrize( "cls", [ - # {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog}, - {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} + {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog}, + {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear}, + {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline}, + {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline}, + {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline}, + {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential} ] ) class TestSharedMethods: - @pytest.mark.parametrize("n_basis", [2, 3]) + @pytest.mark.parametrize( + "samples, vmin, vmax, expectation", + [ + (0.5, 0, 1, does_not_raise()), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), + (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], + ) + def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): + if "OrthExp" in cls["eval"].__name__ and not hasattr(samples, "shape"): + return + bas = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + with expectation: + bas(samples) + + @pytest.mark.parametrize( + "attribute, value", + [ + ("label", None), + ("label", "label"), + ("n_basis_input", 1), + ("n_output_features", 5), + ], + ) + def test_attr_setter(self, attribute, value, cls): + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + with pytest.raises( + AttributeError, match=rf"can't set attribute|property '{attribute}' of" + ): + setattr(bas, attribute, value) + + + @pytest.mark.parametrize( + "n_input, expectation", + [ + (2, does_not_raise()), + (0, pytest.raises(ValueError, match="Input shape mismatch detected")), + (1, pytest.raises(ValueError, match="Input shape mismatch detected")), + (3, pytest.raises(ValueError, match="Input shape mismatch detected")), + ], + ) + def test_expected_input_number(self, n_input, expectation, cls): + bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["eval"], 5)) + x = np.random.randn(20, 2) + bas.compute_features(x) + with expectation: + bas.compute_features(np.random.randn(30, n_input)) + + @pytest.mark.parametrize( + "conv_kwargs, expectation", + [ + (dict(), does_not_raise()), + ( + dict(axis=0), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + ), + ( + dict(axis=1), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + ), + (dict(shift=True), does_not_raise()), + ( + dict(shift=True, axis=0), + pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + ), + ( + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + (dict(shift=True, predictor_causality="causal"), does_not_raise()), + ( + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), + ), + ], + ) + def test_init_conv_kwargs(self, conv_kwargs, expectation, cls): + with expectation: + cls["conv"](n_basis_funcs=5, window_size=200, conv_kwargs=conv_kwargs, **extra_decay_rates(cls["eval"], 5)) + + @pytest.mark.parametrize("label", [None, "label"]) + def test_init_label(self, label, cls): + bas = cls["eval"](n_basis_funcs=5, label=label, **extra_decay_rates(cls["eval"], 5)) + expected_label = str(label) if label is not None else cls["eval"].__name__ + assert bas.label == expected_label + + @pytest.mark.parametrize("n_input", [1, 2, 3]) + def test_set_num_output_features(self, n_input, cls): + bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) + assert bas.n_output_features is None + bas.compute_features(np.random.randn(20, n_input)) + assert bas.n_output_features == n_input * bas.n_basis_funcs + + @pytest.mark.parametrize("n_input", [1, 2, 3]) + def test_set_num_basis_input(self, n_input, cls): + bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) + assert bas.n_basis_input is None + bas.compute_features(np.random.randn(20, n_input)) + assert bas.n_basis_input == (n_input,) + assert bas._n_basis_input == (n_input,) + + @pytest.mark.parametrize( + "samples, vmin, vmax, expectation", + [ + (0.5, 0, 1, does_not_raise()), + (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), + (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + ], + ) + def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation, cls): + basis_obj = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + with expectation: + basis_obj.compute_features(samples) + + @pytest.mark.parametrize( + "bounds, samples, nan_idx, mn, mx", + [ + (None, np.arange(5), [4], 0, 1), + ((0, 3), np.arange(5), [4], 0, 3), + ((1, 4), np.arange(5), [0], 1, 4), + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], + ) + def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx, cls): + bas_no_range = cls["eval"](n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + x1, _ = bas.evaluate_on_grid(10) + x2, _ = bas_no_range.evaluate_on_grid(10) + assert np.allclose(x1, x2 * (mx - mn) + mn) + + @pytest.mark.parametrize( + "vmin, vmax, samples, nan_idx", + [ + (0, 3, np.arange(5), [4]), + (1, 4, np.arange(5), [0]), + (1, 3, np.arange(5), [0, 4]), + ], + ) + def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx, cls): + # MSPline integrates to 1 on domain so must be excluded from this check + if "MSpline" in cls["eval"].__name__: + return + bas_no_range = cls["eval"](n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"](n_basis_funcs=5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + _, out1 = bas.evaluate_on_grid(10) + _, out2 = bas_no_range.evaluate_on_grid(10) + assert np.allclose(out1, out2) + + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (1, 2, 3), + pytest.raises( + ValueError, match="The provided `bounds` must be of length two" + ), + ), + ], + ) + def test_vmin_vmax_init(self, bounds, expectation, cls): + with expectation: + bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + assert bounds == bas.bounds if bounds else bas.bounds is None + + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (1, 2, 3), + pytest.raises( + ValueError, match="The provided `bounds` must be of length two" + ), + ), + ], + ) + def test_vmin_vmax_init(self, bounds, expectation, cls): + with expectation: + bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + assert bounds == bas.bounds if bounds else bas.bounds is None + + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation, cls): + bas = cls["eval"](n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + + @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_call_basis_number(self, n_basis, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs) + + bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) x = np.linspace(0, 1, 10) assert bas(x).shape[1] == n_basis - def test_call_equivalent_in_conv(self,cls): - bas_con = cls["conv"](n_basis_funcs=5, window_size=10) - bas_eval = cls["eval"](n_basis_funcs=5) + @pytest.mark.parametrize("n_basis", [6]) + def test_call_equivalent_in_conv(self, n_basis, cls): + bas_con = cls["conv"](n_basis_funcs=n_basis, window_size=10, **extra_decay_rates(cls["conv"], n_basis)) + bas_eval = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)) x = np.linspace(0, 1, 10) assert np.all(bas_con(x) == bas_eval(x)) @@ -305,8 +558,9 @@ def test_call_equivalent_in_conv(self,cls): ], ) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_input_num(self, num_input, mode, kwargs, expectation,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + @pytest.mark.parametrize("n_basis", [6]) + def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation,cls): + bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) with expectation: bas(*([np.linspace(0, 1, 10)] * num_input)) @@ -317,9 +571,10 @@ def test_call_input_num(self, num_input, mode, kwargs, expectation,cls): (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) + @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_input_shape(self, inp, mode, kwargs, expectation,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + def test_call_input_shape(self, inp, mode, kwargs, expectation,n_basis,cls): + bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) with expectation: bas(inp) @@ -333,29 +588,30 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation,cls): ), ], ) - def test_call_input_type(self, samples, expectation,cls): - bas = cls["eval"](n_basis_funcs=5) # Only eval mode is relevant here + @pytest.mark.parametrize("n_basis", [6]) + def test_call_input_type(self, samples, expectation, n_basis,cls): + bas = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)) # Only eval mode is relevant here with expectation: bas(samples) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_call_nan(self, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x[3] = np.nan assert all(np.isnan(bas(x)[3])) - @pytest.mark.parametrize("n_basis", [2, 3]) + @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_call_non_empty(self, n_basis, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs) + bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) with pytest.raises(ValueError, match="All sample provided must"): bas(np.array([])) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_call_sample_axis(self, time_axis_shape, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape @pytest.mark.parametrize( @@ -367,7 +623,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs,cls): ) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_call_sample_range(self, mn, mx, expectation, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) with expectation: bas(np.linspace(mn, mx, 10)) @@ -381,7 +637,7 @@ def test_call_sample_range(self, mn, mx, expectation, mode, kwargs,cls): ) def test_compute_features_axis(self, kwargs, input1_shape, expectation,cls): with expectation: - basis_obj = cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs) + basis_obj = cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs, **extra_decay_rates(cls["conv"], 5)) basis_obj.compute_features(np.ones(input1_shape)) @pytest.mark.parametrize("n_basis_funcs", [4, 5]) @@ -434,16 +690,19 @@ def test_compute_features_conv_input( @pytest.mark.parametrize("eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])]) def test_compute_features_input(self, eval_input,cls): + # orth exp needs more inputs (orthogonalizaiton impossible otherwise) + if "OrthExp" in cls["eval"].__name__: + return basis_obj = cls["eval"](n_basis_funcs=5) basis_obj.compute_features(eval_input) @pytest.mark.parametrize( "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], + [[{"n_basis_funcs": n_basis}, 100] for n_basis in [6, 10, 13]], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})]) def test_compute_features_returns_expected_number_of_basis(self, args, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](**args, **kwargs) + basis_obj = cls[mode](**args, **kwargs, **extra_decay_rates(cls[mode], args["n_basis_funcs"])) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[1] == args["n_basis_funcs"], ( "Dimensions do not agree: The number of basis should match the first dimension " @@ -462,16 +721,72 @@ def test_compute_features_returns_expected_number_of_basis(self, args, sample_si ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation,cls): - basis_obj = cls["eval"](5, bounds=(vmin, vmax)) + if "OrthExp" in cls["eval"].__name__ and not hasattr(samples, "shape"): + return + basis_obj = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) with expectation: basis_obj.compute_features(samples) + @pytest.mark.parametrize( + "bounds, samples, exception", + [ + (None, np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((0, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((1, 4), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ((1, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ], + ) + def test_vmin_vmax_mode_conv(self, bounds, samples, exception, cls): + with exception: + cls["conv"](n_basis_funcs=5, window_size=10, bounds=bounds, **extra_decay_rates(cls["conv"], 5)) + + @pytest.mark.parametrize( + "vmin, vmax, samples, nan_idx", + [ + (None, None, np.arange(5), []), + (0, 3, np.arange(5), [4]), + (1, 4, np.arange(5), [0]), + (1, 3, np.arange(5), [0, 4]), + ], + ) + def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx, cls): + bounds = None if vmin is None else (vmin, vmax) + bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + out = bas.compute_features(samples) + assert np.all(np.isnan(out[nan_idx])) + valid_idx = list(set(samples).difference(nan_idx)) + assert np.all(~np.isnan(out[valid_idx])) + + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation,cls): + bas = cls["eval"](n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + def test_conv_kwargs_error(self,cls): with pytest.raises(TypeError, match="got an unexpected keyword argument 'test'"): - cls["eval"](n_basis_funcs=5, test="hi") + cls["eval"](n_basis_funcs=5, test="hi", **extra_decay_rates(cls["eval"], 5)) def test_convolution_is_performed(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=10) + bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) x = np.random.normal(size=100) conv = bas.compute_features(x) conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) @@ -482,7 +797,9 @@ def test_convolution_is_performed(self,cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + if "OrthExp" in cls["eval"].__name__: + return + basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) if sample_size <= 0: with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): basis_obj.evaluate_on_grid(sample_size) @@ -493,7 +810,7 @@ def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs,cls): @pytest.mark.parametrize("n_input", [0, 1, 2]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) inputs = [10] * n_input if n_input == 0: expectation = pytest.raises( @@ -512,7 +829,9 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs,cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + if "OrthExp" in cls["eval"].__name__: + return + basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) if sample_size <= 0: with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): basis_obj.evaluate_on_grid(sample_size) @@ -521,14 +840,14 @@ def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs,cls): assert grid.shape[0] == sample_size def test_fit_kernel(self, cls): - bas = cls["conv"](n_basis_funcs=5, window_size=3) + bas = cls["conv"](n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5)) bas._set_kernel() assert bas.kernel_ is not None def test_fit_kernel_shape(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=3) + bas = cls["conv"](n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5)) bas._set_kernel() - assert bas.kernel_.shape == (3, 5) + assert bas.kernel_.shape == (30, 5) @pytest.mark.parametrize( "mode, ws, expectation", @@ -568,34 +887,42 @@ def test_fit_kernel_shape(self,cls): ) def test_init_window_size(self, mode, ws, expectation,cls): with expectation: - cls[mode](n_basis_funcs=5, window_size=ws) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs,cls): - if n_basis_funcs < 2: - with pytest.raises( - ValueError, match=f"Object class {cls[mode].__name__} requires >= 2 basis elements.", - ): - cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) - else: - cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + cls[mode](n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5)) + + # @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + # @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) + # @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + # def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs, order, cls): + # min_per_basis = { + # "EvalMSpline": (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs), + # "EvalRaisedCosineLog": lambda x: x < 2, + # "EvalBSpline": lambda x: order > x, + # } + # if n_basis_funcs < 2: + # with pytest.raises( + # ValueError, match=f"Object class {cls[mode].__name__} requires >= 2 basis elements.", + # ): + # cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + # else: + # cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) def test_non_empty_samples(self, samples, mode, kwargs,cls): + if "OrthExp" in cls["eval"].__name__: + return if mode == "conv" and len(samples) == 1: return if len(samples) == 0: with pytest.raises(ValueError, match="All sample provided must be non empty"): - cls[mode](5, **kwargs).compute_features(samples) + cls[mode](5, **kwargs, **extra_decay_rates(cls[mode], 5)).compute_features(samples) else: - cls[mode](5, **kwargs).compute_features(samples) + cls[mode](5, **kwargs, **extra_decay_rates(cls[mode], 5)).compute_features(samples) @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})]) def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs) + basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: expectation = pytest.raises(TypeError, match="missing 1 required positional argument") @@ -609,7 +936,7 @@ def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs, @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) def test_pynapple_support(self, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=5, **kwargs) + bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x_nap = nap.Tsd(t=np.arange(10), d=x) y = bas(x) @@ -627,15 +954,15 @@ def test_pynapple_support_compute_features(self, n_basis, sample_size,cls): d=np.linspace(0, 1, sample_size), time_support=iset, ) - out = cls["eval"](n_basis_funcs=n_basis).compute_features(inp) + out = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)).compute_features(inp) assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + @pytest.mark.parametrize("n_basis_funcs", [5, 10, 80]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 90})]) def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_funcs, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) + basis_obj = cls[mode](n_basis_funcs=n_basis_funcs, **kwargs, **extra_decay_rates(cls[mode], n_basis_funcs)) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[0] == sample_size, ( f"Dimensions do not agree: The sample size of the output should match the input sample size. " @@ -652,10 +979,10 @@ def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_fun def test_set_bounds(self, mode, expectation,cls): kwargs = {"bounds": (1, 2)} with expectation: - cls[mode](n_basis_funcs=10, **kwargs) + cls[mode](n_basis_funcs=10, **kwargs, **extra_decay_rates(cls[mode], 10)) if mode == "conv": - bas = cls["conv"](n_basis_funcs=10, window_size=10) + bas = cls["conv"](n_basis_funcs=10, window_size=20, **extra_decay_rates(cls[mode], 10)) with pytest.raises(ValueError, match="Invalid parameter 'bounds' for estimator"): bas.set_params(bounds=(1, 2)) @@ -719,52 +1046,60 @@ def test_set_params( def test_set_window_size(self, mode, expectation,cls): kwargs = {"window_size": 10} with expectation: - cls[mode](n_basis_funcs=10, **kwargs) + cls[mode](n_basis_funcs=10, **kwargs, **extra_decay_rates(cls[mode], 10)) if mode == "conv": - bas = cls["conv"](n_basis_funcs=10, window_size=10) + bas = cls["conv"](n_basis_funcs=10, window_size=10, **extra_decay_rates(cls["conv"], 10)) with pytest.raises(ValueError, match="If the basis is in `conv` mode"): bas.set_params(window_size=None) if mode == "eval": - bas = cls["eval"](n_basis_funcs=10) + bas = cls["eval"](n_basis_funcs=10, **extra_decay_rates(cls["eval"], 10)) with pytest.raises(ValueError, match="Invalid parameter 'window_size' for estimator"): bas.set_params(window_size=10) def test_transform_fails(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=3) + bas = cls["conv"](n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5)) with pytest.raises( ValueError, match="You must call `_set_kernel` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) def test_transformer_get_params(self,cls): - bas = cls["eval"](n_basis_funcs=5) + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") params_basis = bas.get_params() + rates_1 = params_basis.pop("decay_rates", 1) + rates_2 = params_transf.pop("decay_rates", 1) assert params_transf == params_basis + assert np.all(rates_1 == rates_2) -class TestRaisedCosineLogBasis(BasisFuncsTesting): +class TestRaisedCosineLogBasis: cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} + @pytest.mark.parametrize("width", [1.5, 2, 2.5]) + def test_decay_to_zero_basis_number_match(self, width): + n_basis_funcs = 10 + _, ev = self.cls["conv"]( + n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True, window_size=5 + ).evaluate_on_grid(2) + assert ev.shape[1] == n_basis_funcs, ( + "Basis function number mismatch. " + f"Expected {n_basis_funcs}, got {ev.shape[1]} instead!" + ) - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_non_empty_samples(self, samples, mode, kwargs): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises(ValueError, match="All sample provided must be non empty"): - self.cls[mode](5, **kwargs).compute_features(samples) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): + if n_basis_funcs < 2: + with pytest.raises( + ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", + ): + self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) else: - self.cls[mode](5, **kwargs).compute_features(samples) - - @pytest.mark.parametrize("eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])]) - def test_compute_features_input(self, eval_input): - basis_obj = self.cls["eval"](n_basis_funcs=5) - basis_obj.compute_features(eval_input) + self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) @pytest.mark.parametrize( "width, expectation", @@ -785,159 +1120,86 @@ def test_set_width(self, width, expectation, mode, kwargs): with expectation: basis_obj.set_params(width=width) + def test_time_scaling_property(self): + time_scaling = [0.1, 10, 100] + n_basis_funcs = 5 + _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) + corr = np.zeros(len(time_scaling)) + for idx, ts in enumerate(time_scaling): + basis_log = self.cls["eval"]( + n_basis_funcs=n_basis_funcs, + time_scaling=ts, + enforce_decay_to_zero=False, + ) + _, log_ev = basis_log.evaluate_on_grid(100) + corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + ) + assert np.all( + np.diff(corr) < 0 + ), "As time scales increases, deviation from linearity should increase!" + @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", + "time_scaling, expectation", [ - (dict(), (10,), does_not_raise()), - (dict(axis=0), (10,), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), - (dict(axis=1), (2, 10), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), + (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + (0.1, does_not_raise()), + (10, does_not_raise()), ], ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): with expectation: - basis_obj = self.cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs) - basis_obj.compute_features(np.ones(input1_shape)) + self.cls[mode](n_basis_funcs=5, time_scaling=time_scaling, **kwargs) - @pytest.mark.parametrize("n_basis_funcs", [4, 5]) - @pytest.mark.parametrize("time_scaling", [50, 70]) - @pytest.mark.parametrize("enforce_decay", [True, False]) - @pytest.mark.parametrize("window_size", [10, 15]) @pytest.mark.parametrize( - "input_shape, expected_n_input", + "width, expectation", [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), + (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), + (1.5, does_not_raise()), + (2, does_not_raise()), + (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - def test_compute_features_conv_input( - self, - n_basis_funcs, - time_scaling, - enforce_decay, - window_size, - input_shape, - expected_n_input, - ): - x = np.ones(input_shape) - basis_obj = self.cls["conv"]( - n_basis_funcs=n_basis_funcs, - time_scaling=time_scaling, - window_size=window_size, - enforce_decay_to_zero=enforce_decay, - ) - out = basis_obj.compute_features(x) - assert out.shape[1] == expected_n_input * basis_obj.n_basis_funcs - - @pytest.mark.parametrize( - "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], - ) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_compute_features_returns_expected_number_of_basis(self, args, sample_size, mode, kwargs): - basis_obj = self.cls[mode](**args, **kwargs) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - assert eval_basis.shape[1] == args["n_basis_funcs"], ( - "Dimensions do not agree: The number of basis should match the first dimension " - f"of the evaluated basis. The number of basis is {args['n_basis_funcs']}, but the " - f"evaluated basis has dimension {eval_basis.shape[1]}" - ) + def test_width_values(self, width, expectation, mode, kwargs): + with expectation: + self.cls[mode](n_basis_funcs=5, width=width, **kwargs) - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_funcs, sample_size, mode, kwargs): - basis_obj = self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - assert eval_basis.shape[0] == sample_size, ( - f"Dimensions do not agree: The sample size of the output should match the input sample size. " - f"Expected {sample_size}, but got {eval_basis.shape[0]}." - ) - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - basis_obj = self.cls["eval"](5, bounds=(vmin, vmax)) - with expectation: - basis_obj.compute_features(samples) +class TestRaisedCosineLinearBasis(BasisFuncsTesting): + cls = {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): if n_basis_funcs < 2: with pytest.raises( - ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", + ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", ): self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) else: self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs): - basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input == 0: - expectation = pytest.raises(TypeError, match="missing 1 required positional argument") - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="takes 2 positional arguments but \d were given") - else: - expectation = does_not_raise() - + @pytest.mark.parametrize( + "width, expectation", + [ + (10, does_not_raise()), + (10.5, does_not_raise()), + (0.5, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + (-10, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + ], + ) + def test_set_width(self, width, expectation): + basis_obj = self.cls["eval"](n_basis_funcs=5) with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs): - basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) - if sample_size <= 0: - with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs): - basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) - if sample_size <= 0: - with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs): - basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - + basis_obj.width = width with expectation: - basis_obj.evaluate_on_grid(*inputs) + basis_obj.set_params(width=width) @pytest.mark.parametrize( "width, expectation", @@ -951,4497 +1213,365 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs): (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) def test_width_values(self, width, expectation, mode, kwargs): + """ + Test allowable widths: integer multiple of 1/2, greater than 1. + This test validates the behavior of both `eval` and `conv` modes. + """ + basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: - self.cls[mode](n_basis_funcs=5, width=width, **kwargs) - - @pytest.mark.parametrize("width", [1.5, 2, 2.5]) - def test_decay_to_zero_basis_number_match(self, width): - n_basis_funcs = 10 - _, ev = self.cls["conv"]( - n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True, window_size=5 - ).evaluate_on_grid(2) - assert ev.shape[1] == n_basis_funcs, ( - "Basis function number mismatch. " - f"Expected {n_basis_funcs}, got {ev.shape[1]} instead!" - ) - - @pytest.mark.parametrize( - "time_scaling, expectation", - [ - (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), - (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), - (0.1, does_not_raise()), - (10, does_not_raise()), - ], - ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) - def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): + basis_obj.width = width with expectation: - self.cls[mode](n_basis_funcs=5, time_scaling=time_scaling, **kwargs) + basis_obj.set_params(width=width) - def test_time_scaling_property(self): - time_scaling = [0.1, 10, 100] - n_basis_funcs = 5 - _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) - corr = np.zeros(len(time_scaling)) - for idx, ts in enumerate(time_scaling): - basis_log = self.cls["eval"]( - n_basis_funcs=n_basis_funcs, - time_scaling=ts, - enforce_decay_to_zero=False, - ) - _, log_ev = basis_log.evaluate_on_grid(100) - corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( - np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) - ) - assert np.all( - np.diff(corr) < 0 - ), "As time scales increases, deviation from linearity should increase!" - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls["eval"](n_basis_funcs=n_basis).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) +class TestMSplineBasis(BasisFuncsTesting): + cls = {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline} - @pytest.mark.parametrize( - "num_input, expectation", + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) + and order < #basis are enforced. + """ + raise_exception = (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs) + if raise_exception: + with pytest.raises( + ValueError, + match=r"Spline order must be positive!|" + rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, order=order, **kwargs + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, order=order, **kwargs + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) + + @pytest.mark.parametrize("n_basis_funcs", [10]) + @pytest.mark.parametrize("order", [-1, 0, 1, 2]) + def test_order_is_positive(self, n_basis_funcs, order): + """ + Verifies that the order must be positive and less than or equal to the number of basis functions. + """ + raise_exception = order < 1 + if raise_exception: + with pytest.raises(ValueError, match=r"Spline order must be positive!"): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + + @pytest.mark.parametrize("n_basis_funcs", [5]) + @pytest.mark.parametrize( + "order, expectation", [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), + (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), + (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), + (0, pytest.raises(ValueError, match=r"Spline order must be positive")), (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), + (2, does_not_raise()), + ( + 10, + pytest.raises( + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", + ), + ), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_input_num(self, num_input, mode, kwargs, expectation): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) + def test_order_setter(self, n_basis_funcs, order, expectation): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=4) with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) + basis_obj.order = order + basis_obj.compute_features(np.linspace(0, 1, 10)) @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], + "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_input_shape(self, inp, mode, kwargs, expectation): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_sample_axis(self, time_axis_shape, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_nan(self, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) - x = np.linspace(0, 1, 10) - x[3] = np.nan - assert all(np.isnan(bas(x)[3])) + def test_samples_range_matches_compute_features_requirements(self, sample_range): + """ + Verifies that the compute_features() method can handle input range. + """ + basis_obj = self.cls["eval"](n_basis_funcs=5, order=3) + basis_obj.compute_features(np.linspace(*sample_range, 100)) @pytest.mark.parametrize( - "samples, expectation", + "bounds, samples, nan_idx, scaling", [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), + (None, np.arange(5), [4], 1), + ((1, 4), np.arange(5), [0], 3), + ((1, 3), np.arange(5), [0, 4], 2), ], ) - def test_call_input_type(self, samples, expectation): - bas = self.cls["eval"](n_basis_funcs=5) # Only eval mode is relevant here - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls["conv"](n_basis_funcs=5, window_size=10) - bas_eval = self.cls["eval"](n_basis_funcs=5) - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eval(x)) + def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( + self, bounds, samples, nan_idx, scaling + ): + """ + Check that the MSpline has the expected scaling property. - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_pynapple_support(self, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) + The MSpline must integrate to one. If the support is reduced, the height of the spline increases. + """ + bas_no_range = self.cls["eval"](5, bounds=None) + bas = self.cls["eval"](5, bounds=bounds) + _, out1 = bas.evaluate_on_grid(10) + _, out2 = bas_no_range.evaluate_on_grid(10) + assert np.allclose(out1 * scaling, out2) - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_basis_number(self, n_basis, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=n_basis, **kwargs) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_non_empty(self, n_basis, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=n_basis, **kwargs) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) +class TestOrthExponentialBasis(BasisFuncsTesting): + cls = {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential} @pytest.mark.parametrize( - "mn, mx, expectation", - [ - (0, 1, does_not_raise()), - (-2, 2, does_not_raise()), - ], + "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_sample_range(self, mn, mx, expectation, mode, kwargs): - bas = self.cls[mode](n_basis_funcs=5, **kwargs) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls["conv"](n_basis_funcs=5, window_size=3) - bas._set_kernel() - assert bas.kernel_ is not None + def test_decay_rate_repetition(self, decay_rates): + """ + Tests whether the class instance correctly processes the decay rates without repetition. + A repeated rate causes linear algebra issues, and should raise a ValueError exception. + """ + decay_rates = np.asarray(decay_rates, dtype=float) + raise_exception = len(set(decay_rates)) != len(decay_rates) + if raise_exception: + with pytest.raises( + ValueError, match=r"Two or more rates are repeated! Repeating rates will" + ): + self.cls["eval"](n_basis_funcs=len(decay_rates), decay_rates=decay_rates) + else: + self.cls["eval"](n_basis_funcs=len(decay_rates), decay_rates=decay_rates) - def test_fit_kernel_shape(self): - bas = self.cls["conv"](n_basis_funcs=5, window_size=3) - bas._set_kernel() - assert bas.kernel_.shape == (3, 5) + @pytest.mark.parametrize( + "decay_rates", [[], [1], [1, 2, 3], [1, 0.01, 0.02, 0.001]] + ) + @pytest.mark.parametrize("n_basis_funcs", [1, 2, 3, 4]) + def test_decay_rate_size_match_n_basis_funcs(self, decay_rates, n_basis_funcs): + """ + Tests whether the size of decay rates matches the number of basis functions. + """ + raise_exception = len(decay_rates) != n_basis_funcs + decay_rates = np.asarray(decay_rates, dtype=float) + if raise_exception: + with pytest.raises( + ValueError, match="The number of basis functions must match the" + ): + self.cls["eval"](n_basis_funcs=n_basis_funcs, decay_rates=decay_rates) + else: + self.cls["eval"](n_basis_funcs=n_basis_funcs, decay_rates=decay_rates) - def test_transform_fails(self): - bas = self.cls["conv"](n_basis_funcs=5, window_size=3) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): + """ + Tests whether the class instance has a minimum number of basis functions. + """ + raise_exception = n_basis_funcs < 1 + decay_rates = np.arange(1, 1 + n_basis_funcs) if n_basis_funcs > 0 else [] + if raise_exception: + with pytest.raises( + ValueError, + match=f"Object class {self.cls[mode].__name__} requires >= 1 basis elements.", + ): + self.cls[mode]( + n_basis_funcs=n_basis_funcs, + decay_rates=decay_rates, + **kwargs, + ) + else: + self.cls[mode]( + n_basis_funcs=n_basis_funcs, + decay_rates=decay_rates, + **kwargs, + ) - @pytest.mark.parametrize("label", [None, "label"]) - def test_init_label(self, label): - bas = self.cls["eval"](n_basis_funcs=5, label=label) - expected_label = str(label) if label is not None else self.cls["eval"].__name__ - assert bas.label == expected_label - @pytest.mark.parametrize( - "attribute, value", - [ - ("label", None), - ("label", "label"), - ("n_basis_input", 1), - ("n_output_features", 5), - ], - ) - def test_attr_setter(self, attribute, value): - bas = self.cls["eval"](n_basis_funcs=5) - with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" - ): - setattr(bas, attribute, value) +class TestBSplineBasis(BasisFuncsTesting): + cls = {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline} - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_output_features(self, n_input): - bas = self.cls["conv"](n_basis_funcs=5, window_size=10) - assert bas.n_output_features is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_output_features == n_input * bas.n_basis_funcs + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) and + order < #basis are enforced. + """ + raise_exception = order > n_basis_funcs + if raise_exception: + with pytest.raises( + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, + order=order, + **kwargs, + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, + order=order, + **kwargs, + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_basis_input(self, n_input): - bas = self.cls["conv"](n_basis_funcs=5, window_size=10) - assert bas.n_basis_input is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) + @pytest.mark.parametrize("n_basis_funcs", [10]) + @pytest.mark.parametrize("order", [-1, 0, 1, 2]) + def test_order_is_positive(self, n_basis_funcs, order): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) and + order < #basis are enforced. + """ + raise_exception = order < 1 + if raise_exception: + with pytest.raises(ValueError, match=r"Spline order must be positive!"): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + @pytest.mark.parametrize("n_basis_funcs", [5]) @pytest.mark.parametrize( - "n_input, expectation", + "order, expectation", [ + (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), + (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), + (0, pytest.raises(ValueError, match=r"Spline order must be positive")), + (1, does_not_raise()), (2, does_not_raise()), - (0, pytest.raises(ValueError, match="Input shape mismatch detected")), - (1, pytest.raises(ValueError, match="Input shape mismatch detected")), - (3, pytest.raises(ValueError, match="Input shape mismatch detected")), + ( + 10, + pytest.raises( + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", + ), + ), ], ) - def test_expected_input_number(self, n_input, expectation): - bas = self.cls["conv"](n_basis_funcs=5, window_size=10) - x = np.random.randn(20, 2) - bas.compute_features(x) + def test_order_setter(self, n_basis_funcs, order, expectation): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=4) with expectation: - bas.compute_features(np.random.randn(30, n_input)) + basis_obj.order = order + basis_obj.compute_features(np.linspace(0, 1, 10)) @pytest.mark.parametrize( - "conv_kwargs, expectation", - [ - (dict(), does_not_raise()), - ( - dict(axis=0), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), - ), - ( - dict(axis=1), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), - ), - (dict(shift=True), does_not_raise()), - ( - dict(shift=True, axis=0), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), - ), - ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - (dict(shift=True, predictor_causality="causal"), does_not_raise()), - ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - ], - ) - def test_init_conv_kwargs(self, conv_kwargs, expectation): - with expectation: - self.cls["conv"](n_basis_funcs=5, window_size=200, conv_kwargs=conv_kwargs) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - None, - pytest.raises( - ValueError, - match="If the basis is in `conv` mode, you must provide a ", - ), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, pytest.raises( - TypeError, - match=r"got an unexpected keyword argument 'window_size'", - )), - ( - "eval", - 10, - pytest.raises( - TypeError, - match=r"got an unexpected keyword argument 'window_size'", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls[mode](n_basis_funcs=5, window_size=ws) - - @pytest.mark.parametrize( - "enforce_decay_to_zero, time_scaling, width, window_size, n_basis_funcs, bounds, mode", - [ - (False, 15, 4, None, 10, (1, 2), "eval"), - (False, 15, 4, 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, - enforce_decay_to_zero, - time_scaling, - width, - window_size, - n_basis_funcs, - bounds, - mode: Literal["eval", "conv"], - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - enforce_decay_to_zero=enforce_decay_to_zero, - time_scaling=time_scaling, - width=width, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - if window_size is None: - pars.pop("window_size") - if bounds is None: - pars.pop("bounds") - - keys = list(pars.keys()) - bas = self.cls[mode]( - **pars - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas = bas.set_params(**par_set) - assert isinstance(bas, self.cls[mode]) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ], - ) - def test_set_bounds(self, mode, expectation): - kwargs = {"bounds": (1, 2)} - with expectation: - self.cls[mode](n_basis_funcs=10, **kwargs) - - if mode == "conv": - bas = self.cls["conv"](n_basis_funcs=10, window_size=10) - with pytest.raises(ValueError, match="Invalid parameter 'bounds' for estimator"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(TypeError, match="got an unexpected keyword argument 'window_size'")), - ], + "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] ) - def test_set_window_size(self, mode, expectation): - kwargs = {"window_size": 10} - with expectation: - self.cls[mode](n_basis_funcs=10, **kwargs) + def test_samples_range_matches_compute_features_requirements(self, sample_range: tuple): + """ + Verifies that the compute_features() method can handle input range. + """ + basis_obj = self.cls["eval"](n_basis_funcs=5, order=3) + basis_obj.compute_features(np.linspace(*sample_range, 100)) - if mode == "conv": - bas = self.cls["conv"](n_basis_funcs=10, window_size=10) - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - if mode == "eval": - bas = self.cls["eval"](n_basis_funcs=10) - with pytest.raises(ValueError, match="Invalid parameter 'window_size' for estimator"): - bas.set_params(window_size=10) +class TestCyclicBSplineBasis(BasisFuncsTesting): + cls = {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline} - def test_convolution_is_performed(self): - bas = self.cls["conv"](n_basis_funcs=5, window_size=10) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) + @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) + @pytest.mark.parametrize("order", [2, 3, 4, 5]) + @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + """ + Verifies that the minimum number of basis functions and order required (i.e., at least 1) + and order < #basis are enforced. + """ + raise_exception = order > n_basis_funcs + if raise_exception: + with pytest.raises( + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, + order=order, + **kwargs, + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls[mode]( + n_basis_funcs=n_basis_funcs, + order=order, + **kwargs, + ) + basis_obj.compute_features(np.linspace(0, 1, 10)) - def test_conv_kwargs_error(self): - with pytest.raises(TypeError, match="got an unexpected keyword argument 'test'"): - self.cls["eval"](n_basis_funcs=5, test="hi") + @pytest.mark.parametrize("n_basis_funcs", [10]) + @pytest.mark.parametrize("order", [1, 2, 3]) + def test_order_1_invalid(self, n_basis_funcs, order): + """ + Verifies that order >= 2 is required for cyclic B-splines. + """ + raise_exception = order == 1 + if raise_exception: + with pytest.raises( + ValueError, match=r"Order >= 2 required for cyclic B-spline" + ): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation): - with expectation: - bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize("n_basis_funcs", [10]) + @pytest.mark.parametrize("order", [-1, 0, 2, 3]) + def test_order_is_positive(self, n_basis_funcs, order): + """ + Verifies that the order is positive and < #basis. + """ + raise_exception = order < 1 + if raise_exception: + with pytest.raises(ValueError, match=r"Spline order must be positive!"): + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + else: + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) + basis_obj.compute_features(np.linspace(0, 1, 10)) + @pytest.mark.parametrize("n_basis_funcs", [5]) @pytest.mark.parametrize( - "bounds, expectation", + "order, expectation", [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), + (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), + (0, pytest.raises(ValueError, match=r"Spline order must be positive")), + (1, does_not_raise()), + (2, does_not_raise()), ( - (2, 1), + 10, pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", ), ), ], ) - def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls["eval"](n_basis_funcs=3, bounds=(1, 3)) + def test_order_setter(self, n_basis_funcs, order, expectation): + """ + Verifies that setting `order` validates the value correctly. + """ + basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=4) with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (None, None, np.arange(5), []), - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): - bounds = None if vmin is None else (vmin, vmax) - bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) - out = bas.compute_features(samples) - assert np.all(np.isnan(out[nan_idx])) - valid_idx = list(set(samples).difference(nan_idx)) - assert np.all(~np.isnan(out[valid_idx])) - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): - bas_no_range = self.cls["eval"](n_basis_funcs=3, bounds=None) - bas = self.cls["eval"](n_basis_funcs=3, bounds=(vmin, vmax)) - _, out1 = bas.evaluate_on_grid(10) - _, out2 = bas_no_range.evaluate_on_grid(10) - assert np.allclose(out1, out2) + basis_obj.order = order + basis_obj.compute_features(np.linspace(0, 1, 10)) @pytest.mark.parametrize( - "bounds, samples, nan_idx, mn, mx", - [ - (None, np.arange(5), [4], 0, 1), - ((0, 3), np.arange(5), [4], 0, 3), - ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3), - ], + "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls["eval"](n_basis_funcs=3, bounds=None) - bas = self.cls["eval"](n_basis_funcs=3, bounds=bounds) - x1, _ = bas.evaluate_on_grid(10) - x2, _ = bas_no_range.evaluate_on_grid(10) - assert np.allclose(x1, x2 * (mx - mn) + mn) - - @pytest.mark.parametrize( - "bounds, samples, exception", - [ - (None, np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((0, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((1, 4), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((1, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ], - ) - def test_vmin_vmax_mode_conv(self, bounds, samples, exception): - with exception: - self.cls["conv"](n_basis_funcs=3, window_size=10, bounds=bounds) - - def test_transformer_get_params(self): - bas = self.cls["eval"](n_basis_funcs=5) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - params_basis = bas.get_params() - assert params_transf == params_basis - - -class TestRaisedCosineLinearBasis(BasisFuncsTesting): - cls = basis.RaisedCosineBasisLinear - - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_non_empty_samples(self, samples, mode, window_size): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls(5, mode=mode, window_size=window_size).compute_features( - samples - ) - else: - self.cls(5, mode=mode, window_size=window_size).compute_features(samples) - - @pytest.mark.parametrize( - "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] - ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5) - basis_obj.compute_features(eval_input) - - @pytest.mark.parametrize( - "width, expectation", - [ - (10, does_not_raise()), - (10.5, does_not_raise()), - ( - 0.5, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), - ( - 10.3, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), - ( - -10, - pytest.raises( - ValueError, - match=r"Invalid raised cosine width\. 2\*width must be a positive", - ), - ), - (None, pytest.raises(TypeError, match="'<=' not supported between")), - ], - ) - def test_set_width(self, width, expectation): - basis_obj = self.cls(n_basis_funcs=5) - with expectation: - basis_obj.width = width - with expectation: - basis_obj.set_params(width=width) - - @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", - [ - (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ], - ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - with expectation: - basis_obj = self.cls(n_basis_funcs=5, mode="conv", window_size=5, **kwargs) - basis_obj.compute_features(np.ones(input1_shape)) - - @pytest.mark.parametrize("n_basis_funcs", [4, 5]) - @pytest.mark.parametrize("window_size", [10, 15]) - @pytest.mark.parametrize( - "input_shape, expected_n_input", - [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), - ], - ) - def test_compute_features_conv_input( - self, n_basis_funcs, window_size, input_shape, expected_n_input - ): - x = np.ones(input_shape) - bas = self.cls( - n_basis_funcs=n_basis_funcs, - mode="conv", - window_size=window_size, - ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * bas.n_basis_funcs - - @pytest.mark.parametrize( - "args, sample_size", - [[{"n_basis_funcs": n_basis}, 100] for n_basis in [2, 10, 100]], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_compute_features_returns_expected_number_of_basis( - self, args, mode, window_size, sample_size - ): - """ - Verifies that the compute_features() method returns the expected number of basis functions. - """ - basis_obj = self.cls(mode=mode, window_size=window_size, **args) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[1] != args["n_basis_funcs"]: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." - f"The number of basis is {args['n_basis_funcs']}", - f"The first dimension of the output features is {eval_basis.shape[1]}", - ) - return - - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [2, 10, 100]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, mode, window_size - ): - """ - Checks that the sample size of the output from the co ute_features() method matches the input sample size. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features." - f"The window size is {sample_size}", - f"The second dimension of the output features basis is {eval_basis.shape[0]}", - ) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, mode, window_size - ): - """ - Verifies that the minimum number of basis functions required (i.e., 1) is enforced. - """ - raise_exception = n_basis_funcs < 2 - if raise_exception: - with pytest.raises( - ValueError, - match=f"Object class {self.cls.__name__} " - r"requires >= 2 basis elements\.", - ): - self.cls( - n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size - ) - else: - self.cls(n_basis_funcs=n_basis_funcs, mode=mode, window_size=window_size) - - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): - """ - Confirms that the compute_features() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls(n_basis_funcs=5, mode=mode, window_size=window_size) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, match="Input dimensionality mismatch" - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): - """ - Checks that the evaluate_on_grid() method returns a grid of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): - """ - Ensures that the evaluate_on_grid() method returns basis functions of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): + def test_samples_range_matches_compute_features_requirements(self, sample_range: tuple): """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. + Verifies that the compute_features() method can handle input ranges. """ - basis_obj = self.cls(n_basis_funcs=5) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.evaluate_on_grid(*inputs) - - @pytest.mark.parametrize( - "width ,expectation", - [ - (-1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), - (0, pytest.raises(ValueError, match="Invalid raised cosine width. ")), - (0.5, pytest.raises(ValueError, match="Invalid raised cosine width. ")), - (1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), - (1.5, does_not_raise()), - (2, does_not_raise()), - (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), - ], - ) - def test_width_values(self, width, expectation): - """Test allowable widths: integer multiple of 1/2, greater than 1.""" - with expectation: - self.cls(n_basis_funcs=5, width=width) - - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls(n_basis).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) - - # TEST CALL - @pytest.mark.parametrize( - "num_input, expectation", - [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), - (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) - - @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_nan(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x[3] = np.nan - assert all(np.isnan(bas(x)[3])) - - @pytest.mark.parametrize( - "samples, expectation", - [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), - ], - ) - def test_call_input_type(self, samples, expectation): - bas = self.cls(5) - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10) - bas_eva = self.cls(5, mode="eval") - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) - - @pytest.mark.parametrize( - "mn, mx, expectation", - [ - (0, 1, does_not_raise()), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_ is not None - - def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_.shape == (3, 5) - - def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=3) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 2 - with expectation: - self.cls(5, mode=mode, window_size=window_size) - - @pytest.mark.parametrize("label", [None, "label"]) - def test_init_label(self, label): - bas = self.cls(5, label=label) - assert bas.label == (str(label) if label is not None else self.cls.__name__) - - @pytest.mark.parametrize( - "attribute, value", - [ - ("label", None), - ("label", "label"), - ("n_basis_input", 1), - ("n_output_features", 5), - ], - ) - def test_attr_setter(self, attribute, value): - bas = self.cls(5) - with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" - ): - setattr(bas, attribute, value) - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_output_features(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_output_features is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_output_features == n_input * bas.n_basis_funcs - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_basis_input(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_basis_input is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) - - @pytest.mark.parametrize( - "n_input, expectation", - [ - (2, does_not_raise()), - (0, pytest.raises(ValueError, match="Input shape mismatch detected")), - (1, pytest.raises(ValueError, match="Input shape mismatch detected")), - (3, pytest.raises(ValueError, match="Input shape mismatch detected")), - ], - ) - def test_expected_input_number(self, n_input, expectation): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.randn(20, 2) - bas.compute_features(x) - with expectation: - bas.compute_features(np.random.randn(30, n_input)) - - @pytest.mark.parametrize( - "conv_kwargs, expectation", - [ - (dict(), does_not_raise()), - ( - dict(axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - (dict(shift=True), does_not_raise()), - ( - dict(shift=True, axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - (dict(shift=True, predictor_causality="causal"), does_not_raise()), - ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - ], - ) - def test_init_conv_kwargs(self, conv_kwargs, expectation): - with expectation: - self.cls(5, mode="conv", window_size=200, **conv_kwargs) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, does_not_raise()), - ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - - @pytest.mark.parametrize( - "width, window_size, n_basis_funcs, bounds, mode", - [ - (4, None, 10, (1, 2), "eval"), - (4, 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, width, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - width=width, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - keys = list(pars.keys()) - bas = self.cls( - width=width, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), - ], - ) - def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) - with expectation: - self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), - ], - ) - def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" - with expectation: - self.cls(window_size=10, n_basis_funcs=10, mode=mode) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - - bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) - - def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) - - def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, mode="eval", test="hi") - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation): - with expectation: - bas = self.cls(3, bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls(5, bounds=(1, 3)) - with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (None, None, np.arange(5), []), - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): - bounds = None if vmin is None else (vmin, vmax) - bas = self.cls(3, mode="eval", bounds=bounds) - out = bas.compute_features(samples) - assert np.all(np.isnan(out[nan_idx])) - valid_idx = list(set(samples).difference(nan_idx)) - assert np.all(~np.isnan(out[valid_idx])) - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval( - self, vmin, vmax, samples, nan_idx - ): - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) - _, out1 = bas.evaluate_on_grid(10) - _, out2 = bas_no_range.evaluate_on_grid(10) - assert np.allclose(out1, out2) - - @pytest.mark.parametrize( - "bounds, samples, nan_idx, mn, mx", - [ - (None, np.arange(5), [4], 0, 1), - ((0, 3), np.arange(5), [4], 0, 3), - ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3), - ], - ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=bounds) - x1, _ = bas.evaluate_on_grid(10) - x2, _ = bas_no_range.evaluate_on_grid(10) - assert np.allclose(x1, x2 * (mx - mn) + mn) - - @pytest.mark.parametrize( - "bounds, samples, exception", - [ - (None, np.arange(5), does_not_raise()), - ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ], - ) - def test_vmin_vmax_mode_conv(self, bounds, samples, exception): - with exception: - self.cls(3, mode="conv", window_size=10, bounds=bounds) - - def test_transformer_get_params(self): - bas = self.cls(5) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - params_basis = bas.get_params() - assert params_transf == params_basis - - -class TestMSplineBasis(BasisFuncsTesting): - cls = basis.EvalMSpline - - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_non_empty_samples(self, samples, mode, window_size): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls(5, mode=mode, window_size=window_size).compute_features( - samples - ) - else: - self.cls(5, mode=mode, window_size=window_size).compute_features(samples) - - @pytest.mark.parametrize( - "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] - ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5) - basis_obj.compute_features(eval_input) - - @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", - [ - (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ], - ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - with expectation: - basis_obj = self.cls(n_basis_funcs=5, mode="conv", window_size=5, **kwargs) - basis_obj.compute_features(np.ones(input1_shape)) - - @pytest.mark.parametrize("n_basis_funcs", [2, 3]) - @pytest.mark.parametrize("order", [1, 2]) - @pytest.mark.parametrize("window_size", [10, 15]) - @pytest.mark.parametrize( - "input_shape, expected_n_input", - [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), - ], - ) - def test_compute_features_conv_input( - self, - n_basis_funcs, - order, - window_size, - input_shape, - expected_n_input, - ): - x = np.ones(input_shape) - bas = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode="conv", - window_size=window_size, - ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * bas.n_basis_funcs - - @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) - @pytest.mark.parametrize("order", range(1, 6)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_compute_features_returns_expected_number_of_basis( - self, n_basis_funcs: int, order: int, mode, window_size - ): - """ - Verifies that the compute_features() method returns the expected number of basis functions. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, 100)) - if eval_basis.shape[1] != n_basis_funcs: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." - f"The number of basis is {n_basis_funcs}", - f"The first dimension of the output features is {eval_basis.shape[1]}", - ) - - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [4, 10, 100]) - @pytest.mark.parametrize("order", [1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, order, mode, window_size - ): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features." - f"The window size is {sample_size}", - f"The second dimension of the output features is {eval_basis.shape[0]}", - ) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, order, mode, window_size - ): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs) - if raise_exception: - with pytest.raises( - ValueError, - match=r"Spline order must be positive!|" - rf"{self.cls.__name__} `order` parameter cannot be larger than", - ): - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] - ) - def test_samples_range_matches_compute_features_requirements( - self, sample_range: tuple - ): - """ - Verifies that the compute_features() method can handle input range. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - basis_obj.compute_features(np.linspace(*sample_range, 100)) - - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): - """ - Confirms that the compute_features() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls( - n_basis_funcs=5, order=3, mode=mode, window_size=window_size - ) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): - """ - Checks that the evaluate_on_grid() method returns a grid of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): - """ - Ensures that the evaluate_on_grid() method returns basis functions of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, match=r"All sample counts provided must be greater" - ): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): - """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.evaluate_on_grid(*inputs) - - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls(n_basis).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) - - # TEST CALL - @pytest.mark.parametrize( - "num_input, expectation", - [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), - (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) - - @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_nan(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x[3] = np.nan - assert all(np.isnan(bas(x)[3])) - - @pytest.mark.parametrize( - "samples, expectation", - [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), - ], - ) - def test_call_input_type(self, samples, expectation): - bas = self.cls(5) - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10) - bas_eva = self.cls(5, mode="eval") - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) - - @pytest.mark.parametrize("mn, mx, expectation", [(0, 1, does_not_raise())]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_ is not None - - def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_.shape == (3, 5) - - def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=3) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 2 - with expectation: - self.cls(5, mode=mode, window_size=window_size) - - @pytest.mark.parametrize("label", [None, "label"]) - def test_init_label(self, label): - bas = self.cls(5, label=label) - assert bas.label == (str(label) if label is not None else self.cls.__name__) - - @pytest.mark.parametrize( - "attribute, value", - [ - ("label", None), - ("label", "label"), - ("n_basis_input", 1), - ("n_output_features", 5), - ], - ) - def test_attr_setter(self, attribute, value): - bas = self.cls(5) - with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" - ): - setattr(bas, attribute, value) - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_output_features(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_output_features is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_output_features == n_input * bas.n_basis_funcs - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_basis_input(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_basis_input is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) - - @pytest.mark.parametrize( - "n_input, expectation", - [ - (2, does_not_raise()), - (0, pytest.raises(ValueError, match="Input shape mismatch detected")), - (1, pytest.raises(ValueError, match="Input shape mismatch detected")), - (3, pytest.raises(ValueError, match="Input shape mismatch detected")), - ], - ) - def test_expected_input_number(self, n_input, expectation): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.randn(20, 2) - bas.compute_features(x) - with expectation: - bas.compute_features(np.random.randn(30, n_input)) - - @pytest.mark.parametrize( - "conv_kwargs, expectation", - [ - (dict(), does_not_raise()), - ( - dict(axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - (dict(shift=True), does_not_raise()), - ( - dict(shift=True, axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - (dict(shift=True, predictor_causality="causal"), does_not_raise()), - ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - ], - ) - def test_init_conv_kwargs(self, conv_kwargs, expectation): - with expectation: - self.cls(5, mode="conv", window_size=200, **conv_kwargs) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, does_not_raise()), - ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - - @pytest.mark.parametrize( - "order, window_size, n_basis_funcs, bounds, mode", - [ - (4, None, 10, (1, 2), "eval"), - (4, 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - order=order, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - keys = list(pars.keys()) - bas = self.cls( - order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) - - @pytest.mark.parametrize("n_basis_funcs", [10]) - @pytest.mark.parametrize("order", [-1, 0, 1, 2]) - def test_order_is_positive(self, n_basis_funcs, order): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order < 1 - if raise_exception: - with pytest.raises(ValueError, match=r"Spline order must be positive!"): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [5]) - @pytest.mark.parametrize( - "order, expectation", - [ - (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), - (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), - (0, pytest.raises(ValueError, match=r"Spline order must be positive")), - (1, does_not_raise()), - (2, does_not_raise()), - ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), - ), - ], - ) - def test_order_setter(self, n_basis_funcs, order, expectation): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=4) - with expectation: - basis_obj.order = order - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), - ], - ) - def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) - with expectation: - self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), - ], - ) - def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" - with expectation: - self.cls(window_size=10, n_basis_funcs=10, mode=mode) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - - bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) - - def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) - - def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, mode="eval", test="hi") - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation): - with expectation: - bas = self.cls(3, bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls(3, bounds=(1, 3)) - with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (None, None, np.arange(5), []), - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): - bounds = None if vmin is None else (vmin, vmax) - bas = self.cls(3, mode="eval", bounds=bounds) - out = bas.compute_features(samples) - assert np.all(np.isnan(out[nan_idx])) - valid_idx = list(set(samples).difference(nan_idx)) - assert np.all(~np.isnan(out[valid_idx])) - - @pytest.mark.parametrize( - "bounds, samples, nan_idx, scaling", - [ - (None, np.arange(5), [4], 1), - ((1, 4), np.arange(5), [0], 3), - ((1, 3), np.arange(5), [0, 4], 2), - ], - ) - def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( - self, bounds, samples, nan_idx, scaling - ): - """Check that the MSpline has the expected scaling property.""" - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=bounds) - _, out1 = bas.evaluate_on_grid(10) - _, out2 = bas_no_range.evaluate_on_grid(10) - # multiply by scaling to get the invariance - # mspline must integrate to one, if the support - # is reduced, the height of the spline increases. - assert np.allclose(out1 * scaling, out2) - - @pytest.mark.parametrize( - "bounds, samples, nan_idx, mn, mx", - [ - (None, np.arange(5), [4], 0, 1), - ((0, 3), np.arange(5), [4], 0, 3), - ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3), - ], - ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls(3, mode="eval", bounds=None) - bas = self.cls(3, mode="eval", bounds=bounds) - x1, _ = bas.evaluate_on_grid(10) - x2, _ = bas_no_range.evaluate_on_grid(10) - assert np.allclose(x1, x2 * (mx - mn) + mn) - - @pytest.mark.parametrize( - "bounds, samples, exception", - [ - (None, np.arange(5), does_not_raise()), - ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ], - ) - def test_vmin_vmax_mode_conv(self, bounds, samples, exception): - with exception: - self.cls(3, mode="conv", window_size=10, bounds=bounds) - - def test_transformer_get_params(self): - bas = self.cls(5) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - params_basis = bas.get_params() - assert params_transf == params_basis - - -class TestOrthExponentialBasis(BasisFuncsTesting): - cls = basis.OrthExponentialBasis - - # this class requires at leas `n_basis` samples - @pytest.mark.parametrize("samples", [[], [0] * 30, [0] * 20]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_non_empty_samples(self, samples, mode, window_size): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls( - 5, decay_rates=np.arange(1, 6), mode=mode, window_size=window_size - ).compute_features(samples) - else: - self.cls( - 5, decay_rates=np.arange(1, 6), mode=mode, window_size=window_size - ).compute_features(samples) - - @pytest.mark.parametrize( - "eval_input", - [0, [0] * 6, (0,) * 6, np.array([0] * 6), jax.numpy.array([0] * 6)], - ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - if isinstance(eval_input, int): - # OrthExponentialBasis is special -- cannot accept int input - with pytest.raises( - ValueError, - match="OrthExponentialBasis requires at least as many samples", - ): - basis_obj.compute_features(eval_input) - else: - basis_obj.compute_features(eval_input) - - @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", - [ - (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ], - ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - with expectation: - basis_obj = self.cls( - n_basis_funcs=5, - mode="conv", - window_size=5, - decay_rates=np.arange(1, 6), - **kwargs, - ) - basis_obj.compute_features(np.ones(input1_shape)) - - @pytest.mark.parametrize("n_basis_funcs", [2, 3]) - @pytest.mark.parametrize("window_size", [10, 15]) - @pytest.mark.parametrize( - "input_shape, expected_n_input", - [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), - ], - ) - def test_compute_features_conv_input( - self, - n_basis_funcs, - window_size, - input_shape, - expected_n_input, - ): - x = np.ones(input_shape) - bas = self.cls( - n_basis_funcs=n_basis_funcs, - mode="conv", - window_size=window_size, - decay_rates=0.1 * np.arange(1, n_basis_funcs + 1), - ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * n_basis_funcs - - @pytest.mark.parametrize("n_basis_funcs", [1, 2, 4, 8]) - @pytest.mark.parametrize("sample_size", [10, 1000]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_compute_features_returns_expected_number_of_basis( - self, n_basis_funcs, sample_size, mode, window_size - ): - """Tests whether the evaluate method returns the expected number of basis functions.""" - decay_rates = np.arange(1, 1 + n_basis_funcs) - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - decay_rates=decay_rates, - mode=mode, - window_size=window_size, - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[1] != n_basis_funcs: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." - f"The number of basis is {n_basis_funcs}", - f"The first dimension of the output features basis is {eval_basis.shape[1]}", - ) - return - - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [2, 10, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 30)]) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, mode, window_size - ): - """Tests whether the sample size of the features result matches that of the input.""" - decay_rates = np.linspace(0.1, 20, n_basis_funcs) - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - decay_rates=decay_rates, - mode=mode, - window_size=window_size, - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features." - f"The window size is {sample_size}", - f"The second dimension of the output features is {eval_basis.shape[0]}", - ) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - ( - np.linspace(-0.5, -0.001, 7), - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - ( - np.linspace(1.5, 2.0, 7), - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - ( - [-0.5, -0.1, -0.01, 1.5, 2, 3], - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax), decay_rates=np.linspace(0.1, 1, 5)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 30)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, mode, window_size - ): - """Tests whether the class instance has a minimum number of basis functions.""" - raise_exception = n_basis_funcs < 1 - decay_rates = np.arange(1, 1 + n_basis_funcs) - if raise_exception: - with pytest.raises( - ValueError, - match=f"Object class {self.cls.__name__} " - r"requires >= 1 basis elements\.", - ): - self.cls( - n_basis_funcs=n_basis_funcs, - decay_rates=decay_rates, - mode=mode, - window_size=window_size, - ) - else: - self.cls( - n_basis_funcs=n_basis_funcs, - decay_rates=decay_rates, - mode=mode, - window_size=window_size, - ) - - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): - """Tests whether the compute_features method correctly processes the number of required inputs.""" - basis_obj = self.cls( - n_basis_funcs=5, - decay_rates=np.arange(1, 6), - mode=mode, - window_size=window_size, - ) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 2, 3, 4, 5, 6, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): - """Tests whether the compute_features_on_grid method correctly outputs the grid mesh size.""" - basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = sample_size < 5 - if raise_exception: - with pytest.raises( - ValueError, - match=rf"{self.cls.__name__} requires at least as " - r"many samples as basis functions\!|" - r"All sample counts provided must be greater", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): - """Tests whether the evaluate_on_grid method correctly outputs the basis size.""" - basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - raise_exception = sample_size < 5 - if raise_exception: - with pytest.raises( - ValueError, - match=r"All sample counts provided must be greater|" - rf"{self.cls.__name__} requires at least as many samples as basis", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): - """Tests whether the evaluate_on_grid method correctly processes the Input dimensionality.""" - basis_obj = self.cls(n_basis_funcs=5, decay_rates=np.arange(1, 6)) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.evaluate_on_grid(*inputs) - - @pytest.mark.parametrize( - "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] - ) - def test_decay_rate_repetition(self, decay_rates): - """ - Tests whether the class instance correctly processes the decay rates without repetition. - A repeated rate causes linear algebra issues, and should raise a ValyeError exception. - """ - decay_rates = np.asarray(decay_rates, dtype=float) - # raise exception if any of the decay rate is repeated - raise_exception = len(set(decay_rates)) != len(decay_rates) - if raise_exception: - with pytest.raises( - ValueError, match=r"Two or more rate are repeated\! Repeating rate will" - ): - self.cls(n_basis_funcs=len(decay_rates), decay_rates=decay_rates) - else: - self.cls(n_basis_funcs=len(decay_rates), decay_rates=decay_rates) - - @pytest.mark.parametrize( - "decay_rates", [[], [1], [1, 2, 3], [1, 0.01, 0.02, 0.001]] - ) - @pytest.mark.parametrize("n_basis_func", [1, 2, 3, 4]) - def test_decay_rate_size_match_n_basis_func(self, decay_rates, n_basis_func): - """Tests whether the size of decay rates matches the number of basis functions.""" - raise_exception = len(decay_rates) != n_basis_func - decay_rates = np.asarray(decay_rates, dtype=float) - if raise_exception: - with pytest.raises( - ValueError, match="The number of basis functions must match the" - ): - self.cls(n_basis_funcs=n_basis_func, decay_rates=decay_rates) - else: - self.cls(n_basis_funcs=n_basis_func, decay_rates=decay_rates) - - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls(n_basis, np.arange(1, n_basis + 1)).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) - - # TEST CALL - @pytest.mark.parametrize( - "num_input, expectation", - [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), - (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) - - @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_nan(self, mode, window_size): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - x = np.linspace(0, 1, 15) - x[13] = np.nan - with does_not_raise(): - out = bas(x) - assert np.all(np.isnan(out[13])) - - @pytest.mark.parametrize( - "samples, expectation", - [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), - ], - ) - def test_call_input_type(self, samples, expectation): - bas = self.cls(5, np.linspace(0.1, 1, 5)) - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - bas_eva = self.cls(5, mode="eval", decay_rates=np.arange(1, 6)) - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - ( - np.linspace(-1, -0.5, 10), - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, decay_rates=np.linspace(0, 1, 5), bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls( - n_basis, - mode=mode, - window_size=window_size, - decay_rates=np.arange(1, n_basis + 1), - ) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - - @pytest.mark.parametrize("n_basis", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls( - n_basis, - mode=mode, - window_size=window_size, - decay_rates=np.arange(1, n_basis + 1), - ) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) - - @pytest.mark.parametrize("mn, mx, expectation", [(0, 1, does_not_raise())]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls( - 5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6) - ) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - bas._set_kernel(None) - assert bas.kernel_ is not None - - def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - bas._set_kernel(None) - assert bas.kernel_.shape == (10, 5) - - def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 10 - with expectation: - self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) - - @pytest.mark.parametrize("label", [None, "label"]) - def test_init_label(self, label): - bas = self.cls(5, label=label, decay_rates=np.arange(1, 6)) - assert bas.label == (str(label) if label is not None else self.cls.__name__) - - @pytest.mark.parametrize( - "attribute, value", - [ - ("label", None), - ("label", "label"), - ("n_basis_input", 1), - ("n_output_features", 5), - ], - ) - def test_attr_setter(self, attribute, value): - bas = self.cls(5, decay_rates=np.arange(1, 6)) - with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" - ): - setattr(bas, attribute, value) - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_output_features(self, n_input): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - assert bas.n_output_features is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_output_features == n_input * bas.n_basis_funcs - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_basis_input(self, n_input): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - assert bas.n_basis_input is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) - - @pytest.mark.parametrize( - "n_input, expectation", - [ - (2, does_not_raise()), - (0, pytest.raises(ValueError, match="Input shape mismatch detected")), - (1, pytest.raises(ValueError, match="Input shape mismatch detected")), - (3, pytest.raises(ValueError, match="Input shape mismatch detected")), - ], - ) - def test_expected_input_number(self, n_input, expectation): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - x = np.random.randn(20, 2) - bas.compute_features(x) - with expectation: - bas.compute_features(np.random.randn(30, n_input)) - - @pytest.mark.parametrize( - "conv_kwargs, expectation", - [ - (dict(), does_not_raise()), - ( - dict(axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - (dict(shift=True), does_not_raise()), - ( - dict(shift=True, axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - (dict(shift=True, predictor_causality="causal"), does_not_raise()), - ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - ], - ) - def test_init_conv_kwargs(self, conv_kwargs, expectation): - with expectation: - self.cls( - 5, - mode="conv", - window_size=200, - decay_rates=np.arange(1, 6), - **conv_kwargs, - ) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ("conv", 10, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, does_not_raise()), - ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) - - @pytest.mark.parametrize( - "decay_rates, window_size, n_basis_funcs, bounds, mode", - [ - (np.arange(1, 11), None, 10, (1, 2), "eval"), - (np.arange(1, 11), 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, - decay_rates, - window_size, - n_basis_funcs, - bounds, - mode: Literal["eval", "conv"], - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - decay_rates=decay_rates, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - keys = list(pars.keys()) - bas = self.cls( - decay_rates=decay_rates, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - mode=mode, - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), - ], - ) - def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) - with expectation: - self.cls( - decay_rates=np.arange(1, 11), - window_size=ws[mode], - n_basis_funcs=10, - mode=mode, - bounds=(1, 2), - ) - - bas = self.cls( - decay_rates=np.arange(1, 11), - window_size=10, - n_basis_funcs=10, - mode="conv", - bounds=None, - ) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), - ], - ) - def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" - with expectation: - self.cls( - decay_rates=np.arange(1, 11), - window_size=10, - n_basis_funcs=10, - mode=mode, - ) - - bas = self.cls( - decay_rates=np.arange(1, 11), window_size=10, n_basis_funcs=10, mode="conv" - ) - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - - bas = self.cls( - decay_rates=np.arange(1, 11), - window_size=None, - n_basis_funcs=10, - mode="eval", - ) - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) - - def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) - - def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, decay_rates=[1, 2, 3, 4, 5], mode="eval", test="hi") - - def test_transformer_get_params(self): - bas = self.cls(5, decay_rates=[1, 2, 3, 4, 5]) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - rates_transf = params_transf.pop("decay_rates") - params_basis = bas.get_params() - rates_basis = params_basis.pop("decay_rates") - assert params_transf == params_basis - assert np.all(rates_transf == rates_basis) - - -class TestBSplineBasis(BasisFuncsTesting): - cls = basis.BSplineBasis - - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_non_empty_samples(self, samples, mode, window_size): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls(5, mode=mode, window_size=window_size).compute_features( - samples - ) - else: - self.cls(5, mode=mode, window_size=window_size).compute_features(samples) - - @pytest.mark.parametrize( - "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] - ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5) - basis_obj.compute_features(eval_input) - - @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", - [ - (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ], - ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - with expectation: - basis_obj = self.cls(n_basis_funcs=5, mode="conv", window_size=5, **kwargs) - basis_obj.compute_features(np.ones(input1_shape)) - - @pytest.mark.parametrize("n_basis_funcs", [2, 3]) - @pytest.mark.parametrize("order", [1, 2]) - @pytest.mark.parametrize("window_size", [10, 15]) - @pytest.mark.parametrize( - "input_shape, expected_n_input", - [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), - ], - ) - def test_compute_features_conv_input( - self, - n_basis_funcs, - order, - window_size, - input_shape, - expected_n_input, - ): - x = np.ones(input_shape) - bas = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode="conv", - window_size=window_size, - ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * n_basis_funcs - - @pytest.mark.parametrize("n_basis_funcs", [6, 8, 10]) - @pytest.mark.parametrize("order", range(1, 6)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_compute_features_returns_expected_number_of_basis( - self, n_basis_funcs: int, order: int, mode, window_size - ): - """ - Verifies that the compute_features() method returns the expected number of basis functions. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, 100)) - if eval_basis.shape[1] != n_basis_funcs: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." - f"The number of basis is {n_basis_funcs}", - f"The first dimension of the output features is {eval_basis.shape[1]}", - ) - return - - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [4, 10, 100]) - @pytest.mark.parametrize("order", [1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, order, mode, window_size - ): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features." - f"The window size is {sample_size}", - f"The second dimension of the output features is {eval_basis.shape[0]}", - ) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, order, mode, window_size - ): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order > n_basis_funcs - if raise_exception: - with pytest.raises( - ValueError, - match=rf"{self.cls.__name__} `order` parameter cannot be larger than", - ): - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [10]) - @pytest.mark.parametrize("order", [-1, 0, 1, 2]) - def test_order_is_positive(self, n_basis_funcs, order): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order < 1 - if raise_exception: - with pytest.raises(ValueError, match=r"Spline order must be positive!"): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [5]) - @pytest.mark.parametrize( - "order, expectation", - [ - (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), - (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), - (0, pytest.raises(ValueError, match=r"Spline order must be positive")), - (1, does_not_raise()), - (2, does_not_raise()), - ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), - ), - ], - ) - def test_order_setter(self, n_basis_funcs, order, expectation): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=4) - with expectation: - basis_obj.order = order - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] - ) - def test_samples_range_matches_compute_features_requirements( - self, sample_range: tuple - ): - """ - Verifies that the compute_features() method can handle input range. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - basis_obj.compute_features(np.linspace(*sample_range, 100)) - - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): - """ - Confirms that the compute_features() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls( - n_basis_funcs=5, order=3, mode=mode, window_size=window_size - ) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): - """ - Checks that the evaluate_on_grid() method returns a grid of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, - match=r"Invalid input data|" - r"All sample counts provided must be greater", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): - """ - Ensures that the evaluate_on_grid() method returns basis functions of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, - match=r"All sample counts provided must be greater|" - r"Invalid input data", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): - """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.evaluate_on_grid(*inputs) - - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls(n_basis).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) - - # TEST CALL - @pytest.mark.parametrize( - "num_input, expectation", - [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), - (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) - - @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_nan(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x[3] = np.nan - assert all(np.isnan(bas(x)[3])) - - @pytest.mark.parametrize( - "samples, expectation", - [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), - ], - ) - def test_call_input_type(self, samples, expectation): - bas = self.cls(5) - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10) - bas_eva = self.cls(5, mode="eval") - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) - - @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - - @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) - - @pytest.mark.parametrize( - "mn, mx, expectation", [(0, 1, does_not_raise()), (-2, 2, does_not_raise())] - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_ is not None - - def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_.shape == (3, 5) - - def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=3) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", does_not_raise()), - ( - "invalid", - pytest.raises( - ValueError, match="`mode` should be either 'conv' or 'eval'" - ), - ), - ], - ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 2 - with expectation: - self.cls(5, mode=mode, window_size=window_size) - - @pytest.mark.parametrize("label", [None, "label"]) - def test_init_label(self, label): - bas = self.cls(5, label=label) - assert bas.label == (str(label) if label is not None else self.cls.__name__) - - @pytest.mark.parametrize( - "attribute, value", - [ - ("label", None), - ("label", "label"), - ("n_basis_input", 1), - ("n_output_features", 5), - ], - ) - def test_attr_setter(self, attribute, value): - bas = self.cls(5) - with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" - ): - setattr(bas, attribute, value) - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_output_features(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_output_features is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_output_features == n_input * bas.n_basis_funcs - - @pytest.mark.parametrize("n_input", [1, 2, 3]) - def test_set_num_basis_input(self, n_input): - bas = self.cls(5, mode="conv", window_size=10) - assert bas.n_basis_input is None - bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) - - @pytest.mark.parametrize( - "n_input, expectation", - [ - (2, does_not_raise()), - (0, pytest.raises(ValueError, match="Input shape mismatch detected")), - (1, pytest.raises(ValueError, match="Input shape mismatch detected")), - (3, pytest.raises(ValueError, match="Input shape mismatch detected")), - ], - ) - def test_expected_input_number(self, n_input, expectation): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.randn(20, 2) - bas.compute_features(x) - with expectation: - bas.compute_features(np.random.randn(30, n_input)) - - @pytest.mark.parametrize( - "conv_kwargs, expectation", - [ - (dict(), does_not_raise()), - ( - dict(axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - (dict(shift=True), does_not_raise()), - ( - dict(shift=True, axis=0), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - (dict(shift=True, predictor_causality="causal"), does_not_raise()), - ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), - ), - ], - ) - def test_init_conv_kwargs(self, conv_kwargs, expectation): - with expectation: - self.cls(5, mode="conv", window_size=200, **conv_kwargs) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, does_not_raise()), - ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - - @pytest.mark.parametrize( - "order, window_size, n_basis_funcs, bounds, mode", - [ - (3, None, 10, (1, 2), "eval"), - (3, 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - order=order, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - keys = list(pars.keys()) - bas = self.cls( - order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), - ], - ) - def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) - with expectation: - self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), - ], - ) - def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" - with expectation: - self.cls(window_size=10, n_basis_funcs=10, mode=mode) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - - bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) - - def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) - - def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, mode="eval", test="hi") - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation): - with expectation: - bas = self.cls(5, bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls(5, bounds=(1, 3)) - with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (None, None, np.arange(5), []), - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): - bounds = None if vmin is None else (vmin, vmax) - bas = self.cls(5, mode="eval", bounds=bounds) - out = bas.compute_features(samples) - assert np.all(np.isnan(out[nan_idx])) - valid_idx = list(set(samples).difference(nan_idx)) - assert np.all(~np.isnan(out[valid_idx])) - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval( - self, vmin, vmax, samples, nan_idx - ): - bas_no_range = self.cls(5, mode="eval", bounds=None) - bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) - _, out1 = bas.evaluate_on_grid(10) - _, out2 = bas_no_range.evaluate_on_grid(10) - assert np.allclose(out1, out2) - - @pytest.mark.parametrize( - "bounds, samples, nan_idx, mn, mx", - [ - (None, np.arange(5), [4], 0, 1), - ((0, 3), np.arange(5), [4], 0, 3), - ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3), - ], - ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls(5, mode="eval", bounds=None) - bas = self.cls(5, mode="eval", bounds=bounds) - x1, _ = bas.evaluate_on_grid(10) - x2, _ = bas_no_range.evaluate_on_grid(10) - assert np.allclose(x1, x2 * (mx - mn) + mn) - - @pytest.mark.parametrize( - "bounds, samples, exception", - [ - (None, np.arange(5), does_not_raise()), - ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ], - ) - def test_vmin_vmax_mode_conv(self, bounds, samples, exception): - with exception: - self.cls(5, mode="conv", window_size=10, bounds=bounds) - - def test_transformer_get_params(self): - bas = self.cls(5) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - params_basis = bas.get_params() - assert params_transf == params_basis - - -class TestCyclicBSplineBasis(BasisFuncsTesting): - cls = basis.CyclicBSplineBasis - - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_non_empty_samples(self, samples, mode, window_size): - if mode == "conv" and len(samples) == 1: - return - if len(samples) == 0: - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - self.cls(5, mode=mode, window_size=window_size).compute_features( - samples - ) - else: - self.cls(5, mode=mode, window_size=window_size).compute_features(samples) - - @pytest.mark.parametrize( - "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] - ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls(n_basis_funcs=5) - basis_obj.compute_features(eval_input) - - @pytest.mark.parametrize( - "kwargs, input1_shape, expectation", - [ - (dict(), (10,), does_not_raise()), - ( - dict(axis=0), - (10,), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ( - dict(axis=1), - (2, 10), - pytest.raises( - ValueError, match="Setting the `axis` parameter is not allowed" - ), - ), - ], - ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - with expectation: - basis_obj = self.cls(n_basis_funcs=5, mode="conv", window_size=5, **kwargs) - basis_obj.compute_features(np.ones(input1_shape)) - - @pytest.mark.parametrize("n_basis_funcs", [4, 5]) - @pytest.mark.parametrize("order", [3, 2]) - @pytest.mark.parametrize("window_size", [10, 15]) - @pytest.mark.parametrize( - "input_shape, expected_n_input", - [ - ((20,), 1), - ((20, 1), 1), - ((20, 2), 2), - ((20, 1, 2), 2), - ((20, 2, 1), 2), - ((20, 2, 2), 4), - ], - ) - def test_compute_features_conv_input( - self, - n_basis_funcs, - order, - window_size, - input_shape, - expected_n_input, - ): - x = np.ones(input_shape) - bas = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode="conv", - window_size=window_size, - ) - out = bas.compute_features(x) - assert out.shape[1] == expected_n_input * n_basis_funcs - - @pytest.mark.parametrize("n_basis_funcs", [8, 10]) - @pytest.mark.parametrize("order", range(2, 6)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_compute_features_returns_expected_number_of_basis( - self, n_basis_funcs: int, order: int, mode, window_size - ): - """ - Verifies that the compute_features() method returns the expected number of basis functions. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, 100)) - if eval_basis.shape[1] != n_basis_funcs: - raise ValueError( - "Dimensions do not agree: The number of basis should match the first dimension of the output features." - f"The number of basis is {n_basis_funcs}", - f"The first dimension of the output features is {eval_basis.shape[0]}", - ) - return - - @pytest.mark.parametrize("sample_size", [100, 1000]) - @pytest.mark.parametrize("n_basis_funcs", [8, 10, 100]) - @pytest.mark.parametrize("order", [2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_funcs, sample_size, order, mode, window_size - ): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, order=order, mode=mode, window_size=window_size - ) - eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) - if eval_basis.shape[0] != sample_size: - raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features." - f"The window size is {sample_size}", - f"The second dimension of the output features is {eval_basis.shape[1]}", - ) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("order", [2, 3, 4, 5]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_minimum_number_of_basis_required_is_matched( - self, n_basis_funcs, order, mode, window_size - ): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order > n_basis_funcs - if raise_exception: - with pytest.raises( - ValueError, - match=rf"{self.cls.__name__} `order` parameter cannot be larger than", - ): - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls( - n_basis_funcs=n_basis_funcs, - order=order, - mode=mode, - window_size=window_size, - ) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [10]) - @pytest.mark.parametrize("order", [-1, 0, 2, 3]) - def test_order_is_positive(self, n_basis_funcs, order): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order < 1 - if raise_exception: - with pytest.raises(ValueError, match=r"Spline order must be positive!"): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [5]) - @pytest.mark.parametrize( - "order, expectation", - [ - (1.5, pytest.raises(ValueError, match=r"Spline order must be an integer")), - (-1, pytest.raises(ValueError, match=r"Spline order must be positive")), - (0, pytest.raises(ValueError, match=r"Spline order must be positive")), - (1, does_not_raise()), - (2, does_not_raise()), - ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), - ), - ], - ) - def test_order_setter(self, n_basis_funcs, order, expectation): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=4) - with expectation: - basis_obj.order = order - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize("n_basis_funcs", [10]) - @pytest.mark.parametrize("order", [1, 2, 3]) - def test_order_1_invalid(self, n_basis_funcs, order): - """ - Verifies that the minimum number of basis functions and order required (i.e., at least 1) and - order < #basis are enforced. - """ - raise_exception = order == 1 - if raise_exception: - with pytest.raises( - ValueError, match=r"Order >= 2 required for cyclic B-spline" - ): - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - else: - basis_obj = self.cls(n_basis_funcs=n_basis_funcs, order=order) - basis_obj.compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] - ) - def test_samples_range_matches_compute_features_requirements( - self, sample_range: tuple - ): - """ - Verifies that the compute_features() method can handle input range. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - basis_obj.compute_features(np.linspace(*sample_range, 100)) - - @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features( - self, n_input, mode, window_size - ): - """ - Confirms that the compute_features() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls( - n_basis_funcs=5, order=3, mode=mode, window_size=window_size - ) - inputs = [np.linspace(0, 1, 20)] * n_input - if n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match="Input dimensionality mismatch", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.compute_features(*inputs) - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size): - """ - Checks that the evaluate_on_grid() method returns a grid of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, - match=r"Empty sample array provided\. At least one sample is required|" - "All sample counts provided must be greater", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - grid, _ = basis_obj.evaluate_on_grid(sample_size) - assert grid.shape[0] == sample_size - - @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - def test_evaluate_on_grid_basis_size(self, sample_size): - """ - Ensures that the evaluate_on_grid() method returns basis functions of the expected size. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - raise_exception = sample_size <= 0 - if raise_exception: - with pytest.raises( - ValueError, - match="All sample counts provided must be greater|" - r"Empty sample array provided\. At least one sample is required for", - ): - basis_obj.evaluate_on_grid(sample_size) - else: - _, eval_basis = basis_obj.evaluate_on_grid(sample_size) - assert eval_basis.shape[0] == sample_size - - @pytest.mark.parametrize("n_input", [0, 1, 2]) - def test_evaluate_on_grid_input_number(self, n_input): - """ - Validates that the evaluate_on_grid() method correctly handles the number of input samples that are provided. - """ - basis_obj = self.cls(n_basis_funcs=5, order=3) - inputs = [10] * n_input - if n_input == 0: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) missing 1 required positional argument", - ) - elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises( - TypeError, - match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", - ) - else: - expectation = does_not_raise() - with expectation: - basis_obj.evaluate_on_grid(*inputs) - - @pytest.mark.parametrize("sample_size", [30]) - @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size): - iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) - inp = nap.Tsd( - t=np.linspace(0, 1, sample_size), - d=np.linspace(0, 1, sample_size), - time_support=iset, - ) - out = self.cls(n_basis).compute_features(inp) - assert isinstance(out, nap.TsdFrame) - assert np.all(out.time_support.values == inp.time_support.values) - - # TEST CALL - @pytest.mark.parametrize( - "num_input, expectation", - [ - (0, pytest.raises(TypeError, match="Input dimensionality mismatch")), - (1, does_not_raise()), - (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_num(self, num_input, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) - - @pytest.mark.parametrize( - "inp, expectation", - [ - (np.linspace(0, 1, 10), does_not_raise()), - (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), - ], - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_input_shape(self, inp, mode, window_size, expectation): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(inp) - - @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_axis(self, time_axis_shape, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_nan(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x[3] = np.nan - assert all(np.isnan(bas(x)[3])) - - @pytest.mark.parametrize( - "samples, expectation", - [ - (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), - ), - ], - ) - def test_call_input_type(self, samples, expectation): - bas = self.cls(5) - with expectation: - bas(samples) - - def test_call_equivalent_in_conv(self): - bas_con = self.cls(5, mode="conv", window_size=10) - bas_eva = self.cls(5, mode="eval") - x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eva(x)) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_pynapple_support(self, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) - assert isinstance(y_nap, nap.TsdFrame) - assert np.all(y == y_nap.d) - assert np.all(y_nap.t == x_nap.t) - - @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_basis_number(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis - - @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_non_empty(self, n_basis, mode, window_size): - bas = self.cls(n_basis, mode=mode, window_size=window_size) - with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) - - @pytest.mark.parametrize( - "mn, mx, expectation", [(0, 1, does_not_raise()), (-2, 2, does_not_raise())] - ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) - def test_call_sample_range(self, mn, mx, expectation, mode, window_size): - bas = self.cls(5, mode=mode, window_size=window_size) - with expectation: - bas(np.linspace(mn, mx, 10)) - - def test_fit_kernel(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_ is not None - - def test_fit_kernel_shape(self): - bas = self.cls(5, mode="conv", window_size=3) - bas._set_kernel(None) - assert bas.kernel_.shape == (3, 5) - - def test_transform_fails(self): - bas = self.cls(5, mode="conv", window_size=3) - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): - bas._compute_features(np.linspace(0, 1, 10)) - - @pytest.mark.parametrize( - "mode, ws, expectation", - [ - ("conv", 2, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ("eval", None, does_not_raise()), - ( - "eval", - 10, - pytest.raises( - ValueError, - match=r"If basis is in `mode=='eval'`, `window_size` should be None", - ), - ), - ], - ) - def test_init_window_size(self, mode, ws, expectation): - with expectation: - self.cls(5, mode=mode, window_size=ws) - - @pytest.mark.parametrize( - "order, window_size, n_basis_funcs, bounds, mode", - [ - (3, None, 10, (1, 2), "eval"), - (3, 10, 10, None, "conv"), - ], - ) - def test_set_params( - self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] - ): - """Test the read-only and read/write property of the parameters.""" - pars = dict( - order=order, - window_size=window_size, - n_basis_funcs=n_basis_funcs, - bounds=bounds, - ) - keys = list(pars.keys()) - bas = self.cls( - order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode - ) - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} - bas.set_params(**par_set) - assert isinstance(bas, self.cls) - - for i in range(len(pars)): - for j in range(i + 1, len(pars)): - with pytest.raises( - AttributeError, - match="can't set attribute 'mode'|property 'mode' of ", - ): - par_set = { - keys[i]: pars[keys[i]], - keys[j]: pars[keys[j]], - "mode": mode, - } - bas.set_params(**par_set) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("eval", does_not_raise()), - ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), - ], - ) - def test_set_bounds(self, mode, expectation): - ws = dict(eval=None, conv=10) - with expectation: - self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) - with pytest.raises(ValueError, match="`bounds` should only be set"): - bas.set_params(bounds=(1, 2)) - - @pytest.mark.parametrize( - "mode, expectation", - [ - ("conv", does_not_raise()), - ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), - ], - ) - def test_set_window_size(self, mode, expectation): - """Test window size set behavior.""" - with expectation: - self.cls(window_size=10, n_basis_funcs=10, mode=mode) - - bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): - bas.set_params(window_size=None) - - bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") - with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): - bas.set_params(window_size=10) - - def test_convolution_is_performed(self): - bas = self.cls(5, mode="conv", window_size=10) - x = np.random.normal(size=100) - conv = bas.compute_features(x) - conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) - valid = ~np.isnan(conv) - assert np.all(conv[valid] == conv_2[valid]) - assert np.all(np.isnan(conv_2[~valid])) - - def test_conv_kwargs_error(self): - with pytest.raises(ValueError, match="kwargs should only be set"): - self.cls(5, mode="eval", test="hi") - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation): - with expectation: - bas = self.cls(5, bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_setter(self, bounds, expectation): - bas = self.cls(5, bounds=(1, 3)) - with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (None, None, np.arange(5), []), - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): - bounds = None if vmin is None else (vmin, vmax) - bas = self.cls(5, mode="eval", bounds=bounds) - out = bas.compute_features(samples) - assert np.all(np.isnan(out[nan_idx])) - valid_idx = list(set(samples).difference(nan_idx)) - assert np.all(~np.isnan(out[valid_idx])) - - @pytest.mark.parametrize( - "vmin, vmax, samples, nan_idx", - [ - (0, 3, np.arange(5), [4]), - (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]), - ], - ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval( - self, vmin, vmax, samples, nan_idx - ): - bas_no_range = self.cls(5, mode="eval", bounds=None) - bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) - _, out1 = bas.evaluate_on_grid(10) - _, out2 = bas_no_range.evaluate_on_grid(10) - assert np.allclose(out1, out2) - - @pytest.mark.parametrize( - "bounds, samples, nan_idx, mn, mx", - [ - (None, np.arange(5), [4], 0, 1), - ((0, 3), np.arange(5), [4], 0, 3), - ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3), - ], - ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): - bas_no_range = self.cls(5, mode="eval", bounds=None) - bas = self.cls(5, mode="eval", bounds=bounds) - x1, _ = bas.evaluate_on_grid(10) - x2, _ = bas_no_range.evaluate_on_grid(10) - assert np.allclose(x1, x2 * (mx - mn) + mn) - - @pytest.mark.parametrize( - "bounds, samples, exception", - [ - (None, np.arange(5), does_not_raise()), - ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ], - ) - def test_vmin_vmax_mode_conv(self, bounds, samples, exception): - with exception: - self.cls(5, mode="conv", window_size=10, bounds=bounds) - - def test_transformer_get_params(self): - bas = self.cls(5) - bas_transformer = bas.to_transformer() - params_transf = bas_transformer.get_params() - params_transf.pop("_basis") - params_basis = bas.get_params() - assert params_transf == params_basis + basis_obj = self.cls["eval"](n_basis_funcs=5, order=3) + basis_obj.compute_features(np.linspace(*sample_range, 100)) class CombinedBasis(BasisFuncsTesting): From 07cad1892e48d6a37ba4cdef22c03becdc1db755 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 12:10:39 -0500 Subject: [PATCH 042/109] fixed additive basis tests --- tests/test_basis.py | 394 +++++++++++++++++++++----------------------- 1 file changed, 190 insertions(+), 204 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index b1290fa3..e4da4b9b 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -11,6 +11,7 @@ import numpy as np import pynapple as nap import pytest +from scipy.sparse.csgraph import depth_first_tree from scipy.stats import expon import utils_testing @@ -50,6 +51,10 @@ def class_specific_params(): ) +def trim_kwargs(cls, kwargs, class_specific_params): + return {key: value for key, value in kwargs.items() if key in class_specific_params[cls.__name__]} + + def extra_decay_rates(cls, n_basis): name = cls.__name__ if "OrthExp" in name: @@ -682,7 +687,7 @@ def test_compute_features_conv_input( width=width,) # figure out which kwargs needs to be removed - kwargs = {key: value for key, value in kwargs.items() if key in class_specific_params[cls["conv"].__name__]} + kwargs = trim_kwargs(cls["conv"], kwargs, class_specific_params) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -1585,51 +1590,30 @@ class CombinedBasis(BasisFuncsTesting): cls = None @staticmethod - def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): + def instantiate_basis(n_basis, basis_class, class_specific_params, window_size=10): """Instantiate and return two basis of the type specified.""" - if mode == "eval": - window_size = None - - if basis_class == basis.EvalMSpline: - basis_obj = basis_class( - n_basis_funcs=n_basis, order=4, mode=mode, window_size=window_size - ) - elif basis_class in [basis.RaisedCosineBasisLinear, basis.RaisedCosineBasisLog]: - basis_obj = basis_class( - n_basis_funcs=n_basis, mode=mode, window_size=window_size - ) - elif basis_class == basis.OrthExponentialBasis: - basis_obj = basis_class( - n_basis_funcs=n_basis, - decay_rates=np.arange(1, 1 + n_basis), - mode=mode, - window_size=window_size, - ) - elif basis_class == basis.BSplineBasis: - basis_obj = basis_class( - n_basis_funcs=n_basis, order=3, mode=mode, window_size=window_size - ) - elif basis_class == basis.CyclicBSplineBasis: - basis_obj = basis_class( - n_basis_funcs=n_basis, order=3, mode=mode, window_size=window_size - ) - elif basis_class == AdditiveBasis: - b1 = basis.EvalMSpline( - n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size - ) - b2 = basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis + 1) + kwargs = { + "n_basis_funcs": n_basis, + "window_size": window_size, + "order": 2, + "decay_rates": np.arange(1, 1 + n_basis) + } + + if basis_class == AdditiveBasis: + kwargs_mspline = trim_kwargs(basis.EvalMSpline, kwargs, class_specific_params) + kwargs_raised_cosine = trim_kwargs(basis.ConvRaisedCosineLinear, kwargs, class_specific_params) + b1 = basis.EvalMSpline(**kwargs_mspline) + b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) basis_obj = b1 + b2 elif basis_class == MultiplicativeBasis: - b1 = basis.EvalMSpline( - n_basis_funcs=n_basis, order=2, mode=mode, window_size=window_size - ) - b2 = basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis + 1) + kwargs_mspline = trim_kwargs(basis.EvalMSpline, kwargs, class_specific_params) + kwargs_raised_cosine = trim_kwargs(basis.ConvRaisedCosineLinear, kwargs, class_specific_params) + b1 = basis.EvalMSpline(**kwargs_mspline) + b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) basis_obj = b1 * b2 else: - raise ValueError( - f"Test for basis addition not implemented for basis of type {basis_class}!" - ) + basis_obj = basis_class(**trim_kwargs(basis_class, kwargs, class_specific_params)) return basis_obj @@ -1637,15 +1621,13 @@ class TestAdditiveBasis(CombinedBasis): cls = AdditiveBasis @pytest.mark.parametrize( - "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] + "samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]] ) - @pytest.mark.parametrize("mode, ws", [("conv", 2), ("eval", None)]) - def test_non_empty_samples(self, samples, mode, ws): - if mode == "conv" and len(samples[0]) < 2: - return - basis_obj = basis.EvalMSpline(5, mode=mode, window_size=ws) + basis.EvalMSpline( - 5, mode=mode, window_size=ws - ) + @pytest.mark.parametrize("base_cls", [basis.EvalBSpline, basis.ConvBSpline]) + def test_non_empty_samples(self, base_cls, samples, class_specific_params): + kwargs = {"window_size": 2, "n_basis_funcs": 5} + kwargs = trim_kwargs(base_cls, kwargs, class_specific_params) + basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -1676,9 +1658,9 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("sample_size", [10, 1000]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [10]) def test_compute_features_returns_expected_number_of_basis( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -1686,10 +1668,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj @@ -1708,18 +1690,18 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [10]) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -1738,19 +1720,19 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize("n_input", [0, 1, 2, 3, 10, 30]) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [10]) def test_number_of_required_inputs_compute_features( - self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -1772,13 +1754,13 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b + self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params ): """ Test whether the resulting meshgrid size matches the sample size input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -1792,13 +1774,13 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b + self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -1811,14 +1793,14 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b + self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input required_dim = ( @@ -1839,7 +1821,7 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size + self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -1847,9 +1829,10 @@ def test_pynapple_support_compute_features( d=np.linspace(0, 1, sample_size), time_support=iset, ) - basis_add = self.instantiate_basis(n_basis_a, basis_a) + self.instantiate_basis( - n_basis_b, basis_b - ) + basis_add = (self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + + self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + )) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) # check type @@ -1863,15 +1846,15 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) def test_call_input_num( - self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -1890,7 +1873,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -1902,55 +1885,55 @@ def test_call_input_shape( basis_a, basis_b, inp, - mode, window_size, expectation, + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: basis_obj(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_sample_axis( - self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size): + def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params): if ( basis_a == basis.OrthExponentialBasis or basis_b == basis.OrthExponentialBasis ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -1962,39 +1945,39 @@ def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_siz @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="eval", window_size=None + n_basis_a, basis_a, class_specific_params, window_size=3 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="eval", window_size=None + n_basis_b, basis_b, class_specific_params, window_size=3 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=8 + n_basis_a, basis_a, class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=8 + n_basis_b, basis_b, class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con(*x) == bas_eva(*x)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2006,37 +1989,37 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2050,7 +2033,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2064,8 +2047,9 @@ def test_call_sample_range( mn, mx, expectation, - mode, + window_size, + class_specific_params ): if expectation == "check": if ( @@ -2078,10 +2062,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2091,15 +2075,15 @@ def test_call_sample_range( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - bas._set_kernel(None) + bas._set_kernel() def check_kernel(basis_obj): has_kern = [] @@ -2118,25 +2102,27 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): + if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: + context = does_not_raise() + else: + context = pytest.raises(ValueError, match="You must call `_set_kernel` before `_compute_features`") + with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality bas._compute_features(*x) @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(11, mode="conv", window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(11, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2147,8 +2133,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(10, mode="conv", window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2166,8 +2152,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(10, mode="conv", window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas = bas1 + bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2181,12 +2167,12 @@ class TestMultiplicativeBasis(CombinedBasis): @pytest.mark.parametrize( "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] ) - @pytest.mark.parametrize("mode, ws", [("conv", 2), ("eval", None)]) - def test_non_empty_samples(self, samples, mode, ws): + @pytest.mark.parametrize(" ws", [3, None]) + def test_non_empty_samples(self, samples, ws): if mode == "conv" and len(samples[0]) < 2: return - basis_obj = basis.EvalMSpline(5, mode=mode, window_size=ws) * basis.EvalMSpline( - 5, mode=mode, window_size=ws + basis_obj = basis.EvalMSpline(5, window_size=ws) * basis.EvalMSpline( + 5, window_size=ws ) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -2218,9 +2204,9 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("sample_size", [10, 1000]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [None, 10]) def test_compute_features_returns_expected_number_of_basis( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2228,10 +2214,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2251,19 +2237,19 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [None, 10]) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2281,19 +2267,19 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize("n_input", [0, 1, 2, 3, 10, 30]) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) + @pytest.mark.parametrize("window_size", [None, 10]) def test_number_of_required_inputs_compute_features( - self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2430,15 +2416,15 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) def test_call_input_num( - self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2457,7 +2443,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2469,55 +2455,55 @@ def test_call_input_shape( basis_a, basis_b, inp, - mode, + window_size, expectation, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: basis_obj(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_sample_axis( - self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size): + def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size): if ( basis_a == basis.OrthExponentialBasis or basis_b == basis.OrthExponentialBasis ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2531,37 +2517,37 @@ def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_siz @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="eval", window_size=None + n_basis_a, basis_a, "eval", window_size=None ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="eval", window_size=None + n_basis_b, basis_b, window_size=None ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=8 + n_basis_a, basis_a, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=8 + n_basis_b, basis_b, window_size=8 ) bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con(*x) == bas_eva(*x)) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -2573,37 +2559,37 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, mode, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2617,7 +2603,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) + @pytest.mark.parametrize(" window_size", [None, 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2631,7 +2617,7 @@ def test_call_sample_range( mn, mx, expectation, - mode, + window_size, ): if expectation == "check": @@ -2645,10 +2631,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode=mode, window_size=window_size + n_basis_a, basis_a, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode=mode, window_size=window_size + n_basis_b, basis_b, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -2660,10 +2646,10 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=10 + n_basis_a, basis_a, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=10 + n_basis_b, basis_b, window_size=10 ) bas = basis_a_obj * basis_b_obj bas._set_kernel(None) @@ -2687,10 +2673,10 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, mode="conv", window_size=10 + n_basis_a, basis_a, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, mode="conv", window_size=10 + n_basis_b, basis_b, window_size=10 ) bas = basis_a_obj * basis_b_obj with pytest.raises( @@ -2702,8 +2688,8 @@ def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(11, mode="conv", window_size=10) + bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) + bas2 = basis.BSplineBasis(11, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2714,8 +2700,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(10, mode="conv", window_size=10) + bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) + bas2 = basis.BSplineBasis(10, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2733,8 +2719,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(10, mode="conv", window_size=10) + bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) + bas2 = basis.BSplineBasis(10, window_size=10) bas = bas1 * bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2744,8 +2730,8 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_n_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, mode="conv", window_size=10) - bas2 = basis.BSplineBasis(10, mode="conv", window_size=10) + bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) + bas2 = basis.BSplineBasis(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) @@ -3144,7 +3130,7 @@ def test_transformerbasis_dir(basis_cls): ], ) def test_transformerbasis_sk_clone_kernel_noned(basis_cls): - orig_bas = basis_cls(10, mode="conv", window_size=5) + orig_bas = basis_cls(10, window_size=5) trans_bas = basis.TransformerBasis(orig_bas) # kernel should be saved in the object after fit @@ -3230,14 +3216,14 @@ def test_multi_epoch_pynapple_basis( if basis_cls == AdditiveBasis: bas = basis.BSplineBasis( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, ) bas = bas + basis.RaisedCosineBasisLinear( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3245,7 +3231,7 @@ def test_multi_epoch_pynapple_basis( elif basis_cls == MultiplicativeBasis: bas = basis.RaisedCosineBasisLog( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3254,7 +3240,7 @@ def test_multi_epoch_pynapple_basis( else: bas = basis_cls( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3318,14 +3304,14 @@ def test_multi_epoch_pynapple_basis_transformer( if basis_cls == AdditiveBasis: bas = basis.BSplineBasis( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, ) bas = bas + basis.RaisedCosineBasisLinear( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3333,7 +3319,7 @@ def test_multi_epoch_pynapple_basis_transformer( elif basis_cls == MultiplicativeBasis: bas = basis.RaisedCosineBasisLog( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3342,7 +3328,7 @@ def test_multi_epoch_pynapple_basis_transformer( else: bas = basis_cls( 5, - mode="conv", + window_size=window_size, predictor_causality=predictor_causality, shift=shift, @@ -3470,7 +3456,7 @@ def test__get_splitter( for i, val in enumerate( zip([bas1, bas2, bas3], [mode1, mode2, mode3], extra_kwargs) ): - bas, mode, kwrgs = val + bas, kwrgs = val if bas != basis.OrthExponentialBasis: kwrgs.pop("decay_rates") if mode == "eval": @@ -3650,13 +3636,13 @@ def test__get_splitter_split_by_input( bas1_instance = bas1( n_basis[0], - mode=mode, + **extra_kwargs[0], label="1", ) bas2_instance = bas2( n_basis[1], - mode=mode, + **extra_kwargs[1], label="2", ) @@ -3755,13 +3741,13 @@ def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes): bas1_instance = bas1( n_basis[0], - mode=mode, + **extra_kwargs[0], label="1", ) bas2_instance = bas2( n_basis[1], - mode=mode, + **extra_kwargs[1], label="2", ) From ea8e2083ae65cf3a6e8cd5d4f7b1657a38974b85 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 12:24:09 -0500 Subject: [PATCH 043/109] fixed multiplicative basis tests --- tests/test_basis.py | 176 ++++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index e4da4b9b..bcd83860 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2167,13 +2167,9 @@ class TestMultiplicativeBasis(CombinedBasis): @pytest.mark.parametrize( "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] ) - @pytest.mark.parametrize(" ws", [3, None]) + @pytest.mark.parametrize(" ws", [3]) def test_non_empty_samples(self, samples, ws): - if mode == "conv" and len(samples[0]) < 2: - return - basis_obj = basis.EvalMSpline(5, window_size=ws) * basis.EvalMSpline( - 5, window_size=ws - ) + basis_obj = basis.EvalMSpline(5) * basis.EvalRaisedCosineLinear(5) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -2204,9 +2200,9 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("sample_size", [10, 1000]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("window_size", [None, 10]) + @pytest.mark.parametrize("window_size", [10]) def test_compute_features_returns_expected_number_of_basis( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2214,10 +2210,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2237,19 +2233,19 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("window_size", [None, 10]) + @pytest.mark.parametrize("window_size", [10]) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size + self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2267,19 +2263,19 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize("n_input", [0, 1, 2, 3, 10, 30]) @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("window_size", [None, 10]) + @pytest.mark.parametrize("window_size", [10]) def test_number_of_required_inputs_compute_features( - self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size + self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2301,13 +2297,13 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b + self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params ): """ Test whether the resulting meshgrid size matches the sample size input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -2321,13 +2317,13 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b + self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -2340,14 +2336,14 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b + self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input required_dim = ( @@ -2362,19 +2358,19 @@ def test_evaluate_on_grid_input_number( with expectation: basis_obj.evaluate_on_grid(*inputs) - @pytest.mark.parametrize("basis_a", [basis.EvalMSpline]) - @pytest.mark.parametrize("basis_b", [basis.OrthExponentialBasis]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) @pytest.mark.parametrize("sample_size_a", [11, 12]) @pytest.mark.parametrize("sample_size_b", [11, 12]) def test_inconsistent_sample_sizes( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size_a, sample_size_b + self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size_a, sample_size_b, class_specific_params ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b) + basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) basis_obj = basis_a_obj * basis_b_obj if raise_exception: with pytest.raises( @@ -2395,7 +2391,7 @@ def test_inconsistent_sample_sizes( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size + self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2404,8 +2400,8 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_prod = self.instantiate_basis( - n_basis_a, basis_a - ) * self.instantiate_basis(n_basis_b, basis_b) + n_basis_a, basis_a, class_specific_params, window_size=10 + ) * self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) @@ -2416,15 +2412,15 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) def test_call_input_num( - self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2443,7 +2439,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2458,52 +2454,53 @@ def test_call_input_shape( window_size, expectation, + class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: basis_obj(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_sample_axis( - self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size): + def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params): if ( basis_a == basis.OrthExponentialBasis or basis_b == basis.OrthExponentialBasis ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2515,39 +2512,39 @@ def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, "eval", window_size=None + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=None + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=8 + n_basis_a, basis_a, class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=8 + n_basis_b, basis_b, class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con(*x) == bas_eva(*x)) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -2559,37 +2556,37 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2603,7 +2600,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [None, 3]) + @pytest.mark.parametrize(" window_size", [ 3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2619,6 +2616,7 @@ def test_call_sample_range( expectation, window_size, + class_specific_params ): if expectation == "check": if ( @@ -2631,10 +2629,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -2644,15 +2642,15 @@ def test_call_sample_range( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - bas._set_kernel(None) + bas._set_kernel() def check_kernel(basis_obj): has_kern = [] @@ -2671,25 +2669,27 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b): + def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" - ): + if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: + context = does_not_raise() + else: + context = pytest.raises(ValueError, match="You must call `_set_kernel` before `_compute_features`") + with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality bas._compute_features(*x) @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) - bas2 = basis.BSplineBasis(11, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(11, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2700,8 +2700,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) - bas2 = basis.BSplineBasis(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2719,8 +2719,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) - bas2 = basis.BSplineBasis(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas = bas1 * bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2730,8 +2730,8 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_n_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.RaisedCosineBasisLinear(10, window_size=10) - bas2 = basis.BSplineBasis(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) From c112ab986317701cde6fb853655eae837c2cc6e8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 13:37:33 -0500 Subject: [PATCH 044/109] fixed test splitters --- tests/test_basis.py | 496 +++++++++++++------------------------------- 1 file changed, 139 insertions(+), 357 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index bcd83860..057a454d 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -11,16 +11,13 @@ import numpy as np import pynapple as nap import pytest -from scipy.sparse.csgraph import depth_first_tree -from scipy.stats import expon import utils_testing from sklearn.base import clone as sk_clone import nemos.basis.basis as basis import nemos.convolve as convolve -from nemos.basis import EvalOrthExponential -from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring +from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring, TransformerBasis from nemos.basis._decaying_exponential import OrthExponentialBasis from nemos.basis._raised_cosine_basis import ( RaisedCosineBasisLinear, @@ -63,17 +60,19 @@ def extra_decay_rates(cls, n_basis): # automatic define user accessible basis and check the methods -def list_all_basis_classes() -> list[type]: +def list_all_basis_classes(filter_basis="all") -> list[type]: """ Return all the classes in nemos.basis which are a subclass of Basis, which should be all concrete classes except TransformerBasis. """ - return [ + all_basis = [ class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis) if issubclass(class_obj, Basis) ] - + if filter_basis != "all": + all_basis = [a for a in all_basis if filter_basis in a.__name__] + return all_basis def test_all_basis_are_tested() -> None: """Meta-test. @@ -1590,15 +1589,22 @@ class CombinedBasis(BasisFuncsTesting): cls = None @staticmethod - def instantiate_basis(n_basis, basis_class, class_specific_params, window_size=10): + def instantiate_basis(n_basis, basis_class, class_specific_params, window_size=10, **kwargs): """Instantiate and return two basis of the type specified.""" - kwargs = { + # Set non-optional args + default_kwargs = { "n_basis_funcs": n_basis, "window_size": window_size, - "order": 2, "decay_rates": np.arange(1, 1 + n_basis) } + repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) + if repeated_keys: + raise ValueError("Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs") + + # Merge with provided extra kwargs + kwargs = {**default_kwargs, **kwargs} + if basis_class == AdditiveBasis: kwargs_mspline = trim_kwargs(basis.EvalMSpline, kwargs, class_specific_params) @@ -2743,7 +2749,7 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): "exponent", [-1, 0, 0.5, basis.EvalRaisedCosineLog(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) -def test_power_of_basis(exponent, basis_class): +def test_power_of_basis(exponent, basis_class, class_specific_params): """Test if the power behaves as expected.""" raise_exception_type = not type(exponent) is int @@ -2752,7 +2758,7 @@ def test_power_of_basis(exponent, basis_class): else: raise_exception_value = False - basis_obj = CombinedBasis.instantiate_basis(5, basis_class) + basis_obj = CombinedBasis.instantiate_basis(5, basis_class, class_specific_params, window_size=10) if raise_exception_type: with pytest.raises(TypeError, match=r"Exponent should be an integer\!"): @@ -2782,40 +2788,28 @@ def test_power_of_basis(exponent, basis_class): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_basis_to_transformer(basis_cls): +def test_basis_to_transformer(basis_cls, class_specific_params): n_basis_funcs = 5 - bas = basis_cls(n_basis_funcs) + bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) trans_bas = bas.to_transformer() - assert isinstance(trans_bas, basis.TransformerBasis) + assert isinstance(trans_bas, TransformerBasis) # check that things like n_basis_funcs are the same as the original basis for k in bas.__dict__.keys(): - assert getattr(bas, k) == getattr(trans_bas, k) + assert np.all(getattr(bas, k) == getattr(trans_bas, k)) @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformer_has_the_same_public_attributes_as_basis(basis_cls): +def test_transformer_has_the_same_public_attributes_as_basis(basis_cls, class_specific_params): n_basis_funcs = 5 - bas = basis_cls(n_basis_funcs) + bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} public_attrs_transformerbasis = { @@ -2833,20 +2827,14 @@ def test_transformer_has_the_same_public_attributes_as_basis(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_to_transformer_and_constructor_are_equivalent(basis_cls): +def test_to_transformer_and_constructor_are_equivalent(basis_cls, class_specific_params): n_basis_funcs = 5 - bas = basis_cls(n_basis_funcs) + bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) trans_bas_a = bas.to_transformer() - trans_bas_b = basis.TransformerBasis(bas) + trans_bas_b = TransformerBasis(bas) # they both just have a _basis assert ( @@ -2855,29 +2843,24 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls): == ["_basis"] ) # and those bases are the same + assert np.all(trans_bas_a._basis.__dict__.pop("_decay_rates", 1) == trans_bas_b._basis.__dict__.pop("_decay_rates", 1)) assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_basis_to_transformer_makes_a_copy(basis_cls): - bas_a = basis_cls(5) +def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): + bas_a = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) trans_bas_a = bas_a.to_transformer() # changing an attribute in bas should not change trans_bas bas_a.n_basis_funcs = 10 assert trans_bas_a.n_basis_funcs == 5 - # changing an attribute in the transformerbasis should not change the original - bas_b = basis_cls(5) + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) trans_bas_b = bas_b.to_transformer() trans_bas_b.n_basis_funcs = 100 assert bas_b.n_basis_funcs == 5 @@ -2885,34 +2868,26 @@ def test_basis_to_transformer_makes_a_copy(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) @pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) -def test_transformerbasis_getattr(basis_cls, n_basis_funcs): - trans_basis = basis.TransformerBasis(basis_cls(n_basis_funcs)) +def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): + trans_basis = TransformerBasis( + CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) + ) assert trans_basis.n_basis_funcs == n_basis_funcs @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) @pytest.mark.parametrize("n_basis_funcs_init", [5]) @pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_funcs_new): - trans_basis = basis.TransformerBasis(basis_cls(n_basis_funcs_init)) +def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params): + trans_basis = TransformerBasis( + CombinedBasis().instantiate_basis(n_basis_funcs_init, basis_cls, class_specific_params, window_size=10) + ) trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) assert trans_basis.n_basis_funcs == n_basis_funcs_new @@ -2921,18 +2896,14 @@ def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_func @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_setattr_basis(basis_cls): +def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it - trans_bas = basis.TransformerBasis(basis_cls(10)) - trans_bas._basis = basis_cls(20) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10) + ) + trans_bas._basis = CombinedBasis().instantiate_basis(20, basis_cls, class_specific_params, window_size=10) assert trans_bas.n_basis_funcs == 20 assert trans_bas._basis.n_basis_funcs == 20 @@ -2941,18 +2912,12 @@ def test_transformerbasis_setattr_basis(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_setattr_basis_attribute(basis_cls): +def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): # setting an attribute that is an attribute of the underlying _basis # should propagate setting it on _basis itself - trans_bas = basis.TransformerBasis(basis_cls(10)) + trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10)) trans_bas.n_basis_funcs = 20 assert trans_bas.n_basis_funcs == 20 @@ -2962,19 +2927,13 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_copy_basis_on_contsruct(basis_cls): +def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): # modifying the transformerbasis's attributes shouldn't # touch the original basis that was used to create it - orig_bas = basis_cls(10) - trans_bas = basis.TransformerBasis(orig_bas) + orig_bas = CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10) + trans_bas = TransformerBasis(orig_bas) trans_bas.n_basis_funcs = 20 assert orig_bas.n_basis_funcs == 10 @@ -2985,18 +2944,12 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_setattr_illegal_attribute(basis_cls): +def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): # changing an attribute that is not _basis or an attribute of _basis # is not allowed - trans_bas = basis.TransformerBasis(basis_cls(10)) + trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10)) with pytest.raises( ValueError, @@ -3007,21 +2960,17 @@ def test_transformerbasis_setattr_illegal_attribute(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_addition(basis_cls): +def test_transformerbasis_addition(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = basis.TransformerBasis(basis_cls(n_basis_funcs_a)) - trans_bas_b = basis.TransformerBasis(basis_cls(n_basis_funcs_b)) + bas_a = CombinedBasis().instantiate_basis(n_basis_funcs_a, basis_cls, class_specific_params, window_size=10) + bas_b = CombinedBasis().instantiate_basis(n_basis_funcs_b, basis_cls, class_specific_params, window_size=10) + trans_bas_a = TransformerBasis(bas_a) + trans_bas_b = TransformerBasis(bas_b) trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, basis.TransformerBasis) + assert isinstance(trans_bas_sum, TransformerBasis) assert isinstance(trans_bas_sum._basis, AdditiveBasis) assert ( trans_bas_sum.n_basis_funcs @@ -3037,21 +2986,15 @@ def test_transformerbasis_addition(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_multiplication(basis_cls): +def test_transformerbasis_multiplication(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = basis.TransformerBasis(basis_cls(n_basis_funcs_a)) - trans_bas_b = basis.TransformerBasis(basis_cls(n_basis_funcs_b)) + trans_bas_a = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs_a, basis_cls, class_specific_params, window_size=10)) + trans_bas_b = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs_b, basis_cls, class_specific_params, window_size=10)) trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, basis.TransformerBasis) + assert isinstance(trans_bas_prod, TransformerBasis) assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) assert ( trans_bas_prod.n_basis_funcs @@ -3067,13 +3010,7 @@ def test_transformerbasis_multiplication(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) @pytest.mark.parametrize( "exponent, error_type, error_message", @@ -3085,29 +3022,23 @@ def test_transformerbasis_multiplication(basis_cls): ], ) def test_transformerbasis_exponentiation( - basis_cls, exponent: int, error_type, error_message + basis_cls, exponent: int, error_type, error_message, class_specific_params ): - trans_bas = basis.TransformerBasis(basis_cls(5)) + trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10)) if not isinstance(exponent, int): with pytest.raises(error_type, match=error_message): trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, basis.TransformerBasis) + assert isinstance(trans_bas_exp, TransformerBasis) assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) -def test_transformerbasis_dir(basis_cls): - trans_bas = basis.TransformerBasis(basis_cls(5)) +def test_transformerbasis_dir(basis_cls, class_specific_params): + trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10)) for attr_name in ( "fit", "transform", @@ -3116,22 +3047,18 @@ def test_transformerbasis_dir(basis_cls): "mode", "window_size", ): + if attr_name == "window_size" and "Eval" in trans_bas._basis.__class__.__name__: + continue assert attr_name in dir(trans_bas) @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes("Conv"), ) -def test_transformerbasis_sk_clone_kernel_noned(basis_cls): - orig_bas = basis_cls(10, window_size=5) - trans_bas = basis.TransformerBasis(orig_bas) +def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): + orig_bas = CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=20) + trans_bas = TransformerBasis(orig_bas) # kernel should be saved in the object after fit trans_bas.fit(np.random.randn(100, 20)) @@ -3148,25 +3075,19 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls): @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - ], + list_all_basis_classes(), ) @pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): +def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs, class_specific_params): # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = basis.TransformerBasis(basis_cls(n_basis_funcs)) + trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10)) filepath = tmpdir / "transformerbasis.pickle" with open(filepath, "wb") as f: pickle.dump(trans_bas, f) with open(filepath, "rb") as f: trans_bas2 = pickle.load(f) - assert isinstance(trans_bas2, basis.TransformerBasis) + assert isinstance(trans_bas2, TransformerBasis) assert trans_bas2.n_basis_funcs == n_basis_funcs @@ -3199,52 +3120,21 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): ) @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - AdditiveBasis, - MultiplicativeBasis, - ], + list_all_basis_classes("Conv"), ) def test_multi_epoch_pynapple_basis( - basis_cls, tsd, window_size, shift, predictor_causality, nan_index + basis_cls, tsd, window_size, shift, predictor_causality, nan_index, class_specific_params ): """Test nan location in multi-epoch pynapple tsd.""" - if basis_cls == AdditiveBasis: - bas = basis.BSplineBasis( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - bas = bas + basis.RaisedCosineBasisLinear( - 5, + kwargs = dict(conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality)) - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - elif basis_cls == MultiplicativeBasis: - bas = basis.RaisedCosineBasisLog( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - bas = basis.EvalMSpline(3) * bas + # require a ws of at least nbasis funcs. + if "OrthExp" in basis_cls.__name__: + nbasis = 2 + # splines requires at least 1 basis more than the order of the spline. else: - bas = basis_cls( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) + nbasis = 5 + bas = CombinedBasis().instantiate_basis(nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs) n_input = bas._n_input_dimensionality @@ -3287,57 +3177,26 @@ def test_multi_epoch_pynapple_basis( ) @pytest.mark.parametrize( "basis_cls", - [ - basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, - AdditiveBasis, - MultiplicativeBasis, - ], + list_all_basis_classes("Conv"), ) def test_multi_epoch_pynapple_basis_transformer( - basis_cls, tsd, window_size, shift, predictor_causality, nan_index + basis_cls, tsd, window_size, shift, predictor_causality, nan_index, class_specific_params ): """Test nan location in multi-epoch pynapple tsd.""" - if basis_cls == AdditiveBasis: - bas = basis.BSplineBasis( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - bas = bas + basis.RaisedCosineBasisLinear( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - elif basis_cls == MultiplicativeBasis: - bas = basis.RaisedCosineBasisLog( - 5, - - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) - bas = basis.EvalMSpline(3) * bas + kwargs = dict(conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality)) + # require a ws of at least nbasis funcs. + if "OrthExp" in basis_cls.__name__: + nbasis = 2 + # splines requires at least 1 basis more than the order of the spline. else: - bas = basis_cls( - 5, + nbasis = 5 - window_size=window_size, - predictor_causality=predictor_causality, - shift=shift, - ) + bas = CombinedBasis().instantiate_basis(nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs) n_input = bas._n_input_dimensionality # pass through transformer - bas = basis.TransformerBasis(bas) + bas = TransformerBasis(bas) # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) @@ -3358,14 +3217,10 @@ def test_multi_epoch_pynapple_basis_transformer( "bas1, bas2, bas3", list( itertools.product( - *[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 3 + *[list_all_basis_classes()] * 3 ) ), ) -@pytest.mark.parametrize( - "mode1, mode2, mode3", - list(itertools.product(["eval", "conv"], ["eval", "conv"], ["eval", "conv"])), -) @pytest.mark.parametrize( "operator1, operator2, compute_slice", [ @@ -3437,50 +3292,22 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - mode1, mode2, mode3, bas1, bas2, bas3, operator1, operator2, compute_slice + bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params ): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) for bas in [bas1, bas2, bas3] ): return # define the basis n_basis = [5, 6, 7] n_input_basis = [1, 2, 3] - extra_kwargs = ( - {"decay_rates": np.arange(1, n_basis[0] + 1), "window_size": 5}, - {"decay_rates": np.arange(1, n_basis[1] + 1), "window_size": 5}, - {"decay_rates": np.arange(1, n_basis[2] + 1), "window_size": 5}, - ) - for i, val in enumerate( - zip([bas1, bas2, bas3], [mode1, mode2, mode3], extra_kwargs) - ): - bas, kwrgs = val - if bas != basis.OrthExponentialBasis: - kwrgs.pop("decay_rates") - if mode == "eval": - n_input_basis[i] = 1 - kwrgs.pop("window_size") - - bas1_instance = bas1( - n_basis[0], - mode=mode1, - **extra_kwargs[0], - label="1", - ) - bas2_instance = bas2( - n_basis[1], - mode=mode2, - **extra_kwargs[1], - label="2", - ) - bas3_instance = bas3( - n_basis[2], - mode=mode3, - **extra_kwargs[2], - label="3", - ) + + combine_basis = CombinedBasis() + bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") + bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") + bas3_instance = combine_basis.instantiate_basis(n_basis[2], bas3, class_specific_params, window_size=10, label="3") func1 = getattr(bas1_instance, operator1) func2 = getattr(bas2_instance, operator2) @@ -3497,7 +3324,7 @@ def test__get_splitter( "bas1, bas2", list( itertools.product( - *[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 2 + *[list_all_basis_classes()] * 2 ) ), ) @@ -3614,38 +3441,19 @@ def test__get_splitter( ], ) def test__get_splitter_split_by_input( - bas1, bas2, operator, n_input_basis_1, n_input_basis_2, compute_slice + bas1, bas2, operator, n_input_basis_1, n_input_basis_2, compute_slice, class_specific_params ): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) for bas in [bas1, bas2] ): return # define the basis n_basis = [5, 6] - mode = "conv" - extra_kwargs = ( - {"decay_rates": np.arange(1, n_basis[0] + 1), "window_size": 5}, - {"decay_rates": np.arange(1, n_basis[1] + 1), "window_size": 5}, - ) - for i, val in enumerate(zip([bas1, bas2], extra_kwargs)): - bas, kwrgs = val - if bas != basis.OrthExponentialBasis: - kwrgs.pop("decay_rates") - - bas1_instance = bas1( - n_basis[0], - - **extra_kwargs[0], - label="1", - ) - bas2_instance = bas2( - n_basis[1], - - **extra_kwargs[1], - label="2", - ) + combine_basis = CombinedBasis() + bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") + bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") func1 = getattr(bas1_instance, operator) bas12 = func1(bas2_instance) @@ -3664,32 +3472,24 @@ def test__get_splitter_split_by_input( "bas1, bas2, bas3", list( itertools.product( - *[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 3 + *[list_all_basis_classes()] * 3 ) ), ) -def test_duplicate_keys(bas1, bas2, bas3): +def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) for bas in [bas1, bas2, bas3] ): return - extra_kwargs = ( - {"decay_rates": np.arange(1, 5 + 1)}, - {"decay_rates": np.arange(1, 5 + 1)}, - {"decay_rates": np.arange(1, 5 + 1)}, - ) - for bas, kwrgs in zip((bas1, bas2, bas3), extra_kwargs): - if bas != basis.OrthExponentialBasis: - kwrgs.pop("decay_rates") - - bas_obj = ( - bas1(5, **extra_kwargs[0], label="label") - + bas2(5, **extra_kwargs[1], label="label") - + bas3(5, **extra_kwargs[2], label="label") - ) + combine_basis = CombinedBasis() + bas1_instance = combine_basis.instantiate_basis(5, bas1, class_specific_params, window_size=10, label="label") + bas2_instance = combine_basis.instantiate_basis(5, bas2, class_specific_params, window_size=10, label="label") + bas3_instance = combine_basis.instantiate_basis(5, bas3, class_specific_params, window_size=10, label="label") + bas_obj = bas1_instance + bas2_instance + bas3_instance + inps = [np.zeros((1,)) for n in range(3)] bas_obj._set_num_output_features(*inps) slice_dict = bas_obj._get_feature_slicing()[0] @@ -3700,7 +3500,7 @@ def test_duplicate_keys(bas1, bas2, bas3): "bas1, bas2", list( itertools.product( - *[tuple((getattr(basis, basis_name) for basis_name in dir(basis)))] * 2 + *[list_all_basis_classes()] * 2 ) ), ) @@ -3720,37 +3520,19 @@ def test_duplicate_keys(bas1, bas2, bas3): ), ], ) -def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes): +def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) for bas in [bas1, bas2] ): return # define the basis n_basis = [5, 6] - mode = "conv" - extra_kwargs = ( - {"decay_rates": np.arange(1, n_basis[0] + 1), "window_size": 5}, - {"decay_rates": np.arange(1, n_basis[1] + 1), "window_size": 5}, - ) - for i, val in enumerate(zip([bas1, bas2], extra_kwargs)): - bas, kwrgs = val - if bas != basis.OrthExponentialBasis: - kwrgs.pop("decay_rates") - - bas1_instance = bas1( - n_basis[0], + combine_basis = CombinedBasis() + bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") + bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") - **extra_kwargs[0], - label="1", - ) - bas2_instance = bas2( - n_basis[1], - - **extra_kwargs[1], - label="2", - ) bas = bas1_instance + bas2_instance bas._set_num_output_features(np.zeros((1, 2)), np.zeros((1, 3))) with expectation: From 5f65b714cf8662abe1ecac9f5db4295d6d1187ad Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 13:38:21 -0500 Subject: [PATCH 045/109] linted tests --- src/nemos/basis/_basis_mixin.py | 6 +- tests/test_basis.py | 1549 +++++++++++++++++++++---------- 2 files changed, 1069 insertions(+), 486 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 1f3b342d..ab95ab21 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -104,8 +104,10 @@ def _compute_features(self, *xi: ArrayLike): """ if self.kernel_ is None: - raise ValueError("You must call `_set_kernel` before `_compute_features`! " - "Convolution kernel is not set.") + raise ValueError( + "You must call `_set_kernel` before `_compute_features`! " + "Convolution kernel is not set." + ) # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. diff --git a/tests/test_basis.py b/tests/test_basis.py index 057a454d..f07d297e 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -11,13 +11,18 @@ import numpy as np import pynapple as nap import pytest - import utils_testing from sklearn.base import clone as sk_clone import nemos.basis.basis as basis import nemos.convolve as convolve -from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring, TransformerBasis +from nemos.basis._basis import ( + AdditiveBasis, + Basis, + MultiplicativeBasis, + TransformerBasis, + add_docstring, +) from nemos.basis._decaying_exponential import OrthExponentialBasis from nemos.basis._raised_cosine_basis import ( RaisedCosineBasisLinear, @@ -33,23 +38,31 @@ def class_specific_params(): eval_params = ["bounds"] conv_params = ["window_size", "conv_kwargs"] return dict( - EvalBSpline = shared_params + eval_params + ["order"], - ConvBSpline = shared_params + conv_params + ["order"], - EvalMSpline = shared_params + eval_params + ["order"], - ConvMSpline = shared_params + conv_params +["order"], - EvalCyclicBSpline = shared_params + eval_params + ["order"], - ConvCyclicBSpline = shared_params + conv_params +["order"], - EvalRaisedCosineLinear= shared_params + eval_params + ["width"], - ConvRaisedCosineLinear=shared_params + conv_params +["width"], - EvalRaisedCosineLog= shared_params + eval_params + ["width", "time_scaling", "enforce_decay_to_zero"], - ConvRaisedCosineLog= shared_params + conv_params +["width", "time_scaling", "enforce_decay_to_zero"], - EvalOrthExponential= shared_params + eval_params + ["decay_rates"], - ConvOrthExponential = shared_params + conv_params +["decay_rates"] + EvalBSpline=shared_params + eval_params + ["order"], + ConvBSpline=shared_params + conv_params + ["order"], + EvalMSpline=shared_params + eval_params + ["order"], + ConvMSpline=shared_params + conv_params + ["order"], + EvalCyclicBSpline=shared_params + eval_params + ["order"], + ConvCyclicBSpline=shared_params + conv_params + ["order"], + EvalRaisedCosineLinear=shared_params + eval_params + ["width"], + ConvRaisedCosineLinear=shared_params + conv_params + ["width"], + EvalRaisedCosineLog=shared_params + + eval_params + + ["width", "time_scaling", "enforce_decay_to_zero"], + ConvRaisedCosineLog=shared_params + + conv_params + + ["width", "time_scaling", "enforce_decay_to_zero"], + EvalOrthExponential=shared_params + eval_params + ["decay_rates"], + ConvOrthExponential=shared_params + conv_params + ["decay_rates"], ) def trim_kwargs(cls, kwargs, class_specific_params): - return {key: value for key, value in kwargs.items() if key in class_specific_params[cls.__name__]} + return { + key: value + for key, value in kwargs.items() + if key in class_specific_params[cls.__name__] + } def extra_decay_rates(cls, n_basis): @@ -74,6 +87,7 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: all_basis = [a for a in all_basis if filter_basis in a.__name__] return all_basis + def test_all_basis_are_tested() -> None: """Meta-test. @@ -296,8 +310,8 @@ def cls(self): {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline}, {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline}, {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline}, - {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential} - ] + {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential}, + ], ) class TestSharedMethods: @@ -306,23 +320,23 @@ class TestSharedMethods: [ (0.5, 0, 1, does_not_raise()), ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), ), ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), ), ], ) @@ -345,11 +359,10 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): def test_attr_setter(self, attribute, value, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) with pytest.raises( - AttributeError, match=rf"can't set attribute|property '{attribute}' of" + AttributeError, match=rf"can't set attribute|property '{attribute}' of" ): setattr(bas, attribute, value) - @pytest.mark.parametrize( "n_input, expectation", [ @@ -360,7 +373,9 @@ def test_attr_setter(self, attribute, value, cls): ], ) def test_expected_input_number(self, n_input, expectation, cls): - bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["eval"], 5)) + bas = cls["conv"]( + n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["eval"], 5) + ) x = np.random.randn(20, 2) bas.compute_features(x) with expectation: @@ -371,49 +386,66 @@ def test_expected_input_number(self, n_input, expectation, cls): [ (dict(), does_not_raise()), ( - dict(axis=0), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + dict(axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), ), ( - dict(axis=1), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + dict(axis=1), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), ), (dict(shift=True), does_not_raise()), ( - dict(shift=True, axis=0), - pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed"), + dict(shift=True, axis=0), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), ), ( - dict(shifts=True), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), + dict(shifts=True), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), ), (dict(shift=True, predictor_causality="causal"), does_not_raise()), ( - dict(shift=True, time_series=np.arange(10)), - pytest.raises(ValueError, match="Unrecognized keyword arguments"), + dict(shift=True, time_series=np.arange(10)), + pytest.raises(ValueError, match="Unrecognized keyword arguments"), ), ], ) def test_init_conv_kwargs(self, conv_kwargs, expectation, cls): with expectation: - cls["conv"](n_basis_funcs=5, window_size=200, conv_kwargs=conv_kwargs, **extra_decay_rates(cls["eval"], 5)) + cls["conv"]( + n_basis_funcs=5, + window_size=200, + conv_kwargs=conv_kwargs, + **extra_decay_rates(cls["eval"], 5), + ) @pytest.mark.parametrize("label", [None, "label"]) def test_init_label(self, label, cls): - bas = cls["eval"](n_basis_funcs=5, label=label, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"]( + n_basis_funcs=5, label=label, **extra_decay_rates(cls["eval"], 5) + ) expected_label = str(label) if label is not None else cls["eval"].__name__ assert bas.label == expected_label @pytest.mark.parametrize("n_input", [1, 2, 3]) def test_set_num_output_features(self, n_input, cls): - bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) + bas = cls["conv"]( + n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) + ) assert bas.n_output_features is None bas.compute_features(np.random.randn(20, n_input)) assert bas.n_output_features == n_input * bas.n_basis_funcs @pytest.mark.parametrize("n_input", [1, 2, 3]) def test_set_num_basis_input(self, n_input, cls): - bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) + bas = cls["conv"]( + n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) + ) assert bas.n_basis_input is None bas.compute_features(np.random.randn(20, n_input)) assert bas.n_basis_input == (n_input,) @@ -423,14 +455,31 @@ def test_set_num_basis_input(self, n_input, cls): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation, cls): - basis_obj = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + basis_obj = cls["eval"]( + 5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5) + ) with expectation: basis_obj.compute_features(samples) @@ -443,9 +492,15 @@ def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation, cls) ((1, 3), np.arange(5), [0, 4], 1, 3), ], ) - def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx, cls): - bas_no_range = cls["eval"](n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5)) - bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + def test_vmin_vmax_eval_on_grid_affects_x( + self, bounds, samples, nan_idx, mn, mx, cls + ): + bas_no_range = cls["eval"]( + n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5) + ) + bas = cls["eval"]( + n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5) + ) x1, _ = bas.evaluate_on_grid(10) x2, _ = bas_no_range.evaluate_on_grid(10) assert np.allclose(x1, x2 * (mx - mn) + mn) @@ -458,12 +513,18 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (1, 3, np.arange(5), [0, 4]), ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx, cls): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx, cls + ): # MSPline integrates to 1 on domain so must be excluded from this check - if "MSpline" in cls["eval"].__name__: + if "MSpline" in cls["eval"].__name__: return - bas_no_range = cls["eval"](n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5)) - bas = cls["eval"](n_basis_funcs=5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + bas_no_range = cls["eval"]( + n_basis_funcs=5, bounds=None, **extra_decay_rates(cls["eval"], 5) + ) + bas = cls["eval"]( + n_basis_funcs=5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5) + ) _, out1 = bas.evaluate_on_grid(10) _, out2 = bas_no_range.evaluate_on_grid(10) assert np.allclose(out1, out2) @@ -479,16 +540,18 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), + (1, 2, 3), + pytest.raises( + ValueError, match="The provided `bounds` must be of length two" + ), ), ], ) def test_vmin_vmax_init(self, bounds, expectation, cls): with expectation: - bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"]( + n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5) + ) assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize( @@ -502,16 +565,18 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), + (1, 2, 3), + pytest.raises( + ValueError, match="The provided `bounds` must be of length two" + ), ), ], ) def test_vmin_vmax_init(self, bounds, expectation, cls): with expectation: - bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"]( + n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5) + ) assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize( @@ -525,31 +590,43 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), ), ], ) def test_vmin_vmax_setter(self, bounds, expectation, cls): - bas = cls["eval"](n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"]( + n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5) + ) with expectation: bas.set_params(bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_basis_number(self, n_basis, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_basis_number(self, n_basis, mode, kwargs, cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) + bas = cls[mode]( + n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) + ) x = np.linspace(0, 1, 10) assert bas(x).shape[1] == n_basis @pytest.mark.parametrize("n_basis", [6]) def test_call_equivalent_in_conv(self, n_basis, cls): - bas_con = cls["conv"](n_basis_funcs=n_basis, window_size=10, **extra_decay_rates(cls["conv"], n_basis)) - bas_eval = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)) + bas_con = cls["conv"]( + n_basis_funcs=n_basis, + window_size=10, + **extra_decay_rates(cls["conv"], n_basis), + ) + bas_eval = cls["eval"]( + n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) + ) x = np.linspace(0, 1, 10) assert np.all(bas_con(x) == bas_eval(x)) @@ -561,10 +638,14 @@ def test_call_equivalent_in_conv(self, n_basis, cls): (2, pytest.raises(TypeError, match="Input dimensionality mismatch")), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) @pytest.mark.parametrize("n_basis", [6]) - def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation,cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) + def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls): + bas = cls[mode]( + n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) + ) with expectation: bas(*([np.linspace(0, 1, 10)] * num_input)) @@ -576,9 +657,13 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation,cls) ], ) @pytest.mark.parametrize("n_basis", [6]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_input_shape(self, inp, mode, kwargs, expectation,n_basis,cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): + bas = cls[mode]( + n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) + ) with expectation: bas(inp) @@ -587,34 +672,44 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation,n_basis,cls): [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), ( - np.array(["a", "1", "2", "3", "4", "5"]), - pytest.raises(TypeError, match="Input samples must"), + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), ), ], ) @pytest.mark.parametrize("n_basis", [6]) - def test_call_input_type(self, samples, expectation, n_basis,cls): - bas = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)) # Only eval mode is relevant here + def test_call_input_type(self, samples, expectation, n_basis, cls): + bas = cls["eval"]( + n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) + ) # Only eval mode is relevant here with expectation: bas(samples) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_nan(self, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x[3] = np.nan assert all(np.isnan(bas(x)[3])) @pytest.mark.parametrize("n_basis", [6, 7]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_non_empty(self, n_basis, mode, kwargs,cls): - bas = cls[mode](n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis)) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_non_empty(self, n_basis, mode, kwargs, cls): + bas = cls[mode]( + n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) + ) with pytest.raises(ValueError, match="All sample provided must"): bas(np.array([])) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_sample_axis(self, time_axis_shape, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape @@ -625,8 +720,10 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs,cls): (-2, 2, does_not_raise()), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_call_sample_range(self, mn, mx, expectation, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) with expectation: bas(np.linspace(mn, mx, 10)) @@ -635,13 +732,30 @@ def test_call_sample_range(self, mn, mx, expectation, mode, kwargs,cls): "kwargs, input1_shape, expectation", [ (dict(), (10,), does_not_raise()), - (dict(axis=0), (10,), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), - (dict(axis=1), (2, 10), pytest.raises(ValueError, match="Setting the `axis` parameter is not allowed")), + ( + dict(axis=0), + (10,), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), + ( + dict(axis=1), + (2, 10), + pytest.raises( + ValueError, match="Setting the `axis` parameter is not allowed" + ), + ), ], ) - def test_compute_features_axis(self, kwargs, input1_shape, expectation,cls): + def test_compute_features_axis(self, kwargs, input1_shape, expectation, cls): with expectation: - basis_obj = cls["conv"](n_basis_funcs=5, window_size=5, conv_kwargs=kwargs, **extra_decay_rates(cls["conv"], 5)) + basis_obj = cls["conv"]( + n_basis_funcs=5, + window_size=5, + conv_kwargs=kwargs, + **extra_decay_rates(cls["conv"], 5), + ) basis_obj.compute_features(np.ones(input1_shape)) @pytest.mark.parametrize("n_basis_funcs", [4, 5]) @@ -662,28 +776,29 @@ def test_compute_features_axis(self, kwargs, input1_shape, expectation,cls): ], ) def test_compute_features_conv_input( - self, - n_basis_funcs, - time_scaling, - enforce_decay, - window_size, - input_shape, - expected_n_input, - order, - width, - cls, - class_specific_params, - ): + self, + n_basis_funcs, + time_scaling, + enforce_decay, + window_size, + input_shape, + expected_n_input, + order, + width, + cls, + class_specific_params, + ): x = np.ones(input_shape) kwargs = dict( n_basis_funcs=n_basis_funcs, - decay_rates=np.arange(1, n_basis_funcs+1), + decay_rates=np.arange(1, n_basis_funcs + 1), time_scaling=time_scaling, window_size=window_size, enforce_decay_to_zero=enforce_decay, order=order, - width=width,) + width=width, + ) # figure out which kwargs needs to be removed kwargs = trim_kwargs(cls["conv"], kwargs, class_specific_params) @@ -692,8 +807,10 @@ def test_compute_features_conv_input( out = basis_obj.compute_features(x) assert out.shape[1] == expected_n_input * basis_obj.n_basis_funcs - @pytest.mark.parametrize("eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])]) - def test_compute_features_input(self, eval_input,cls): + @pytest.mark.parametrize( + "eval_input", [0, [0], (0,), np.array([0]), jax.numpy.array([0])] + ) + def test_compute_features_input(self, eval_input, cls): # orth exp needs more inputs (orthogonalizaiton impossible otherwise) if "OrthExp" in cls["eval"].__name__: return @@ -704,9 +821,15 @@ def test_compute_features_input(self, eval_input,cls): "args, sample_size", [[{"n_basis_funcs": n_basis}, 100] for n_basis in [6, 10, 13]], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})]) - def test_compute_features_returns_expected_number_of_basis(self, args, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](**args, **kwargs, **extra_decay_rates(cls[mode], args["n_basis_funcs"])) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})] + ) + def test_compute_features_returns_expected_number_of_basis( + self, args, sample_size, mode, kwargs, cls + ): + basis_obj = cls[mode]( + **args, **kwargs, **extra_decay_rates(cls[mode], args["n_basis_funcs"]) + ) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[1] == args["n_basis_funcs"], ( "Dimensions do not agree: The number of basis should match the first dimension " @@ -718,31 +841,77 @@ def test_compute_features_returns_expected_number_of_basis(self, args, sample_si "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), ], ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation,cls): + def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation, cls): if "OrthExp" in cls["eval"].__name__ and not hasattr(samples, "shape"): return - basis_obj = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) + basis_obj = cls["eval"]( + 5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5) + ) with expectation: basis_obj.compute_features(samples) @pytest.mark.parametrize( "bounds, samples, exception", [ - (None, np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((0, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((1, 4), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), - ((1, 3), np.arange(5), pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ( + None, + np.arange(5), + pytest.raises( + TypeError, match="got an unexpected keyword argument 'bounds'" + ), + ), + ( + (0, 3), + np.arange(5), + pytest.raises( + TypeError, match="got an unexpected keyword argument 'bounds'" + ), + ), + ( + (1, 4), + np.arange(5), + pytest.raises( + TypeError, match="got an unexpected keyword argument 'bounds'" + ), + ), + ( + (1, 3), + np.arange(5), + pytest.raises( + TypeError, match="got an unexpected keyword argument 'bounds'" + ), + ), ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception, cls): with exception: - cls["conv"](n_basis_funcs=5, window_size=10, bounds=bounds, **extra_decay_rates(cls["conv"], 5)) + cls["conv"]( + n_basis_funcs=5, + window_size=10, + bounds=bounds, + **extra_decay_rates(cls["conv"], 5), + ) @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", @@ -755,7 +924,9 @@ def test_vmin_vmax_mode_conv(self, bounds, samples, exception, cls): ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx, cls): bounds = None if vmin is None else (vmin, vmax) - bas = cls["eval"](n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5)) + bas = cls["eval"]( + n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5) + ) out = bas.compute_features(samples) assert np.all(np.isnan(out[nan_idx])) valid_idx = list(set(samples).difference(nan_idx)) @@ -772,25 +943,31 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx, cls): ((1, "a"), pytest.raises(TypeError, match="Could not convert")), (("a", "a"), pytest.raises(TypeError, match="Could not convert")), ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), ), ], ) - def test_vmin_vmax_setter(self, bounds, expectation,cls): - bas = cls["eval"](n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5)) + def test_vmin_vmax_setter(self, bounds, expectation, cls): + bas = cls["eval"]( + n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5) + ) with expectation: bas.set_params(bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None - def test_conv_kwargs_error(self,cls): - with pytest.raises(TypeError, match="got an unexpected keyword argument 'test'"): + def test_conv_kwargs_error(self, cls): + with pytest.raises( + TypeError, match="got an unexpected keyword argument 'test'" + ): cls["eval"](n_basis_funcs=5, test="hi", **extra_decay_rates(cls["eval"], 5)) - def test_convolution_is_performed(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5)) + def test_convolution_is_performed(self, cls): + bas = cls["conv"]( + n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) + ) x = np.random.normal(size=100) conv = bas.compute_features(x) conv_2 = convolve.create_convolutional_predictor(bas.kernel_, x) @@ -799,30 +976,42 @@ def test_convolution_is_performed(self,cls): assert np.all(np.isnan(conv_2[~valid])) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: return - basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) + basis_obj = cls[mode]( + n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5) + ) if sample_size <= 0: - with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): + with pytest.raises( + ValueError, match=r"All sample counts provided must be greater" + ): basis_obj.evaluate_on_grid(sample_size) else: _, eval_basis = basis_obj.evaluate_on_grid(sample_size) assert eval_basis.shape[0] == sample_size @pytest.mark.parametrize("n_input", [0, 1, 2]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): + basis_obj = cls[mode]( + n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5) + ) inputs = [10] * n_input if n_input == 0: expectation = pytest.raises( - TypeError, match=r"evaluate_on_grid\(\) missing 1 required positional argument", + TypeError, + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( - TypeError, match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + TypeError, + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -831,25 +1020,35 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs,cls): basis_obj.evaluate_on_grid(*inputs) @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: return - basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) + basis_obj = cls[mode]( + n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5) + ) if sample_size <= 0: - with pytest.raises(ValueError, match=r"All sample counts provided must be greater"): + with pytest.raises( + ValueError, match=r"All sample counts provided must be greater" + ): basis_obj.evaluate_on_grid(sample_size) else: grid, _ = basis_obj.evaluate_on_grid(sample_size) assert grid.shape[0] == sample_size def test_fit_kernel(self, cls): - bas = cls["conv"](n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5)) + bas = cls["conv"]( + n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5) + ) bas._set_kernel() assert bas.kernel_ is not None - def test_fit_kernel_shape(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5)) + def test_fit_kernel_shape(self, cls): + bas = cls["conv"]( + n_basis_funcs=5, window_size=30, **extra_decay_rates(cls["conv"], 5) + ) bas._set_kernel() assert bas.kernel_.shape == (30, 5) @@ -858,40 +1057,46 @@ def test_fit_kernel_shape(self,cls): [ ("conv", 2, does_not_raise()), ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), + "conv", + -1, + pytest.raises(ValueError, match="`window_size` must be a positive "), ), ( - "conv", - None, - pytest.raises( - ValueError, - match="If the basis is in `conv` mode, you must provide a ", - ), + "conv", + None, + pytest.raises( + ValueError, + match="If the basis is in `conv` mode, you must provide a ", + ), ), ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), + "conv", + 1.5, + pytest.raises(ValueError, match="`window_size` must be a positive "), ), - ("eval", None, pytest.raises( - TypeError, - match=r"got an unexpected keyword argument 'window_size'", - )), ( - "eval", - 10, - pytest.raises( - TypeError, - match=r"got an unexpected keyword argument 'window_size'", - ), + "eval", + None, + pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + ), + ), + ( + "eval", + 10, + pytest.raises( + TypeError, + match=r"got an unexpected keyword argument 'window_size'", + ), ), ], ) - def test_init_window_size(self, mode, ws, expectation,cls): + def test_init_window_size(self, mode, ws, expectation, cls): with expectation: - cls[mode](n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5)) + cls[mode]( + n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5) + ) # @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) # @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @@ -911,35 +1116,55 @@ def test_init_window_size(self, mode, ws, expectation,cls): # cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_non_empty_samples(self, samples, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_non_empty_samples(self, samples, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: return if mode == "conv" and len(samples) == 1: return if len(samples) == 0: - with pytest.raises(ValueError, match="All sample provided must be non empty"): - cls[mode](5, **kwargs, **extra_decay_rates(cls[mode], 5)).compute_features(samples) + with pytest.raises( + ValueError, match="All sample provided must be non empty" + ): + cls[mode]( + 5, **kwargs, **extra_decay_rates(cls[mode], 5) + ).compute_features(samples) else: - cls[mode](5, **kwargs, **extra_decay_rates(cls[mode], 5)).compute_features(samples) + cls[mode](5, **kwargs, **extra_decay_rates(cls[mode], 5)).compute_features( + samples + ) @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})] + ) + def test_number_of_required_inputs_compute_features( + self, n_input, mode, kwargs, cls + ): + basis_obj = cls[mode]( + n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5) + ) inputs = [np.linspace(0, 1, 20)] * n_input if n_input == 0: - expectation = pytest.raises(TypeError, match="missing 1 required positional argument") + expectation = pytest.raises( + TypeError, match="missing 1 required positional argument" + ) elif n_input != basis_obj._n_input_dimensionality: - expectation = pytest.raises(TypeError, match="takes 2 positional arguments but \d were given") + expectation = pytest.raises( + TypeError, match="takes 2 positional arguments but \d were given" + ) else: expectation = does_not_raise() with expectation: basis_obj.compute_features(*inputs) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})]) - def test_pynapple_support(self, mode, kwargs,cls): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + ) + def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x_nap = nap.Tsd(t=np.arange(10), d=x) @@ -951,22 +1176,32 @@ def test_pynapple_support(self, mode, kwargs,cls): @pytest.mark.parametrize("sample_size", [30]) @pytest.mark.parametrize("n_basis", [5]) - def test_pynapple_support_compute_features(self, n_basis, sample_size,cls): + def test_pynapple_support_compute_features(self, n_basis, sample_size, cls): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( t=np.linspace(0, 1, sample_size), d=np.linspace(0, 1, sample_size), time_support=iset, ) - out = cls["eval"](n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis)).compute_features(inp) + out = cls["eval"]( + n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) + ).compute_features(inp) assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) @pytest.mark.parametrize("sample_size", [100, 1000]) @pytest.mark.parametrize("n_basis_funcs", [5, 10, 80]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 90})]) - def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_funcs, sample_size, mode, kwargs,cls): - basis_obj = cls[mode](n_basis_funcs=n_basis_funcs, **kwargs, **extra_decay_rates(cls[mode], n_basis_funcs)) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 90})] + ) + def test_sample_size_of_compute_features_matches_that_of_input( + self, n_basis_funcs, sample_size, mode, kwargs, cls + ): + basis_obj = cls[mode]( + n_basis_funcs=n_basis_funcs, + **kwargs, + **extra_decay_rates(cls[mode], n_basis_funcs), + ) eval_basis = basis_obj.compute_features(np.linspace(0, 1, sample_size)) assert eval_basis.shape[0] == sample_size, ( f"Dimensions do not agree: The sample size of the output should match the input sample size. " @@ -977,17 +1212,26 @@ def test_sample_size_of_compute_features_matches_that_of_input(self, n_basis_fun "mode, expectation", [ ("eval", does_not_raise()), - ("conv", pytest.raises(TypeError, match="got an unexpected keyword argument 'bounds'")), + ( + "conv", + pytest.raises( + TypeError, match="got an unexpected keyword argument 'bounds'" + ), + ), ], ) - def test_set_bounds(self, mode, expectation,cls): + def test_set_bounds(self, mode, expectation, cls): kwargs = {"bounds": (1, 2)} with expectation: cls[mode](n_basis_funcs=10, **kwargs, **extra_decay_rates(cls[mode], 10)) if mode == "conv": - bas = cls["conv"](n_basis_funcs=10, window_size=20, **extra_decay_rates(cls[mode], 10)) - with pytest.raises(ValueError, match="Invalid parameter 'bounds' for estimator"): + bas = cls["conv"]( + n_basis_funcs=10, window_size=20, **extra_decay_rates(cls[mode], 10) + ) + with pytest.raises( + ValueError, match="Invalid parameter 'bounds' for estimator" + ): bas.set_params(bounds=(1, 2)) @pytest.mark.parametrize( @@ -1004,17 +1248,19 @@ def test_set_bounds(self, mode, expectation,cls): ], ) def test_set_params( - self, - enforce_decay_to_zero, - time_scaling, - width, - window_size, - n_basis_funcs, - bounds, - mode: Literal["eval", "conv"], - order, decay_rates, conv_kwargs, - cls, - class_specific_params + self, + enforce_decay_to_zero, + time_scaling, + width, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + order, + decay_rates, + conv_kwargs, + cls, + class_specific_params, ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -1028,12 +1274,14 @@ def test_set_params( decay_rates=decay_rates, conv_kwargs=conv_kwargs, ) - pars = {key: value for key, value in pars.items() if key in class_specific_params[cls[mode].__name__]} + pars = { + key: value + for key, value in pars.items() + if key in class_specific_params[cls[mode].__name__] + } keys = list(pars.keys()) - bas = cls[mode]( - **pars - ) + bas = cls[mode](**pars) for i in range(len(pars)): for j in range(i + 1, len(pars)): par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} @@ -1044,32 +1292,43 @@ def test_set_params( "mode, expectation", [ ("conv", does_not_raise()), - ("eval", pytest.raises(TypeError, match="got an unexpected keyword argument 'window_size'")), + ( + "eval", + pytest.raises( + TypeError, match="got an unexpected keyword argument 'window_size'" + ), + ), ], ) - def test_set_window_size(self, mode, expectation,cls): + def test_set_window_size(self, mode, expectation, cls): kwargs = {"window_size": 10} with expectation: cls[mode](n_basis_funcs=10, **kwargs, **extra_decay_rates(cls[mode], 10)) if mode == "conv": - bas = cls["conv"](n_basis_funcs=10, window_size=10, **extra_decay_rates(cls["conv"], 10)) + bas = cls["conv"]( + n_basis_funcs=10, window_size=10, **extra_decay_rates(cls["conv"], 10) + ) with pytest.raises(ValueError, match="If the basis is in `conv` mode"): bas.set_params(window_size=None) if mode == "eval": bas = cls["eval"](n_basis_funcs=10, **extra_decay_rates(cls["eval"], 10)) - with pytest.raises(ValueError, match="Invalid parameter 'window_size' for estimator"): + with pytest.raises( + ValueError, match="Invalid parameter 'window_size' for estimator" + ): bas.set_params(window_size=10) - def test_transform_fails(self,cls): - bas = cls["conv"](n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5)) + def test_transform_fails(self, cls): + bas = cls["conv"]( + n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5) + ) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + ValueError, match="You must call `_set_kernel` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) - def test_transformer_get_params(self,cls): + def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() @@ -1083,11 +1342,15 @@ def test_transformer_get_params(self,cls): class TestRaisedCosineLogBasis: cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} + @pytest.mark.parametrize("width", [1.5, 2, 2.5]) def test_decay_to_zero_basis_number_match(self, width): n_basis_funcs = 10 _, ev = self.cls["conv"]( - n_basis_funcs=n_basis_funcs, width=width, enforce_decay_to_zero=True, window_size=5 + n_basis_funcs=n_basis_funcs, + width=width, + enforce_decay_to_zero=True, + window_size=5, ).evaluate_on_grid(2) assert ev.shape[1] == n_basis_funcs, ( "Basis function number mismatch. " @@ -1095,11 +1358,16 @@ def test_decay_to_zero_basis_number_match(self, width): ) @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, mode, kwargs + ): if n_basis_funcs < 2: with pytest.raises( - ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", + ValueError, + match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", ): self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) else: @@ -1110,13 +1378,33 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, [ (10, does_not_raise()), (10.5, does_not_raise()), - (0.5, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), - (10.3, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), - (-10, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + ( + 0.5, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), + ), + ( + 10.3, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), + ), + ( + -10, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), + ), (None, pytest.raises(TypeError, match="'<=' not supported between")), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] + ) def test_set_width(self, width, expectation, mode, kwargs): basis_obj = self.cls[mode](n_basis_funcs=5, **kwargs) with expectation: @@ -1137,7 +1425,7 @@ def test_time_scaling_property(self): ) _, log_ev = basis_log.evaluate_on_grid(100) corr[idx] = (lin_ev.flatten() @ log_ev.flatten()) / ( - np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) + np.linalg.norm(lin_ev.flatten()) * np.linalg.norm(log_ev.flatten()) ) assert np.all( np.diff(corr) < 0 @@ -1146,13 +1434,25 @@ def test_time_scaling_property(self): @pytest.mark.parametrize( "time_scaling, expectation", [ - (-1, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), - (0, pytest.raises(ValueError, match="Only strictly positive time_scaling are allowed")), + ( + -1, + pytest.raises( + ValueError, match="Only strictly positive time_scaling are allowed" + ), + ), + ( + 0, + pytest.raises( + ValueError, match="Only strictly positive time_scaling are allowed" + ), + ), (0.1, does_not_raise()), (10, does_not_raise()), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] + ) def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): with expectation: self.cls[mode](n_basis_funcs=5, time_scaling=time_scaling, **kwargs) @@ -1169,7 +1469,9 @@ def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) def test_width_values(self, width, expectation, mode, kwargs): with expectation: self.cls[mode](n_basis_funcs=5, width=width, **kwargs) @@ -1179,11 +1481,16 @@ class TestRaisedCosineLinearBasis(BasisFuncsTesting): cls = {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, mode, kwargs + ): if n_basis_funcs < 2: with pytest.raises( - ValueError, match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", + ValueError, + match=f"Object class {self.cls[mode].__name__} requires >= 2 basis elements.", ): self.cls[mode](n_basis_funcs=n_basis_funcs, **kwargs) else: @@ -1194,8 +1501,20 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, [ (10, does_not_raise()), (10.5, does_not_raise()), - (0.5, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), - (-10, pytest.raises(ValueError, match=r"Invalid raised cosine width\. 2\*width must be a positive")), + ( + 0.5, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), + ), + ( + -10, + pytest.raises( + ValueError, + match=r"Invalid raised cosine width\. 2\*width must be a positive", + ), + ), ], ) def test_set_width(self, width, expectation): @@ -1217,7 +1536,9 @@ def test_set_width(self, width, expectation): (2.1, pytest.raises(ValueError, match="Invalid raised cosine width. ")), ], ) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})]) + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] + ) def test_width_values(self, width, expectation, mode, kwargs): """ Test allowable widths: integer multiple of 1/2, greater than 1. @@ -1235,8 +1556,12 @@ class TestMSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, order, mode, kwargs + ): """ Verifies that the minimum number of basis functions and order required (i.e., at least 1) and order < #basis are enforced. @@ -1244,9 +1569,9 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, raise_exception = (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs) if raise_exception: with pytest.raises( - ValueError, - match=r"Spline order must be positive!|" - rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ValueError, + match=r"Spline order must be positive!|" + rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", ): basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, order=order, **kwargs @@ -1283,11 +1608,11 @@ def test_order_is_positive(self, n_basis_funcs, order): (1, does_not_raise()), (2, does_not_raise()), ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), + 10, + pytest.raises( + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", + ), ), ], ) @@ -1316,7 +1641,7 @@ def test_samples_range_matches_compute_features_requirements(self, sample_range) ], ) def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( - self, bounds, samples, nan_idx, scaling + self, bounds, samples, nan_idx, scaling ): """ Check that the MSpline has the expected scaling property. @@ -1345,9 +1670,12 @@ def test_decay_rate_repetition(self, decay_rates): raise_exception = len(set(decay_rates)) != len(decay_rates) if raise_exception: with pytest.raises( - ValueError, match=r"Two or more rates are repeated! Repeating rates will" + ValueError, + match=r"Two or more rates are repeated! Repeating rates will", ): - self.cls["eval"](n_basis_funcs=len(decay_rates), decay_rates=decay_rates) + self.cls["eval"]( + n_basis_funcs=len(decay_rates), decay_rates=decay_rates + ) else: self.cls["eval"](n_basis_funcs=len(decay_rates), decay_rates=decay_rates) @@ -1363,15 +1691,19 @@ def test_decay_rate_size_match_n_basis_funcs(self, decay_rates, n_basis_funcs): decay_rates = np.asarray(decay_rates, dtype=float) if raise_exception: with pytest.raises( - ValueError, match="The number of basis functions must match the" + ValueError, match="The number of basis functions must match the" ): self.cls["eval"](n_basis_funcs=n_basis_funcs, decay_rates=decay_rates) else: self.cls["eval"](n_basis_funcs=n_basis_funcs, decay_rates=decay_rates) @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 30})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, mode, kwargs + ): """ Tests whether the class instance has a minimum number of basis functions. """ @@ -1379,8 +1711,8 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, decay_rates = np.arange(1, 1 + n_basis_funcs) if n_basis_funcs > 0 else [] if raise_exception: with pytest.raises( - ValueError, - match=f"Object class {self.cls[mode].__name__} requires >= 1 basis elements.", + ValueError, + match=f"Object class {self.cls[mode].__name__} requires >= 1 basis elements.", ): self.cls[mode]( n_basis_funcs=n_basis_funcs, @@ -1400,8 +1732,12 @@ class TestBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, order, mode, kwargs + ): """ Verifies that the minimum number of basis functions and order required (i.e., at least 1) and order < #basis are enforced. @@ -1409,8 +1745,8 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, raise_exception = order > n_basis_funcs if raise_exception: with pytest.raises( - ValueError, - match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", ): basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, @@ -1452,11 +1788,11 @@ def test_order_is_positive(self, n_basis_funcs, order): (1, does_not_raise()), (2, does_not_raise()), ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), + 10, + pytest.raises( + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", + ), ), ], ) @@ -1469,7 +1805,9 @@ def test_order_setter(self, n_basis_funcs, order, expectation): @pytest.mark.parametrize( "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] ) - def test_samples_range_matches_compute_features_requirements(self, sample_range: tuple): + def test_samples_range_matches_compute_features_requirements( + self, sample_range: tuple + ): """ Verifies that the compute_features() method can handle input range. """ @@ -1482,8 +1820,12 @@ class TestCyclicBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) - @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) - def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, mode, kwargs): + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + ) + def test_minimum_number_of_basis_required_is_matched( + self, n_basis_funcs, order, mode, kwargs + ): """ Verifies that the minimum number of basis functions and order required (i.e., at least 1) and order < #basis are enforced. @@ -1491,8 +1833,8 @@ def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, order, raise_exception = order > n_basis_funcs if raise_exception: with pytest.raises( - ValueError, - match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", ): basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, @@ -1517,7 +1859,7 @@ def test_order_1_invalid(self, n_basis_funcs, order): raise_exception = order == 1 if raise_exception: with pytest.raises( - ValueError, match=r"Order >= 2 required for cyclic B-spline" + ValueError, match=r"Order >= 2 required for cyclic B-spline" ): basis_obj = self.cls["eval"](n_basis_funcs=n_basis_funcs, order=order) basis_obj.compute_features(np.linspace(0, 1, 10)) @@ -1550,11 +1892,11 @@ def test_order_is_positive(self, n_basis_funcs, order): (1, does_not_raise()), (2, does_not_raise()), ( - 10, - pytest.raises( - ValueError, - match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", - ), + 10, + pytest.raises( + ValueError, + match=r"[a-z]+|[A-Z]+ `order` parameter cannot be larger", + ), ), ], ) @@ -1570,7 +1912,9 @@ def test_order_setter(self, n_basis_funcs, order, expectation): @pytest.mark.parametrize( "sample_range", [(0, 1), (0.1, 0.9), (-0.5, 1), (0, 1.5), (-0.5, 1.5)] ) - def test_samples_range_matches_compute_features_requirements(self, sample_range: tuple): + def test_samples_range_matches_compute_features_requirements( + self, sample_range: tuple + ): """ Verifies that the compute_features() method can handle input ranges. """ @@ -1589,46 +1933,57 @@ class CombinedBasis(BasisFuncsTesting): cls = None @staticmethod - def instantiate_basis(n_basis, basis_class, class_specific_params, window_size=10, **kwargs): + def instantiate_basis( + n_basis, basis_class, class_specific_params, window_size=10, **kwargs + ): """Instantiate and return two basis of the type specified.""" # Set non-optional args default_kwargs = { "n_basis_funcs": n_basis, "window_size": window_size, - "decay_rates": np.arange(1, 1 + n_basis) + "decay_rates": np.arange(1, 1 + n_basis), } repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) if repeated_keys: - raise ValueError("Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs") + raise ValueError( + "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" + ) # Merge with provided extra kwargs kwargs = {**default_kwargs, **kwargs} - if basis_class == AdditiveBasis: - kwargs_mspline = trim_kwargs(basis.EvalMSpline, kwargs, class_specific_params) - kwargs_raised_cosine = trim_kwargs(basis.ConvRaisedCosineLinear, kwargs, class_specific_params) + kwargs_mspline = trim_kwargs( + basis.EvalMSpline, kwargs, class_specific_params + ) + kwargs_raised_cosine = trim_kwargs( + basis.ConvRaisedCosineLinear, kwargs, class_specific_params + ) b1 = basis.EvalMSpline(**kwargs_mspline) b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) basis_obj = b1 + b2 elif basis_class == MultiplicativeBasis: - kwargs_mspline = trim_kwargs(basis.EvalMSpline, kwargs, class_specific_params) - kwargs_raised_cosine = trim_kwargs(basis.ConvRaisedCosineLinear, kwargs, class_specific_params) + kwargs_mspline = trim_kwargs( + basis.EvalMSpline, kwargs, class_specific_params + ) + kwargs_raised_cosine = trim_kwargs( + basis.ConvRaisedCosineLinear, kwargs, class_specific_params + ) b1 = basis.EvalMSpline(**kwargs_mspline) b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) basis_obj = b1 * b2 else: - basis_obj = basis_class(**trim_kwargs(basis_class, kwargs, class_specific_params)) + basis_obj = basis_class( + **trim_kwargs(basis_class, kwargs, class_specific_params) + ) return basis_obj class TestAdditiveBasis(CombinedBasis): cls = AdditiveBasis - @pytest.mark.parametrize( - "samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]] - ) + @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.EvalBSpline, basis.ConvBSpline]) def test_non_empty_samples(self, base_cls, samples, class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} @@ -1666,7 +2021,14 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("window_size", [10]) def test_compute_features_returns_expected_number_of_basis( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + sample_size, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -1698,16 +2060,23 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("window_size", [10]) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + sample_size, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -1728,17 +2097,24 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("window_size", [10]) def test_number_of_required_inputs_compute_features( - self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_input, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -1765,8 +2141,12 @@ def test_evaluate_on_grid_meshgrid_size( """ Test whether the resulting meshgrid size matches the sample size input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -1785,8 +2165,12 @@ def test_evaluate_on_grid_basis_size( """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -1805,8 +2189,12 @@ def test_evaluate_on_grid_input_number( Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input required_dim = ( @@ -1835,10 +2223,11 @@ def test_pynapple_support_compute_features( d=np.linspace(0, 1, sample_size), time_support=iset, ) - basis_add = (self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) + - self.instantiate_basis( + basis_add = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + self.instantiate_basis( n_basis_b, basis_b, class_specific_params, window_size=10 - )) + ) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) # check type @@ -1854,13 +2243,20 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize(" window_size", [3]) def test_call_input_num( - self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + num_input, + window_size, + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -1896,10 +2292,10 @@ def test_call_input_shape( class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: @@ -1912,13 +2308,20 @@ def test_call_input_shape( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_sample_axis( - self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + time_axis_shape, + window_size, + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality @@ -1929,17 +2332,19 @@ def test_call_sample_axis( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params): + def test_call_nan( + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + ): if ( basis_a == basis.OrthExponentialBasis or basis_b == basis.OrthExponentialBasis ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -1951,20 +2356,22 @@ def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, cl @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_call_equivalent_in_conv( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=3 + n_basis_a, basis_a, class_specific_params, window_size=3 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=3 + n_basis_b, basis_b, class_specific_params, window_size=3 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj @@ -1977,13 +2384,13 @@ def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, c @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2001,13 +2408,13 @@ def test_pynapple_support( @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2019,13 +2426,13 @@ def test_call_basis_number( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2053,9 +2460,8 @@ def test_call_sample_range( mn, mx, expectation, - window_size, - class_specific_params + class_specific_params, ): if expectation == "check": if ( @@ -2068,10 +2474,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2081,12 +2487,14 @@ def test_call_sample_range( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_fit_kernel( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj bas._set_kernel() @@ -2108,18 +2516,23 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_transform_fails( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: context = does_not_raise() else: - context = pytest.raises(ValueError, match="You must call `_set_kernel` before `_compute_features`") + context = pytest.raises( + ValueError, + match="You must call `_set_kernel` before `_compute_features`", + ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality bas._compute_features(*x) @@ -2127,8 +2540,8 @@ def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_spe @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(11, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2139,8 +2552,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2158,8 +2571,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas = bas1 + bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2174,7 +2587,7 @@ class TestMultiplicativeBasis(CombinedBasis): "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] ) @pytest.mark.parametrize(" ws", [3]) - def test_non_empty_samples(self, samples, ws): + def test_non_empty_samples(self, samples, ws): basis_obj = basis.EvalMSpline(5) * basis.EvalRaisedCosineLinear(5) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -2208,7 +2621,14 @@ def test_compute_features_input(self, eval_input): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("window_size", [10]) def test_compute_features_returns_expected_number_of_basis( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + sample_size, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2216,10 +2636,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2241,17 +2661,24 @@ def test_compute_features_returns_expected_number_of_basis( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("window_size", [10]) def test_sample_size_of_compute_features_matches_that_of_input( - self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + sample_size, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2271,17 +2698,24 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("window_size", [10]) def test_number_of_required_inputs_compute_features( - self, n_input, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_input, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2308,8 +2742,12 @@ def test_evaluate_on_grid_meshgrid_size( """ Test whether the resulting meshgrid size matches the sample size input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -2328,8 +2766,12 @@ def test_evaluate_on_grid_basis_size( """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( *[sample_size] * basis_obj._n_input_dimensionality @@ -2348,8 +2790,12 @@ def test_evaluate_on_grid_input_number( Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input required_dim = ( @@ -2371,12 +2817,23 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("sample_size_a", [11, 12]) @pytest.mark.parametrize("sample_size_b", [11, 12]) def test_inconsistent_sample_sizes( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size_a, sample_size_b, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size_a, + sample_size_b, + class_specific_params, ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b - basis_a_obj = self.instantiate_basis(n_basis_a, basis_a, class_specific_params, window_size=10) - basis_b_obj = self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) basis_obj = basis_a_obj * basis_b_obj if raise_exception: with pytest.raises( @@ -2407,7 +2864,9 @@ def test_pynapple_support_compute_features( ) basis_prod = self.instantiate_basis( n_basis_a, basis_a, class_specific_params, window_size=10 - ) * self.instantiate_basis(n_basis_b, basis_b, class_specific_params, window_size=10) + ) * self.instantiate_basis( + n_basis_b, basis_b, class_specific_params, window_size=10 + ) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) assert np.all(out.time_support.values == inp.time_support.values) @@ -2420,13 +2879,20 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize(" window_size", [3]) def test_call_input_num( - self, n_basis_a, n_basis_b, basis_a, basis_b, num_input, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + num_input, + window_size, + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2457,16 +2923,15 @@ def test_call_input_shape( basis_a, basis_b, inp, - window_size, expectation, - class_specific_params + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: @@ -2479,13 +2944,20 @@ def test_call_input_shape( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_sample_axis( - self, n_basis_a, n_basis_b, basis_a, basis_b, time_axis_shape, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + time_axis_shape, + window_size, + class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality @@ -2496,17 +2968,19 @@ def test_call_sample_axis( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params): + def test_call_nan( + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + ): if ( basis_a == basis.OrthExponentialBasis or basis_b == basis.OrthExponentialBasis ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2518,20 +2992,22 @@ def test_call_nan(self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, cl @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_call_equivalent_in_conv( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj @@ -2544,13 +3020,13 @@ def test_call_equivalent_in_conv(self, n_basis_a, n_basis_b, basis_a, basis_b, c @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -2568,13 +3044,13 @@ def test_pynapple_support( @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2586,13 +3062,13 @@ def test_call_basis_number( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2606,7 +3082,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [ 3]) + @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2620,9 +3096,8 @@ def test_call_sample_range( mn, mx, expectation, - window_size, - class_specific_params + class_specific_params, ): if expectation == "check": if ( @@ -2635,10 +3110,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -2648,12 +3123,14 @@ def test_call_sample_range( @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_fit_kernel(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_fit_kernel( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj bas._set_kernel() @@ -2675,18 +3152,23 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) - def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params): + def test_transform_fails( + self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: + if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: context = does_not_raise() else: - context = pytest.raises(ValueError, match="You must call `_set_kernel` before `_compute_features`") + context = pytest.raises( + ValueError, + match="You must call `_set_kernel` before `_compute_features`", + ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality bas._compute_features(*x) @@ -2694,8 +3176,8 @@ def test_transform_fails(self, n_basis_a, n_basis_b, basis_a, basis_b, class_spe @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(11, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2706,8 +3188,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2725,8 +3207,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas = bas1 * bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2736,8 +3218,8 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_n_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) + bas2 = basis.ConvBSpline(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) @@ -2758,7 +3240,9 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): else: raise_exception_value = False - basis_obj = CombinedBasis.instantiate_basis(5, basis_class, class_specific_params, window_size=10) + basis_obj = CombinedBasis.instantiate_basis( + 5, basis_class, class_specific_params, window_size=10 + ) if raise_exception_type: with pytest.raises(TypeError, match=r"Exponent should be an integer\!"): @@ -2792,7 +3276,9 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): ) def test_basis_to_transformer(basis_cls, class_specific_params): n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) trans_bas = bas.to_transformer() @@ -2807,9 +3293,13 @@ def test_basis_to_transformer(basis_cls, class_specific_params): "basis_cls", list_all_basis_classes(), ) -def test_transformer_has_the_same_public_attributes_as_basis(basis_cls, class_specific_params): +def test_transformer_has_the_same_public_attributes_as_basis( + basis_cls, class_specific_params +): n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} public_attrs_transformerbasis = { @@ -2829,9 +3319,13 @@ def test_transformer_has_the_same_public_attributes_as_basis(basis_cls, class_sp "basis_cls", list_all_basis_classes(), ) -def test_to_transformer_and_constructor_are_equivalent(basis_cls, class_specific_params): +def test_to_transformer_and_constructor_are_equivalent( + basis_cls, class_specific_params +): n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) + bas = CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) trans_bas_a = bas.to_transformer() trans_bas_b = TransformerBasis(bas) @@ -2843,7 +3337,10 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls, class_specific == ["_basis"] ) # and those bases are the same - assert np.all(trans_bas_a._basis.__dict__.pop("_decay_rates", 1) == trans_bas_b._basis.__dict__.pop("_decay_rates", 1)) + assert np.all( + trans_bas_a._basis.__dict__.pop("_decay_rates", 1) + == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) + ) assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ @@ -2852,7 +3349,9 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls, class_specific list_all_basis_classes(), ) def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): - bas_a = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) + bas_a = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) trans_bas_a = bas_a.to_transformer() # changing an attribute in bas should not change trans_bas @@ -2860,7 +3359,9 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): assert trans_bas_a.n_basis_funcs == 5 # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) trans_bas_b = bas_b.to_transformer() trans_bas_b.n_basis_funcs = 100 assert bas_b.n_basis_funcs == 5 @@ -2873,7 +3374,9 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): @pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): trans_basis = TransformerBasis( - CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10) + CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) ) assert trans_basis.n_basis_funcs == n_basis_funcs @@ -2884,9 +3387,13 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_param ) @pytest.mark.parametrize("n_basis_funcs_init", [5]) @pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params): +def test_transformerbasis_set_params( + basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params +): trans_basis = TransformerBasis( - CombinedBasis().instantiate_basis(n_basis_funcs_init, basis_cls, class_specific_params, window_size=10) + CombinedBasis().instantiate_basis( + n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 + ) ) trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) @@ -2901,9 +3408,13 @@ def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_func def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it trans_bas = TransformerBasis( - CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10) + CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) + ) + trans_bas._basis = CombinedBasis().instantiate_basis( + 20, basis_cls, class_specific_params, window_size=10 ) - trans_bas._basis = CombinedBasis().instantiate_basis(20, basis_cls, class_specific_params, window_size=10) assert trans_bas.n_basis_funcs == 20 assert trans_bas._basis.n_basis_funcs == 20 @@ -2917,7 +3428,11 @@ def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): # setting an attribute that is an attribute of the underlying _basis # should propagate setting it on _basis itself - trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10)) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) + ) trans_bas.n_basis_funcs = 20 assert trans_bas.n_basis_funcs == 20 @@ -2932,7 +3447,9 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_para def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): # modifying the transformerbasis's attributes shouldn't # touch the original basis that was used to create it - orig_bas = CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10) + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) trans_bas = TransformerBasis(orig_bas) trans_bas.n_basis_funcs = 20 @@ -2949,7 +3466,11 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_para def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): # changing an attribute that is not _basis or an attribute of _basis # is not allowed - trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=10)) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=10 + ) + ) with pytest.raises( ValueError, @@ -2965,8 +3486,12 @@ def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_pa def test_transformerbasis_addition(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 - bas_a = CombinedBasis().instantiate_basis(n_basis_funcs_a, basis_cls, class_specific_params, window_size=10) - bas_b = CombinedBasis().instantiate_basis(n_basis_funcs_b, basis_cls, class_specific_params, window_size=10) + bas_a = CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 + ) + bas_b = CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 + ) trans_bas_a = TransformerBasis(bas_a) trans_bas_b = TransformerBasis(bas_b) trans_bas_sum = trans_bas_a + trans_bas_b @@ -2991,8 +3516,16 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): def test_transformerbasis_multiplication(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs_a, basis_cls, class_specific_params, window_size=10)) - trans_bas_b = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs_b, basis_cls, class_specific_params, window_size=10)) + trans_bas_a = TransformerBasis( + CombinedBasis().instantiate_basis( + n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 + ) + ) + trans_bas_b = TransformerBasis( + CombinedBasis().instantiate_basis( + n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 + ) + ) trans_bas_prod = trans_bas_a * trans_bas_b assert isinstance(trans_bas_prod, TransformerBasis) assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) @@ -3024,7 +3557,11 @@ def test_transformerbasis_multiplication(basis_cls, class_specific_params): def test_transformerbasis_exponentiation( basis_cls, exponent: int, error_type, error_message, class_specific_params ): - trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10)) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) + ) if not isinstance(exponent, int): with pytest.raises(error_type, match=error_message): @@ -3038,7 +3575,11 @@ def test_transformerbasis_exponentiation( list_all_basis_classes(), ) def test_transformerbasis_dir(basis_cls, class_specific_params): - trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10)) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) + ) for attr_name in ( "fit", "transform", @@ -3057,7 +3598,9 @@ def test_transformerbasis_dir(basis_cls, class_specific_params): list_all_basis_classes("Conv"), ) def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): - orig_bas = CombinedBasis().instantiate_basis(10, basis_cls, class_specific_params, window_size=20) + orig_bas = CombinedBasis().instantiate_basis( + 10, basis_cls, class_specific_params, window_size=20 + ) trans_bas = TransformerBasis(orig_bas) # kernel should be saved in the object after fit @@ -3078,9 +3621,15 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params list_all_basis_classes(), ) @pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs, class_specific_params): +def test_transformerbasis_pickle( + tmpdir, basis_cls, n_basis_funcs, class_specific_params +): # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = TransformerBasis(CombinedBasis().instantiate_basis(n_basis_funcs, basis_cls, class_specific_params, window_size=10)) + trans_bas = TransformerBasis( + CombinedBasis().instantiate_basis( + n_basis_funcs, basis_cls, class_specific_params, window_size=10 + ) + ) filepath = tmpdir / "transformerbasis.pickle" with open(filepath, "wb") as f: pickle.dump(trans_bas, f) @@ -3123,10 +3672,18 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs, class_specifi list_all_basis_classes("Conv"), ) def test_multi_epoch_pynapple_basis( - basis_cls, tsd, window_size, shift, predictor_causality, nan_index, class_specific_params + basis_cls, + tsd, + window_size, + shift, + predictor_causality, + nan_index, + class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" - kwargs = dict(conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality)) + kwargs = dict( + conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality) + ) # require a ws of at least nbasis funcs. if "OrthExp" in basis_cls.__name__: @@ -3134,7 +3691,9 @@ def test_multi_epoch_pynapple_basis( # splines requires at least 1 basis more than the order of the spline. else: nbasis = 5 - bas = CombinedBasis().instantiate_basis(nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs) + bas = CombinedBasis().instantiate_basis( + nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + ) n_input = bas._n_input_dimensionality @@ -3180,10 +3739,18 @@ def test_multi_epoch_pynapple_basis( list_all_basis_classes("Conv"), ) def test_multi_epoch_pynapple_basis_transformer( - basis_cls, tsd, window_size, shift, predictor_causality, nan_index, class_specific_params + basis_cls, + tsd, + window_size, + shift, + predictor_causality, + nan_index, + class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" - kwargs = dict(conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality)) + kwargs = dict( + conv_kwargs=dict(shift=shift, predictor_causality=predictor_causality) + ) # require a ws of at least nbasis funcs. if "OrthExp" in basis_cls.__name__: nbasis = 2 @@ -3191,7 +3758,9 @@ def test_multi_epoch_pynapple_basis_transformer( else: nbasis = 5 - bas = CombinedBasis().instantiate_basis(nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs) + bas = CombinedBasis().instantiate_basis( + nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + ) n_input = bas._n_input_dimensionality @@ -3215,11 +3784,7 @@ def test_multi_epoch_pynapple_basis_transformer( @pytest.mark.parametrize( "bas1, bas2, bas3", - list( - itertools.product( - *[list_all_basis_classes()] * 3 - ) - ), + list(itertools.product(*[list_all_basis_classes()] * 3)), ) @pytest.mark.parametrize( "operator1, operator2, compute_slice", @@ -3292,7 +3857,7 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params + bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params ): # skip nested if any( @@ -3305,9 +3870,15 @@ def test__get_splitter( n_input_basis = [1, 2, 3] combine_basis = CombinedBasis() - bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") - bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") - bas3_instance = combine_basis.instantiate_basis(n_basis[2], bas3, class_specific_params, window_size=10, label="3") + bas1_instance = combine_basis.instantiate_basis( + n_basis[0], bas1, class_specific_params, window_size=10, label="1" + ) + bas2_instance = combine_basis.instantiate_basis( + n_basis[1], bas2, class_specific_params, window_size=10, label="2" + ) + bas3_instance = combine_basis.instantiate_basis( + n_basis[2], bas3, class_specific_params, window_size=10, label="3" + ) func1 = getattr(bas1_instance, operator1) func2 = getattr(bas2_instance, operator2) @@ -3322,11 +3893,7 @@ def test__get_splitter( @pytest.mark.parametrize( "bas1, bas2", - list( - itertools.product( - *[list_all_basis_classes()] * 2 - ) - ), + list(itertools.product(*[list_all_basis_classes()] * 2)), ) @pytest.mark.parametrize( "operator, n_input_basis_1, n_input_basis_2, compute_slice", @@ -3441,7 +4008,13 @@ def test__get_splitter( ], ) def test__get_splitter_split_by_input( - bas1, bas2, operator, n_input_basis_1, n_input_basis_2, compute_slice, class_specific_params + bas1, + bas2, + operator, + n_input_basis_1, + n_input_basis_2, + compute_slice, + class_specific_params, ): # skip nested if any( @@ -3452,8 +4025,12 @@ def test__get_splitter_split_by_input( # define the basis n_basis = [5, 6] combine_basis = CombinedBasis() - bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") - bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") + bas1_instance = combine_basis.instantiate_basis( + n_basis[0], bas1, class_specific_params, window_size=10, label="1" + ) + bas2_instance = combine_basis.instantiate_basis( + n_basis[1], bas2, class_specific_params, window_size=10, label="2" + ) func1 = getattr(bas1_instance, operator) bas12 = func1(bas2_instance) @@ -3470,11 +4047,7 @@ def test__get_splitter_split_by_input( @pytest.mark.parametrize( "bas1, bas2, bas3", - list( - itertools.product( - *[list_all_basis_classes()] * 3 - ) - ), + list(itertools.product(*[list_all_basis_classes()] * 3)), ) def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): # skip nested @@ -3485,9 +4058,15 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): return combine_basis = CombinedBasis() - bas1_instance = combine_basis.instantiate_basis(5, bas1, class_specific_params, window_size=10, label="label") - bas2_instance = combine_basis.instantiate_basis(5, bas2, class_specific_params, window_size=10, label="label") - bas3_instance = combine_basis.instantiate_basis(5, bas3, class_specific_params, window_size=10, label="label") + bas1_instance = combine_basis.instantiate_basis( + 5, bas1, class_specific_params, window_size=10, label="label" + ) + bas2_instance = combine_basis.instantiate_basis( + 5, bas2, class_specific_params, window_size=10, label="label" + ) + bas3_instance = combine_basis.instantiate_basis( + 5, bas3, class_specific_params, window_size=10, label="label" + ) bas_obj = bas1_instance + bas2_instance + bas3_instance inps = [np.zeros((1,)) for n in range(3)] @@ -3498,11 +4077,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): @pytest.mark.parametrize( "bas1, bas2", - list( - itertools.product( - *[list_all_basis_classes()] * 2 - ) - ), + list(itertools.product(*[list_all_basis_classes()] * 2)), ) @pytest.mark.parametrize( "x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28 @@ -3520,7 +4095,9 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): ), ], ) -def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params): +def test_split_feature_axis( + bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params +): # skip nested if any( bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) @@ -3530,8 +4107,12 @@ def test_split_feature_axis(bas1, bas2, x, axis, expectation, exp_shapes, class_ # define the basis n_basis = [5, 6] combine_basis = CombinedBasis() - bas1_instance = combine_basis.instantiate_basis(n_basis[0], bas1, class_specific_params, window_size=10, label="1") - bas2_instance = combine_basis.instantiate_basis(n_basis[1], bas2, class_specific_params, window_size=10, label="2") + bas1_instance = combine_basis.instantiate_basis( + n_basis[0], bas1, class_specific_params, window_size=10, label="1" + ) + bas2_instance = combine_basis.instantiate_basis( + n_basis[1], bas2, class_specific_params, window_size=10, label="2" + ) bas = bas1_instance + bas2_instance bas._set_num_output_features(np.zeros((1, 2)), np.zeros((1, 3))) From 2edb2058a1bc0789f46490068e9a0e2050fcd34c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 14:48:48 -0500 Subject: [PATCH 046/109] fixed some tests on docstrings --- src/nemos/basis/_basis.py | 106 ++++++++++++++++++++++++++++++++++++-- tests/test_basis.py | 59 ++++++++++++--------- 2 files changed, 136 insertions(+), 29 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index a17c5f9b..d3b9c8e5 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -3,7 +3,7 @@ import abc import copy -from functools import wraps +from functools import wraps, partial from typing import Callable, Generator, Literal, Optional, Tuple, Union import jax @@ -880,7 +880,8 @@ class TransformerBasis: Examples -------- - >>> from nemos.basis import EvalBSpline, TransformerBasis + >>> from nemos.basis import EvalBSpline + >>> from nemos.basis._basis import TransformerBasis >>> from nemos.glm import GLM >>> from sklearn.pipeline import Pipeline >>> from sklearn.model_selection import GridSearchCV @@ -1244,6 +1245,10 @@ def __pow__(self, exponent: int) -> TransformerBasis: return TransformerBasis(self._basis**exponent) +add_docstring_additive = partial(add_docstring, cls=Basis) +add_docstring_multiplicative = partial(add_docstring, cls=Basis) + + class AdditiveBasis(Basis): """ Class representing the addition of two Basis objects. @@ -1326,6 +1331,21 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix: : The basis function evaluated at the samples, shape (n_samples, n_basis_funcs) + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + >>> x, y = np.random.normal(size=(2, 30)) + + >>> # define two basis objects and add them + >>> basis_1 = nmo.basis.EvalBSpline(10) + >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> additive_basis = basis_1 + basis_2 + + >>> # call the basis. + >>> out = additive_basis(x, y) + """ X = np.hstack( ( @@ -1335,6 +1355,24 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix: ) return X + @add_docstring_additive("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + r""" + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.glm import GLM + >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") + >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis_add = basis1 + basis2 + >>> X_multi = basis_add.compute_features(np.random.randn(20), np.random.randn(20, 2)) + >>> print(X_multi.shape) # num_features: 17 = 5 + 2*6 + (20, 17) + + """ + return super().compute_features(*xi) + def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ Compute features for added bases and concatenate. @@ -1509,7 +1547,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Parameters ---------- n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of + The number of points in the uniformly spaced grid. The length of n_samples must equal the number of combined bases. Returns @@ -1536,6 +1574,20 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`. This differs from the numpy.meshgrid default, which uses Cartesian indexing. For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`. + + Examples + -------- + >>> import numpy as np + >>> import nemos as nmo + + >>> # define two basis objects and add them + >>> basis_1 = nmo.basis.EvalBSpline(10) + >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> additive_basis = basis_1 + basis_2 + + >>> # evaluate on a grid of 10 x 10 equi-spaced samples + >>> X, Y, Z = additive_basis.evaluate_on_grid(10, 10) + """ return super().evaluate_on_grid(*n_samples) @@ -1699,7 +1751,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Parameters ---------- n_samples[0],...,n_samples[n] - The number of samples in each axis of the grid. The length of + The number of points in the uniformly spaced grid. The length of n_samples must equal the number of combined bases. Returns @@ -1735,3 +1787,49 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: >>> X, Y, Z = mult_basis.evaluate_on_grid(10, 10) """ return super().evaluate_on_grid(*n_samples) + + @add_docstring_multiplicative("compute_features") + def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.glm import GLM + >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") + >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis_mul = basis1 * basis2 + >>> X_multi = basis_mul.compute_features(np.random.randn(20), np.random.randn(20, 2)) + >>> print(X_multi.shape) # num_features: 60 = 5 * 2 * 6 + (20, 60) + + """ + return super().compute_features(*xi) + + @add_docstring_multiplicative("split_by_feature") + def split_by_feature( + self, + x: NDArray, + axis: int = 1, + ): + """ + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.glm import GLM + >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") + >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis_mul = basis1 * basis2 + >>> X_multi = basis_mul.compute_features(np.random.randn(20), np.random.randn(20, 2)) + >>> print(X_multi.shape) # num_features: 60 = 5 * 2 * 6 + (20, 60) + + >>> # The multiplicative basis is a single 2D component. + >>> split_features = basis_mul.split_by_feature(X_multi, axis=1) + >>> for feature, arr in split_features.items(): + ... print(f"{feature}: shape {arr.shape}") + (one_input * two_inputs): shape (20, 1, 60) + + """ + return super().split_by_feature(x, axis=axis) diff --git a/tests/test_basis.py b/tests/test_basis.py index f07d297e..77318b7c 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -30,6 +30,7 @@ ) from nemos.basis._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from nemos.utils import pynapple_concatenate_numpy +import nemos as nmo @pytest.fixture() @@ -82,7 +83,7 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis) if issubclass(class_obj, Basis) - ] + ] + [bas for _, bas in utils_testing.get_non_abstract_classes(nmo.basis._basis) if bas != TransformerBasis] if filter_basis != "all": all_basis = [a for a in all_basis if filter_basis in a.__name__] return all_basis @@ -106,7 +107,7 @@ def test_all_basis_are_tested() -> None: ] # Create the set of basis function objects that are tested using the cls definition - tested_bases = {test_cls.cls for test_cls in subclasses} + tested_bases = {test_cls.cls[mode] for mode in ["eval", "conv"] for test_cls in subclasses if test_cls != CombinedBasis} # Create the set of all the concrete basis classes all_bases = set(list_all_basis_classes()) @@ -117,23 +118,26 @@ def test_all_basis_are_tested() -> None: f"The following classes are not tested: {[bas.__qualname__ for bas in all_bases.difference(tested_bases)]}" ) + pytest_marks = getattr(TestSharedMethods, "pytestmark", []) + + # Find the parametrize mark for TestSharedMethods + out = None + for mark in pytest_marks: + if mark.name == "parametrize": + # Return the arguments of the parametrize mark + out = mark.args[1] # The second argument contains the list + + if out is None: + raise ValueError("cannot fine parametrization.") + + basis_tested_in_shared_methods = {o[key] for key in ("eval", "conv") for o in out} + all_one_dim_basis = set(list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + assert basis_tested_in_shared_methods == all_one_dim_basis + @pytest.mark.parametrize( - "basis_instance", - [ - basis.EvalBSpline(10), - basis.ConvBSpline(10, window_size=11), - basis.EvalCyclicBSpline(10), - basis.ConvCyclicBSpline(10, window_size=11), - basis.EvalMSpline(10), - basis.ConvMSpline(10, window_size=11), - basis.EvalRaisedCosineLinear(10), - basis.ConvRaisedCosineLinear(10, window_size=11), - basis.EvalRaisedCosineLog(10), - basis.ConvRaisedCosineLog(10, window_size=11), - basis.EvalOrthExponential(10, np.arange(1, 11)), - basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), - ], + "basis_cls", + list_all_basis_classes(), ) @pytest.mark.parametrize( "method_name, descr_match", @@ -149,7 +153,9 @@ def test_all_basis_are_tested() -> None: ), ], ) -def test_example_docstrings_add(basis_instance, method_name, descr_match): +def test_example_docstrings_add(basis_cls, method_name, descr_match, class_specific_params): + + basis_instance = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) method = getattr(basis_instance, method_name) doc = method.__doc__ examp_delim = "\n Examples\n --------" @@ -161,10 +167,13 @@ def test_example_docstrings_add(basis_instance, method_name, descr_match): assert re.search(descr_match, doc_components[0]) # check that the basis name is in the example - assert basis_instance.__class__.__name__ in doc_components[1] + if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: + assert basis_cls.__name__ in doc_components[1] - # check that no other basis name is in the example + # check that no other basis name is in the example (except for additive and multiplicative) for basis_name in basis.__dir__(): + if basis_cls in [AdditiveBasis, MultiplicativeBasis]: + continue if basis_name == basis_instance.__class__.__name__: continue assert basis_name not in doc_components[1] @@ -1340,7 +1349,7 @@ def test_transformer_get_params(self, cls): assert np.all(rates_1 == rates_2) -class TestRaisedCosineLogBasis: +class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} @pytest.mark.parametrize("width", [1.5, 2, 2.5]) @@ -1961,7 +1970,7 @@ def instantiate_basis( basis.ConvRaisedCosineLinear, kwargs, class_specific_params ) b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) + b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) basis_obj = b1 + b2 elif basis_class == MultiplicativeBasis: kwargs_mspline = trim_kwargs( @@ -1971,7 +1980,7 @@ def instantiate_basis( basis.ConvRaisedCosineLinear, kwargs, class_specific_params ) b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.RaisedCosineBasisLinear(**kwargs_raised_cosine) + b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) basis_obj = b1 * b2 else: basis_obj = basis_class( @@ -1981,7 +1990,7 @@ def instantiate_basis( class TestAdditiveBasis(CombinedBasis): - cls = AdditiveBasis + cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.EvalBSpline, basis.ConvBSpline]) @@ -2581,7 +2590,7 @@ def test_expected_input_number(self, n_input, expectation): class TestMultiplicativeBasis(CombinedBasis): - cls = MultiplicativeBasis + cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} @pytest.mark.parametrize( "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] From 76a4b5cecc89d5d1bd16466524db255d83867004 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 15:03:09 -0500 Subject: [PATCH 047/109] fixed tests that assumed 1d --- src/nemos/basis/_basis.py | 2 +- tests/test_basis.py | 44 ++++++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index d3b9c8e5..6f9ca14b 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -3,7 +3,7 @@ import abc import copy -from functools import wraps, partial +from functools import partial, wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union import jax diff --git a/tests/test_basis.py b/tests/test_basis.py index 77318b7c..4c569a66 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -14,6 +14,7 @@ import utils_testing from sklearn.base import clone as sk_clone +import nemos as nmo import nemos.basis.basis as basis import nemos.convolve as convolve from nemos.basis._basis import ( @@ -30,7 +31,6 @@ ) from nemos.basis._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from nemos.utils import pynapple_concatenate_numpy -import nemos as nmo @pytest.fixture() @@ -83,7 +83,11 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: class_obj for _, class_obj in utils_testing.get_non_abstract_classes(basis) if issubclass(class_obj, Basis) - ] + [bas for _, bas in utils_testing.get_non_abstract_classes(nmo.basis._basis) if bas != TransformerBasis] + ] + [ + bas + for _, bas in utils_testing.get_non_abstract_classes(nmo.basis._basis) + if bas != TransformerBasis + ] if filter_basis != "all": all_basis = [a for a in all_basis if filter_basis in a.__name__] return all_basis @@ -107,7 +111,12 @@ def test_all_basis_are_tested() -> None: ] # Create the set of basis function objects that are tested using the cls definition - tested_bases = {test_cls.cls[mode] for mode in ["eval", "conv"] for test_cls in subclasses if test_cls != CombinedBasis} + tested_bases = { + test_cls.cls[mode] + for mode in ["eval", "conv"] + for test_cls in subclasses + if test_cls != CombinedBasis + } # Create the set of all the concrete basis classes all_bases = set(list_all_basis_classes()) @@ -131,7 +140,9 @@ def test_all_basis_are_tested() -> None: raise ValueError("cannot fine parametrization.") basis_tested_in_shared_methods = {o[key] for key in ("eval", "conv") for o in out} - all_one_dim_basis = set(list_all_basis_classes("Eval") + list_all_basis_classes("Conv")) + all_one_dim_basis = set( + list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) assert basis_tested_in_shared_methods == all_one_dim_basis @@ -153,9 +164,13 @@ def test_all_basis_are_tested() -> None: ), ], ) -def test_example_docstrings_add(basis_cls, method_name, descr_match, class_specific_params): +def test_example_docstrings_add( + basis_cls, method_name, descr_match, class_specific_params +): - basis_instance = CombinedBasis().instantiate_basis(5, basis_cls, class_specific_params, window_size=10) + basis_instance = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) method = getattr(basis_instance, method_name) doc = method.__doc__ examp_delim = "\n Examples\n --------" @@ -2843,19 +2858,17 @@ def test_inconsistent_sample_sizes( basis_b_obj = self.instantiate_basis( n_basis_b, basis_b, class_specific_params, window_size=10 ) + input_a = [np.linspace(0, 1, sample_size_a)] * basis_a_obj._n_input_dimensionality + input_b = [np.linspace(0, 1, sample_size_b)] * basis_b_obj._n_input_dimensionality basis_obj = basis_a_obj * basis_b_obj if raise_exception: with pytest.raises( ValueError, match=r"Sample size mismatch\. Input elements have inconsistent", ): - basis_obj.compute_features( - np.linspace(0, 1, sample_size_a), np.linspace(0, 1, sample_size_b) - ) + basis_obj.compute_features(*input_a, *input_b) else: - basis_obj.compute_features( - np.linspace(0, 1, sample_size_a), np.linspace(0, 1, sample_size_b) - ) + basis_obj.compute_features(*input_a, *input_b) @pytest.mark.parametrize("sample_size", [30]) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3273,10 +3286,13 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): elif exponent == 3: basis_obj = basis_obj * basis_obj * basis_obj + non_nan = ~np.isnan(eval_pow) + out = basis_obj.compute_features(*[samples] * basis_obj._n_input_dimensionality) assert np.allclose( - eval_pow, - basis_obj.compute_features(*[samples] * basis_obj._n_input_dimensionality), + eval_pow[non_nan], + out[non_nan], ) + assert np.all(np.isnan(out[~non_nan])) @pytest.mark.parametrize( From df1cb6bb3c11a6cf1c8a796b8afe2b567a2e9fd4 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 15:23:36 -0500 Subject: [PATCH 048/109] fixed all basis tests and linted --- tests/test_basis.py | 82 ++++++++++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 4c569a66..91e2a3b5 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2858,8 +2858,12 @@ def test_inconsistent_sample_sizes( basis_b_obj = self.instantiate_basis( n_basis_b, basis_b, class_specific_params, window_size=10 ) - input_a = [np.linspace(0, 1, sample_size_a)] * basis_a_obj._n_input_dimensionality - input_b = [np.linspace(0, 1, sample_size_b)] * basis_b_obj._n_input_dimensionality + input_a = [ + np.linspace(0, 1, sample_size_a) + ] * basis_a_obj._n_input_dimensionality + input_b = [ + np.linspace(0, 1, sample_size_b) + ] * basis_b_obj._n_input_dimensionality basis_obj = basis_a_obj * basis_b_obj if raise_exception: with pytest.raises( @@ -3311,6 +3315,9 @@ def test_basis_to_transformer(basis_cls, class_specific_params): # check that things like n_basis_funcs are the same as the original basis for k in bas.__dict__.keys(): + # skip for add and multiplicative. + if basis_cls in [AdditiveBasis, MultiplicativeBasis]: + continue assert np.all(getattr(bas, k) == getattr(trans_bas, k)) @@ -3342,7 +3349,7 @@ def test_transformer_has_the_same_public_attributes_as_basis( @pytest.mark.parametrize( "basis_cls", - list_all_basis_classes(), + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) def test_to_transformer_and_constructor_are_equivalent( basis_cls, class_specific_params @@ -3380,16 +3387,28 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): trans_bas_a = bas_a.to_transformer() # changing an attribute in bas should not change trans_bas - bas_a.n_basis_funcs = 10 - assert trans_bas_a.n_basis_funcs == 5 + if basis_cls in [AdditiveBasis, MultiplicativeBasis]: + bas_a._basis1.n_basis_funcs = 10 + assert trans_bas_a._basis._basis1.n_basis_funcs == 5 - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b.n_basis_funcs = 100 - assert bas_b.n_basis_funcs == 5 + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) + trans_bas_b = bas_b.to_transformer() + trans_bas_b._basis._basis1.n_basis_funcs = 100 + assert bas_b._basis1.n_basis_funcs == 5 + else: + bas_a.n_basis_funcs = 10 + assert trans_bas_a.n_basis_funcs == 5 + + # changing an attribute in the transformer basis should not change the original + bas_b = CombinedBasis().instantiate_basis( + 5, basis_cls, class_specific_params, window_size=10 + ) + trans_bas_b = bas_b.to_transformer() + trans_bas_b.n_basis_funcs = 100 + assert bas_b.n_basis_funcs == 5 @pytest.mark.parametrize( @@ -3403,12 +3422,18 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_param n_basis_funcs, basis_cls, class_specific_params, window_size=10 ) ) - assert trans_basis.n_basis_funcs == n_basis_funcs + if basis_cls in [AdditiveBasis, MultiplicativeBasis]: + for bas in [ + getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") + ]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_basis.n_basis_funcs == n_basis_funcs @pytest.mark.parametrize( "basis_cls", - list_all_basis_classes(), + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) @pytest.mark.parametrize("n_basis_funcs_init", [5]) @pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) @@ -3428,7 +3453,7 @@ def test_transformerbasis_set_params( @pytest.mark.parametrize( "basis_cls", - list_all_basis_classes(), + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it @@ -3448,7 +3473,7 @@ def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): @pytest.mark.parametrize( "basis_cls", - list_all_basis_classes(), + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): # setting an attribute that is an attribute of the underlying _basis @@ -3467,7 +3492,7 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_para @pytest.mark.parametrize( "basis_cls", - list_all_basis_classes(), + list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), ) def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): # modifying the transformerbasis's attributes shouldn't @@ -3530,8 +3555,9 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): trans_bas_sum._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality ) - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b + if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: + assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b @pytest.mark.parametrize( @@ -3562,8 +3588,9 @@ def test_transformerbasis_multiplication(basis_cls, class_specific_params): trans_bas_prod._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality ) - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b + if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: + assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a + assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b @pytest.mark.parametrize( @@ -3613,7 +3640,10 @@ def test_transformerbasis_dir(basis_cls, class_specific_params): "mode", "window_size", ): - if attr_name == "window_size" and "Eval" in trans_bas._basis.__class__.__name__: + if ( + attr_name == "window_size" + and "Conv" not in trans_bas._basis.__class__.__name__ + ): continue assert attr_name in dir(trans_bas) @@ -3662,7 +3692,13 @@ def test_transformerbasis_pickle( trans_bas2 = pickle.load(f) assert isinstance(trans_bas2, TransformerBasis) - assert trans_bas2.n_basis_funcs == n_basis_funcs + if basis_cls in [AdditiveBasis, MultiplicativeBasis]: + for bas in [ + getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") + ]: + assert bas.n_basis_funcs == n_basis_funcs + else: + assert trans_bas2.n_basis_funcs == n_basis_funcs @pytest.mark.parametrize( From 839db6994226ec890e6e7b1e25b364163b9225a2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 15:26:16 -0500 Subject: [PATCH 049/109] fixed all basis tests and linted --- tests/test_basis.py | 83 --------------------------------------------- 1 file changed, 83 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 91e2a3b5..e9668dbe 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -475,38 +475,6 @@ def test_set_num_basis_input(self, n_input, cls): assert bas.n_basis_input == (n_input,) assert bas._n_basis_input == (n_input,) - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - ( - -0.5, - 0, - 1, - pytest.raises(ValueError, match="All the samples lie outside"), - ), - (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - ( - np.linspace(-1, 0, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ( - np.linspace(1, 2, 10), - 0, - 1, - pytest.warns(UserWarning, match="More than 90% of the samples"), - ), - ], - ) - def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation, cls): - basis_obj = cls["eval"]( - 5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5) - ) - with expectation: - basis_obj.compute_features(samples) - @pytest.mark.parametrize( "bounds, samples, nan_idx, mn, mx", [ @@ -578,57 +546,6 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): ) assert bounds == bas.bounds if bounds else bas.bounds is None - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (1, 2, 3), - pytest.raises( - ValueError, match="The provided `bounds` must be of length two" - ), - ), - ], - ) - def test_vmin_vmax_init(self, bounds, expectation, cls): - with expectation: - bas = cls["eval"]( - n_basis_funcs=5, bounds=bounds, **extra_decay_rates(cls["eval"], 5) - ) - assert bounds == bas.bounds if bounds else bas.bounds is None - - @pytest.mark.parametrize( - "bounds, expectation", - [ - (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), - ((1, 3), does_not_raise()), - (("a", 3), pytest.raises(TypeError, match="Could not convert")), - ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")), - ( - (2, 1), - pytest.raises( - ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" - ), - ), - ], - ) - def test_vmin_vmax_setter(self, bounds, expectation, cls): - bas = cls["eval"]( - n_basis_funcs=5, bounds=(1, 3), **extra_decay_rates(cls["eval"], 5) - ) - with expectation: - bas.set_params(bounds=bounds) - assert bounds == bas.bounds if bounds else bas.bounds is None - @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] From d2440a67254d92c78c936aeb2c4841c4e54c6d4a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 15:39:13 -0500 Subject: [PATCH 050/109] fixed other tests relying on basis --- tests/conftest.py | 2 +- tests/test_identifiability_constraints.py | 12 ++-- tests/test_pipeline.py | 86 +++++++++++------------ tests/test_simulation.py | 2 +- 4 files changed, 51 insertions(+), 51 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 77af28b5..4b225125 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -286,7 +286,7 @@ def coupled_model_simulate(): ) # shrink the filters for simulation stability coupling_filter_bank *= 0.8 - basis = nmo.basis.RaisedCosineBasisLog(20) + basis = nmo.basis.EvalRaisedCosineLog(20) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ab0a6cbe..ea7a1875 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from nemos.basis.basis import BSplineBasis, RaisedCosineBasisLinear +from nemos.basis.basis import EvalBSpline, ConvBSpline, EvalRaisedCosineLinear from nemos.identifiability_constraints import ( _WARN_FLOAT32_MESSAGE, _find_drop_column, @@ -92,20 +92,20 @@ def test_apply_identifiability_constraints_add_constant(add_intercept, expected_ @pytest.mark.parametrize( "basis, input_shape, output_shape, expected_columns", [ - (RaisedCosineBasisLinear(10, width=4), (50,), (50, 10), jnp.arange(10)), + (EvalRaisedCosineLinear(10, width=4), (50,), (50, 10), jnp.arange(10)), ( - BSplineBasis(5) + BSplineBasis(6), + EvalBSpline(5) + EvalBSpline(6), (20,), (20, 9), jnp.array([1, 2, 3, 4, 6, 7, 8, 9, 10]), ), ( - BSplineBasis(5, mode="conv", window_size=10) + BSplineBasis(6), + ConvBSpline(5, window_size=10) + EvalBSpline(6), (20,), (20, 10), jnp.array([0, 1, 2, 3, 4, 6, 7, 8, 9, 10]), ), - (BSplineBasis(5), (10,), (10, 4), jnp.arange(1, 5)), + (EvalBSpline(5), (10,), (10, 4), jnp.arange(1, 5)), ], ) def test_apply_identifiability_constraints_by_basis_component( @@ -207,7 +207,7 @@ def test_apply_constraint_with_invalid(invalid_entries): ) def test_apply_constraint_by_basis_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" - basis = BSplineBasis(5) + basis = EvalBSpline(5) x = basis.compute_features( np.random.randn( 10, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5487b497..10a1ae28 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,21 +6,21 @@ from sklearn.model_selection import GridSearchCV from nemos import basis - +from nemos.basis._basis import TransformerBasis @pytest.mark.parametrize( "bas", [ basis.EvalMSpline(5), - basis.BSplineBasis(5), - basis.CyclicBSplineBasis(5), - basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), - basis.RaisedCosineBasisLinear(5), + basis.EvalBSpline(5), + basis.EvalCyclicBSpline(5), + basis.EvalOrthExponential(5, decay_rates=np.arange(1, 6)), + basis.EvalRaisedCosineLinear(5), ], ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = basis.TransformerBasis(bas) + bas = TransformerBasis(bas) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) @@ -30,15 +30,15 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): "bas", [ basis.EvalMSpline(5), - basis.BSplineBasis(5), - basis.CyclicBSplineBasis(5), - basis.RaisedCosineBasisLinear(5), - basis.RaisedCosineBasisLog(5), + basis.EvalBSpline(5), + basis.EvalCyclicBSpline(5), + basis.EvalRaisedCosineLinear(5), + basis.EvalRaisedCosineLog(5), ], ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = basis.TransformerBasis(bas) + bas = TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -49,17 +49,17 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): "bas", [ basis.EvalMSpline(5), - basis.BSplineBasis(5), - basis.CyclicBSplineBasis(5), - basis.RaisedCosineBasisLinear(5), - basis.RaisedCosineBasisLog(5), + basis.EvalBSpline(5), + basis.EvalCyclicBSpline(5), + basis.EvalRaisedCosineLinear(5), + basis.EvalRaisedCosineLog(5), ], ) def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = basis.TransformerBasis(bas) + bas = TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -74,17 +74,17 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( "bas_cls", [ basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, + basis.EvalMSpline, + basis.EvalCyclicBSpline, + basis.EvalRaisedCosineLinear, + basis.EvalRaisedCosineLog, ], ) def test_sklearn_transformer_pipeline_cv_directly_over_basis( bas_cls, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = basis.TransformerBasis(bas_cls(5)) + bas = TransformerBasis(bas_cls(5)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -95,17 +95,17 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( "bas_cls", [ basis.EvalMSpline, - basis.BSplineBasis, - basis.CyclicBSplineBasis, - basis.RaisedCosineBasisLinear, - basis.RaisedCosineBasisLog, + basis.EvalMSpline, + basis.EvalCyclicBSpline, + basis.EvalRaisedCosineLinear, + basis.EvalRaisedCosineLog, ], ) def test_sklearn_transformer_pipeline_cv_illegal_combination( bas_cls, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = basis.TransformerBasis(bas_cls(5)) + bas = TransformerBasis(bas_cls(5)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), @@ -123,35 +123,35 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( "bas, expected_nans", [ (basis.EvalMSpline(5), 0), - (basis.BSplineBasis(5), 0), - (basis.CyclicBSplineBasis(5), 0), - (basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), 0), - (basis.RaisedCosineBasisLinear(5), 0), - (basis.RaisedCosineBasisLog(5), 0), - (basis.RaisedCosineBasisLog(5) + basis.EvalMSpline(5), 0), - (basis.EvalMSpline(5, mode="conv", window_size=3), 6), - (basis.BSplineBasis(5, mode="conv", window_size=3), 6), + (basis.EvalBSpline(5), 0), + (basis.EvalCyclicBSpline(5), 0), + (basis.EvalOrthExponential(5, decay_rates=np.arange(1, 6)), 0), + (basis.EvalRaisedCosineLinear(5), 0), + (basis.EvalRaisedCosineLog(5), 0), + (basis.EvalRaisedCosineLog(5) + basis.EvalMSpline(5), 0), + (basis.ConvMSpline(5, window_size=3), 6), + (basis.ConvBSpline(5, window_size=3), 6), ( - basis.CyclicBSplineBasis( - 5, mode="conv", window_size=3, predictor_causality="acausal" + basis.ConvCyclicBSpline( + 5, window_size=3, conv_kwargs=dict(predictor_causality="acausal") ), 4, ), ( - basis.OrthExponentialBasis( - 5, decay_rates=np.linspace(0.1, 1, 5), mode="conv", window_size=7 + basis.ConvOrthExponential( + 5, decay_rates=np.linspace(0.1, 1, 5), window_size=7 ), 14, ), - (basis.RaisedCosineBasisLinear(5, mode="conv", window_size=3), 6), - (basis.RaisedCosineBasisLog(5, mode="conv", window_size=3), 6), + (basis.ConvRaisedCosineLinear(5, window_size=3), 6), + (basis.ConvRaisedCosineLog(5, window_size=3), 6), ( - basis.RaisedCosineBasisLog(5, mode="conv", window_size=3) + basis.ConvRaisedCosineLog(5, window_size=3) + basis.EvalMSpline(5), 6, ), ( - basis.RaisedCosineBasisLog(5, mode="conv", window_size=3) + basis.ConvRaisedCosineLog(5, window_size=3) * basis.EvalMSpline(5), 6, ), @@ -166,7 +166,7 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) - bas = basis.TransformerBasis(bas) + bas = TransformerBasis(bas) # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index f6444bc1..e64072a7 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -213,7 +213,7 @@ def test_least_square_correctness(): # set up problem dimensionality ws, n_neurons_receiver, n_neurons_sender, n_basis_funcs = 100, 1, 2, 10 # evaluate a basis - _, eval_basis = basis.RaisedCosineBasisLog(n_basis_funcs).evaluate_on_grid(ws) + _, eval_basis = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(ws) # generate random weights to define filters weights = np.random.normal( size=(n_neurons_receiver, n_neurons_sender, n_basis_funcs) From 3f1094251a387cad9c0b7cd2a605a84451e3a656 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 15:43:54 -0500 Subject: [PATCH 051/109] linted --- tests/test_identifiability_constraints.py | 2 +- tests/test_pipeline.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ea7a1875..e09742fb 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from nemos.basis.basis import EvalBSpline, ConvBSpline, EvalRaisedCosineLinear +from nemos.basis.basis import ConvBSpline, EvalBSpline, EvalRaisedCosineLinear from nemos.identifiability_constraints import ( _WARN_FLOAT32_MESSAGE, _find_drop_column, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 10a1ae28..a12bafff 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -8,6 +8,7 @@ from nemos import basis from nemos.basis._basis import TransformerBasis + @pytest.mark.parametrize( "bas", [ @@ -146,13 +147,11 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( (basis.ConvRaisedCosineLinear(5, window_size=3), 6), (basis.ConvRaisedCosineLog(5, window_size=3), 6), ( - basis.ConvRaisedCosineLog(5, window_size=3) - + basis.EvalMSpline(5), + basis.ConvRaisedCosineLog(5, window_size=3) + basis.EvalMSpline(5), 6, ), ( - basis.ConvRaisedCosineLog(5, window_size=3) - * basis.EvalMSpline(5), + basis.ConvRaisedCosineLog(5, window_size=3) * basis.EvalMSpline(5), 6, ), ], From 17e7d0628b64fe693b80da46caeae1151f78e419 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 16:48:11 -0500 Subject: [PATCH 052/109] fix basis naming --- docs/background/plot_01_1D_basis_function.md | 24 +++++----- docs/background/plot_02_ND_basis_function.md | 13 +++--- docs/background/plot_03_1D_convolution.md | 31 +++++++------ docs/developers_notes/04-basis_module.md | 2 +- docs/how_to_guide/plot_02_glm_demo.md | 2 +- docs/how_to_guide/plot_04_batch_glm.md | 2 +- .../plot_05_sklearn_pipeline_cv_demo.md | 44 +++++++++---------- docs/how_to_guide/plot_06_glm_pytree.md | 6 +-- docs/quickstart.md | 6 +-- docs/tutorials/plot_02_head_direction.md | 8 ++-- docs/tutorials/plot_03_grid_cells.md | 4 +- docs/tutorials/plot_04_v1_cells.md | 2 +- docs/tutorials/plot_05_place_cells.md | 12 ++--- docs/tutorials/plot_06_calcium_imaging.md | 4 +- src/nemos/_documentation_utils/plotting.py | 4 +- 15 files changed, 83 insertions(+), 81 deletions(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 4d823717..97fedd74 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -45,7 +45,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`MSplineBasis`](nemos.basis.MSplineBasis). +We'll start by defining a 1D basis function object of the type [`EvalMSpline`](nemos.basis.EvalMSpline). The hyperparameters required to initialize this class are: - The number of basis functions, which should be a positive integer. @@ -63,7 +63,7 @@ order = 4 n_basis = 10 # Define the 1D basis function object -bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order) +bspline = nmo.basis.EvalBSpline(n_basis_funcs=n_basis, order=order) ``` ## Evaluating a Basis @@ -119,7 +119,7 @@ parameter at initialization. Evaluating the basis at any sample outside the boun ```{code-cell} ipython3 -bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) +bspline_range = nmo.basis.EvalBSpline(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) print("Evaluated basis:") # 0.5 is within the support, 0.1 is outside the support @@ -140,20 +140,18 @@ axs[1].set_title("bounds=[0.2, 0.8]") plt.tight_layout() ``` -## Basis `mode` -In constructing features, [`Basis`](nemos.basis.Basis) objects can be used in two modalities: `"eval"` for evaluate or `"conv"` -for convolve. These two modalities change the behavior of the [`compute_features`](nemos.basis.Basis.compute_features) method of [`Basis`](nemos.basis.Basis), in particular, +## Feature Computation +The bases in the module `nemos.basis` can be classified in two categories: -- If a basis is in mode `"eval"`, then [`compute_features`](nemos.basis.Basis.compute_features) simply returns the evaluated basis. -- If a basis is in mode `"conv"`, then [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis - with `window_size` specified by the user. +- **Evaluation Bases**: Objects for which [`compute_features`](nemos.basis.Basis.compute_features) that returns the evaluated basis. This means that the basis are applying a non-linear transformation of the input. The class name for this kind of bases starts with "Eval", e.g. "EvalBSpline". +- **Convolution Bases**: Objects for which [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis elements with `window_size` specified by the user. The class name for this kind of bases starts with "Conv", e.g. "ConvBSpline". Let's see how this two modalities operate. ```{code-cell} ipython3 -eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval") -conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100) +eval_mode = nmo.basis.EvalMSpline(n_basis_funcs=n_basis) +conv_mode = nmo.basis.ConvMSpline(n_basis_funcs=n_basis, window_size=100) # define an input angles = np.linspace(0, np.pi*4, 201) @@ -228,8 +226,8 @@ evaluate a log-spaced cosine raised function basis. ```{code-cell} ipython3 -# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) +# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter +raised_cosine_log = nmo.basis.EvalRaisedCosineLog(n_basis_funcs=10, width=1.5, time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/plot_02_ND_basis_function.md index dd2ca2c9..5cd4c1ec 100644 --- a/docs/background/plot_02_ND_basis_function.md +++ b/docs/background/plot_02_ND_basis_function.md @@ -127,10 +127,11 @@ Here, we simply add two basis objects, `a_basis` and `b_basis`, together to defi ```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np +import nemos as nmo # Define 1D basis objects -a_basis = nemos.basis.basis.EvalMSpline(n_basis_funcs=15, order=3) -b_basis = nemos.basis.basis.RaisedCosineBasisLog(n_basis_funcs=14) +a_basis = nmo.basis.EvalMSpline(n_basis_funcs=15, order=3) +b_basis = nmo.basis.EvalRaisedCosineLog(n_basis_funcs=14) # Define the 2D additive basis object additive_basis = a_basis + b_basis @@ -279,7 +280,7 @@ for i, j in element_pairs: # select & plot the corresponding product basis element k = i * b_basis.n_basis_funcs + j axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap='Blues') - axs[cc, 2].set_title(f"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b') + axs[cc, 2].set_title(fr"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b') axs[cc, 2].set_xlabel('x-coord') axs[cc, 2].set_ylabel('y-coord') axs[cc, 2].set_aspect("equal") @@ -339,9 +340,9 @@ will output a $K^N \times T$ matrix. T = 10 n_basis = 8 -a_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -b_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) -c_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis) +a_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) +b_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) +c_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md index fa7335e5..38ae9e7f 100644 --- a/docs/background/plot_03_1D_convolution.md +++ b/docs/background/plot_03_1D_convolution.md @@ -82,7 +82,7 @@ see [jax.numpy.convolve](https://jax.readthedocs.io/en/latest/_autosummary/jax.n ```{code-cell} ipython3 # create three filters -basis_obj = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=3) +basis_obj = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=3) _, w = basis_obj.evaluate_on_grid(ws) plt.plot(w) @@ -187,26 +187,31 @@ if path.exists(): ``` ## Convolve using [`Basis.compute_features`](nemos.basis.Basis.compute_features) -All the parameters of [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) can be passed to a [`Basis`](nemos.basis.Basis) directly -at initialization. Note that you must set `mode == "conv"` to actually perform convolution -with [`Basis.compute_features`](nemos.basis.Basis.compute_features). Let's see how we can get the same results through [`Basis`](nemos.basis.Basis). + +Every basis in the `nemos.basis` module whose class name starts with "Conv" will perform a 1D convolution over the +provided input when the `compute_features` method is called. The basis elements will be used as filters for the +convolution. + +All the parameters of [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) can be passed to the object directly at initialization. +Let's see how we can get the same results through [`Basis`](nemos.basis.Basis). ```{code-cell} ipython3 # define basis with different predictor causality -causal_basis = nmo.basis.RaisedCosineBasisLinear( - n_basis_funcs=3, mode="conv", window_size=ws, - predictor_causality="causal" +causal_basis = nmo.basis.ConvRaisedCosineLinear( + n_basis_funcs=3, window_size=ws, + conv_kwargs=dict(predictor_causality="causal") + ) -acausal_basis = nmo.basis.RaisedCosineBasisLinear( - n_basis_funcs=3, mode="conv", window_size=ws, - predictor_causality="acausal" +acausal_basis = nmo.basis.ConvRaisedCosineLinear( + n_basis_funcs=3, window_size=ws, + conv_kwargs=dict(predictor_causality="acausal") ) -anticausal_basis = nmo.basis.RaisedCosineBasisLinear( - n_basis_funcs=3, mode="conv", window_size=ws, - predictor_causality="anti-causal" +anticausal_basis = nmo.basis.ConvRaisedCosineLinear( + n_basis_funcs=3, window_size=ws, + conv_kwargs=dict(predictor_causality="anti-causal") ) # compute convolutions diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index d0f939a0..c0bfd2a8 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -21,7 +21,7 @@ Abstract Class Basis │ ├─ Concrete Subclass RaisedCosineBasisLinear │ │ -│ └─ Concrete Subclass RaisedCosineBasisLog +│ └─ Concrete Subclass EvalRaisedCosineLog │ └─ Concrete Subclass OrthExponentialBasis ``` diff --git a/docs/how_to_guide/plot_02_glm_demo.md b/docs/how_to_guide/plot_02_glm_demo.md index fe18d3ec..d89d10c9 100644 --- a/docs/how_to_guide/plot_02_glm_demo.md +++ b/docs/how_to_guide/plot_02_glm_demo.md @@ -329,7 +329,7 @@ coupling_filter_bank *= 0.8 # define a basis function n_basis_funcs = 20 -basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs) +basis = nmo.basis.EvalRaisedCosineLog(n_basis_funcs) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/docs/how_to_guide/plot_04_batch_glm.md b/docs/how_to_guide/plot_04_batch_glm.md index 707e90da..84e58ad6 100644 --- a/docs/how_to_guide/plot_04_batch_glm.md +++ b/docs/how_to_guide/plot_04_batch_glm.md @@ -106,7 +106,7 @@ Here we instantiate the basis. `ws` is 40 time bins. It corresponds to a 200 ms ```{code-cell} ipython3 ws = 40 -basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=ws) +basis = nmo.basis.ConvRaisedCosineLog(5, window_size=ws) ``` ## Batch definition diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index 7684857c..60f7f4b5 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -152,31 +152,29 @@ sns.despine(ax=ax) ### Converting NeMoS `Basis` to a transformer In order to use NeMoS [`Basis`](nemos.basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis.TransformerBasis) can be done either using the constructor directly or with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer): ```{code-cell} ipython3 -bas = nmo.basis.RaisedCosineBasisLinear(5, mode="conv", window_size=5) -# these two ways of creating the TransformerBasis are equivalent -trans_bas_a = nmo.basis.TransformerBasis(bas) -trans_bas_b = bas.to_transformer() +bas = nmo.basis.ConvRaisedCosineLinear(5, window_size=5) +trans_bas = bas.to_transformer() ``` [`TransformerBasis`](nemos.basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis.Basis) object's attributes: ```{code-cell} ipython3 -print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs) +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` We can also set attributes of the underlying [`Basis`](nemos.basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis.Basis), and neither does changing the original [`Basis`](nemos.basis.Basis) change [`TransformerBasis`](nemos.basis.TransformerBasis) we created: ```{code-cell} ipython3 -trans_bas_a.n_basis_funcs = 10 +trans_bas.n_basis_funcs = 10 bas.n_basis_funcs = 100 -print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs) +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` ### Creating and fitting a pipeline @@ -190,7 +188,7 @@ pipeline = Pipeline( [ ( "transformerbasis", - nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6)), + nmo.basis.EvalRaisedCosineLinear(6).to_transformer(), ), ( "glm", @@ -326,7 +324,7 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds)) coeffs = {} # initialize basis and model -basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6)) +basis = nmo.basis.TransformerBasis(nmo.basis.EvalRaisedCosineLinear(6)) model = nmo.glm.GLM(regularizer="Ridge") # loop over combinations @@ -453,12 +451,12 @@ Here we include `transformerbasis___basis` in the parameter grid to try differen param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis___basis=( - nmo.basis.RaisedCosineBasisLinear(5), - nmo.basis.RaisedCosineBasisLinear(10), - nmo.basis.RaisedCosineBasisLog(5), - nmo.basis.RaisedCosineBasisLog(10), - nmo.basis.MSplineBasis(5), - nmo.basis.MSplineBasis(10), + nmo.basis.EvalRaisedCosineLinear(5), + nmo.basis.EvalRaisedCosineLinear(10), + nmo.basis.EvalRaisedCosineLog(5), + nmo.basis.EvalRaisedCosineLog(10), + nmo.basis.EvalMSpline(5), + nmo.basis.EvalMSpline(10), ), ) ``` @@ -498,7 +496,7 @@ cvdf_wide = cvdf.pivot( doc_plots.plot_heatmap_cv_results(cvdf_wide) ``` -As shown in the table, the model with the highest score, highlighted in blue, used a RaisedCosineBasisLinear basis (as used above), which appears to be a suitable choice for our toy data. +As shown in the table, the model with the highest score, highlighted in blue, used a EvalRaisedCosineLinear basis (as used above), which appears to be a suitable choice for our toy data. We can confirm that by plotting the firing rate predictions: @@ -539,12 +537,12 @@ param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), transformerbasis___basis=( - nmo.basis.RaisedCosineBasisLinear(5), - nmo.basis.RaisedCosineBasisLinear(10), - nmo.basis.RaisedCosineBasisLog(5), - nmo.basis.RaisedCosineBasisLog(10), - nmo.basis.MSplineBasis(5), - nmo.basis.MSplineBasis(10), + nmo.basis.EvalRaisedCosineLinear(5), + nmo.basis.EvalRaisedCosineLinear(10), + nmo.basis.EvalRaisedCosineLog(5), + nmo.basis.EvalRaisedCosineLog(10), + nmo.basis.EvalMSpline(5), + nmo.basis.EvalMSpline(10), ), ) ``` diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 6945460e..9dbd7ddc 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -274,7 +274,7 @@ Okay, let's use unit number 7. Now let's set up our design matrix. First, let's fit the head direction by itself. Head direction is a circular variable (pi and -pi are adjacent to each other), so we need to use a basis that has this property as well. -[`CyclicBSplineBasis`](nemos.basis.CyclicBSplineBasis) is one such basis. +[`EvalCyclicBSpline`](nemos.basis.EvalCyclicBSpline) is one such basis. Let's create our basis and then arrange our data properly. @@ -283,7 +283,7 @@ Let's create our basis and then arrange our data properly. unit_no = 7 spikes = nwb['units'][unit_no] -basis = nmo.basis.CyclicBSplineBasis(10, order=5) +basis = nmo.basis.EvalCyclicBSpline(10, order=5) x = np.linspace(-np.pi, np.pi, 100) plt.figure() plt.plot(x, basis(x)) @@ -351,7 +351,7 @@ our data similarly. ```{code-cell} ipython3 -pos_basis = nmo.basis.RaisedCosineBasisLinear(10) * nmo.basis.RaisedCosineBasisLinear(10) +pos_basis = nmo.basis.EvalRaisedCosineLinear(10) * nmo.basis.EvalRaisedCosineLinear(10) spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data) X['spatial_position'] = pos_basis(*spatial_pos.values.T) diff --git a/docs/quickstart.md b/docs/quickstart.md index 062bdb25..8420d078 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -165,7 +165,7 @@ you need to specify the number of basis functions. For some `basis` objects, add >>> import nemos as nmo >>> n_basis_funcs = 10 ->>> basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs) +>>> basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs) ``` @@ -205,7 +205,7 @@ number of sample points. >>> n_basis_funcs = 10 >>> # define a filter bank of 10 basis function, 200 samples long. ->>> basis = nmo.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) +>>> basis = nmo.basis.ConvBSpline(n_basis_funcs, window_size=200) ``` @@ -350,7 +350,7 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa >>> upsampled_head_dir = head_dir.bin_average(0.01) >>> # create your features ->>> X = nmo.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) +>>> X = nmo.basis.EvalCyclicBSpline(10).compute_features(upsampled_head_dir) >>> # add a neuron axis and fit model >>> model = nmo.glm.GLM().fit(X, counts) diff --git a/docs/tutorials/plot_02_head_direction.md b/docs/tutorials/plot_02_head_direction.md index cf408662..e4402053 100644 --- a/docs/tutorials/plot_02_head_direction.md +++ b/docs/tutorials/plot_02_head_direction.md @@ -419,8 +419,8 @@ the cost of adding additional parameters. ```{code-cell} ipython3 # a basis object can be instantiated in "conv" mode for convolving the input. -basis = nmo.basis.RaisedCosineBasisLog( - n_basis_funcs=8, mode="conv", window_size=window_size +basis = nmo.basis.ConvRaisedCosineLog( + n_basis_funcs=8, window_size=window_size ) # `basis.evaluate_on_grid` is a convenience method to view all basis functions @@ -600,8 +600,8 @@ to get an array of predictors of shape, `(num_time_points, num_neurons * num_bas ```{code-cell} ipython3 # re-initialize basis -basis = nmo.basis.RaisedCosineBasisLog( - n_basis_funcs=8, mode="conv", window_size=window_size +basis = nmo.basis.ConvRaisedCosineLog( + n_basis_funcs=8, window_size=window_size ) # convolve all the neurons diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index 2e1588d4..a7f767ef 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -146,9 +146,9 @@ We can define a two-dimensional basis for position by multiplying two one-dimens see [here](../../background/plot_02_ND_basis_function) for more details. ```{code-cell} ipython3 -basis_2d = nmo.basis.RaisedCosineBasisLinear( +basis_2d = nmo.basis.EvalRaisedCosineLinear( n_basis_funcs=10 -) * nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10) +) * nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=10) ``` Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid. diff --git a/docs/tutorials/plot_04_v1_cells.md b/docs/tutorials/plot_04_v1_cells.md index bd91d5ef..45e8efa8 100644 --- a/docs/tutorials/plot_04_v1_cells.md +++ b/docs/tutorials/plot_04_v1_cells.md @@ -345,7 +345,7 @@ GLM: ```{code-cell} ipython3 window_size = 100 -basis = nmo.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size) +basis = nmo.basis.ConvRaisedCosineLog(8 window_size=window_size) convolved_input = basis.compute_features(filtered_stimulus) ``` diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index 82256887..87e97d88 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -335,15 +335,15 @@ print(count.shape) For each feature, we will use a different set of basis : - - position : [`MSplineBasis`](nemos.basis.MSplineBasis) - - theta phase : [`CyclicBSplineBasis`](nemos.basis.CyclicBSplineBasis) - - speed : [`MSplineBasis`](nemos.basis.MSplineBasis) + - position : [`EvalMSpline`](nemos.basis.EvalMSpline) + - theta phase : [`EvalCyclicBSpline`](nemos.basis.EvalCyclicBSpline) + - speed : [`EvalMSpline`](nemos.basis.EvalMSpline) ```{code-cell} ipython3 -position_basis = nmo.basis.MSplineBasis(n_basis_funcs=10) -phase_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12) -speed_basis = nmo.basis.MSplineBasis(n_basis_funcs=15) +position_basis = nmo.basis.EvalMSpline(n_basis_funcs=10) +phase_basis = nmo.basis.EvalCyclicBSpline(n_basis_funcs=12) +speed_basis = nmo.basis.EvalMSpline(n_basis_funcs=15) ``` In addition, we will consider position and phase to be a joint variable. In NeMoS, we can combine basis by multiplying them and adding them. In this case the final basis object for our model can be made in one line : diff --git a/docs/tutorials/plot_06_calcium_imaging.md b/docs/tutorials/plot_06_calcium_imaging.md index 1dbfc46e..3eeed024 100644 --- a/docs/tutorials/plot_06_calcium_imaging.md +++ b/docs/tutorials/plot_06_calcium_imaging.md @@ -180,8 +180,8 @@ We can combine the two bases. ```{code-cell} ipython3 -heading_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12) -coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=10) +heading_basis = nmo.basis.EvalCyclicBSpline(n_basis_funcs=12) +coupling_basis = nmo.basis.ConvRaisedCosineLog(3, window_size=10) ``` We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an [`AdditiveBasis`](nemos.basis.AdditiveBasis) object). diff --git a/src/nemos/_documentation_utils/plotting.py b/src/nemos/_documentation_utils/plotting.py index 086ccb2c..58a94bc0 100644 --- a/src/nemos/_documentation_utils/plotting.py +++ b/src/nemos/_documentation_utils/plotting.py @@ -33,7 +33,7 @@ from matplotlib.patches import Rectangle from numpy.typing import NDArray -from ..basis import RaisedCosineBasisLog +from ..basis import EvalRaisedCosineLog warnings.warn( "plotting functions contained within `_documentation_utils` are intended for nemos's documentation. " @@ -682,7 +682,7 @@ def plot_rates_and_smoothed_counts( def plot_basis(n_basis_funcs=8, window_size_sec=0.8): fig = plt.figure() - basis = RaisedCosineBasisLog(n_basis_funcs=n_basis_funcs) + basis = EvalRaisedCosineLog(n_basis_funcs=n_basis_funcs) time, basis_kernels = basis.evaluate_on_grid(1000) time *= window_size_sec plt.plot(time, basis_kernels) From 17a9b932c7cddcb6cd73cd30b36785213b4dd312 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 16:56:28 -0500 Subject: [PATCH 053/109] fixed bug --- docs/tutorials/plot_04_v1_cells.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/plot_04_v1_cells.md b/docs/tutorials/plot_04_v1_cells.md index 45e8efa8..aa479928 100644 --- a/docs/tutorials/plot_04_v1_cells.md +++ b/docs/tutorials/plot_04_v1_cells.md @@ -345,7 +345,7 @@ GLM: ```{code-cell} ipython3 window_size = 100 -basis = nmo.basis.ConvRaisedCosineLog(8 window_size=window_size) +basis = nmo.basis.ConvRaisedCosineLog(8, window_size=window_size) convolved_input = basis.compute_features(filtered_stimulus) ``` From 93f185857d0b99135c506749ffbf10bee5404328 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 17:25:56 -0500 Subject: [PATCH 054/109] fix double plotting --- docs/tutorials/plot_01_current_injection.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/plot_01_current_injection.md b/docs/tutorials/plot_01_current_injection.md index 70ebac16..5d0df7f5 100644 --- a/docs/tutorials/plot_01_current_injection.md +++ b/docs/tutorials/plot_01_current_injection.md @@ -357,7 +357,7 @@ if you are interested. ::: ```{code-cell} ipython3 -doc_plots.current_injection_plot(current, spikes, firing_rate) +doc_plots.current_injection_plot(current, spikes, firing_rate); ``` So now that we can view the details of our experiment a little more clearly, From 1aaea3be15c8c752fc16fe60acf4d190fc5a19bd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 20:49:15 -0500 Subject: [PATCH 055/109] move transformer basis out --- src/nemos/basis/__init__.py | 1 + src/nemos/basis/_basis.py | 435 +------------------------- src/nemos/basis/_basis_mixin.py | 35 +++ src/nemos/basis/_transformer_basis.py | 393 +++++++++++++++++++++++ src/nemos/basis/basis.py | 44 ++- tests/test_basis.py | 60 ++-- 6 files changed, 492 insertions(+), 476 deletions(-) create mode 100644 src/nemos/basis/_transformer_basis.py diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index fa3fc70d..3a08ad2e 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -1,6 +1,7 @@ from ._basis import AdditiveBasis, Basis, MultiplicativeBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._spline_basis import BSplineBasis +from ._transformer_basis import TransformerBasis from .basis import ( ConvBSpline, ConvCyclicBSpline, diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 6f9ca14b..c37cec78 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -16,6 +16,7 @@ from ..typing import FeatureMatrix from ..utils import row_wise_kron from ..validation import check_fraction_valid_samples +from ._basis_mixin import BasisTransformerMixin def add_docstring(method_name, cls=None): @@ -112,21 +113,9 @@ class Basis(Base, abc.ABC): mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. Raises ------ @@ -550,35 +539,6 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis: result = result * self return result - def to_transformer(self) -> TransformerBasis: - """ - Turn the Basis into a TransformerBasis for use with scikit-learn. - - Examples - -------- - Jointly cross-validating basis and GLM parameters with scikit-learn. - - >>> import nemos as nmo - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> # load some data - >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.EvalRaisedCosineLinear(10).to_transformer() - >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) - >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) - >>> param_grid = dict( - ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), - ... basis__n_basis_funcs=(3, 5, 10, 20, 100), - ... ) - >>> gridsearch = GridSearchCV( - ... pipeline, - ... param_grid=param_grid, - ... cv=5, - ... ) - >>> gridsearch = gridsearch.fit(X, y) - """ - return TransformerBasis(copy.deepcopy(self)) - def _get_feature_slicing( self, n_inputs: Optional[tuple] = None, @@ -860,396 +820,11 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: return self -class TransformerBasis: - """Basis as ``scikit-learn`` transformers. - - This class abstracts the underlying basis function details, offering methods - similar to scikit-learn's transformers but specifically designed for basis - transformations. It supports fitting to data (calculating any necessary parameters - of the basis functions), transforming data (applying the basis functions to - data), and both fitting and transforming in one step. - - ``TransformerBasis``, unlike ``Basis``, is compatible with scikit-learn pipelining and - model selection, enabling the cross-validation of the basis type and parameters, - for example ``n_basis_funcs``. See the example section below. - - Parameters - ---------- - basis : - A concrete subclass of ``Basis``. - - Examples - -------- - >>> from nemos.basis import EvalBSpline - >>> from nemos.basis._basis import TransformerBasis - >>> from nemos.glm import GLM - >>> from sklearn.pipeline import Pipeline - >>> from sklearn.model_selection import GridSearchCV - >>> import numpy as np - >>> np.random.seed(123) - - >>> # Generate data - >>> num_samples, num_features = 10000, 1 - >>> x = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalBSpline(10) - >>> features = basis.compute_features(x) # basis transformed time series - >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights - >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts - - >>> # transformer can be used in pipelines - >>> transformer = TransformerBasis(basis) - >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) - >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API - >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas - >>> # TransformerBasis parameter can be cross-validated. - >>> # 5-fold cross-validate the number of basis - >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) - >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) - >>> grid_cv = grid_cv.fit(x[:, None], y) - >>> print("Cross-validated number of basis:", grid_cv.best_params_) - Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} - """ - - def __init__(self, basis: Basis): - self._basis = copy.deepcopy(basis) - - @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack impute without using transpose. - - Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. - - Parameters - ---------- - X: - The inputs horizontally stacked. - - Returns - ------- - : - A tuple of each individual input. - - """ - return (X[:, k] for k in range(X.shape[1])) - - def fit(self, X: FeatureMatrix, y=None): - """ - Compute the convolutional kernels. - - If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. - - Parameters - ---------- - X : - The data to fit the basis functions to, shape (num_samples, num_input). - y : ignored - Not used, present for API consistency by convention. - - Returns - ------- - self : - The transformer object. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis - - >>> # Example input - >>> X = np.random.normal(size=(100, 2)) - - >>> # Define and fit tranformation basis - >>> basis = EvalMSpline(10) - >>> transformer = TransformerBasis(basis) - >>> transformer_fitted = transformer.fit(X) - """ - self._basis._set_kernel() - return self - - def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: - """ - Transform the data using the fitted basis functions. - - Parameters - ---------- - X : - The data to transform using the basis functions, shape (num_samples, num_input). - y : - Not used, present for API consistency by convention. - - Returns - ------- - : - The data transformed by the basis functions. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis - - >>> # Example input - >>> X = np.random.normal(size=(10000, 2)) - - >>> # Define and fit tranformation basis - >>> basis = EvalMSpline(10, mode="conv", window_size=200) - >>> transformer = TransformerBasis(basis) - >>> # Before calling `fit` the convolution kernel is not set - >>> transformer.kernel_ - - >>> transformer_fitted = transformer.fit(X) - >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) - >>> transformer_fitted.kernel_.shape - (200, 10) - - >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) - """ - # transpose does not work with pynapple - # can't use func(*X.T) to unwrap - - return self._basis._compute_features(*self._unpack_inputs(X)) - - def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: - """ - Compute the kernels and the features. - - This method is a convenience that combines fit and transform into - one step. - - Parameters - ---------- - X : - The data to fit the basis functions to and then transform. - y : - Not used, present for API consistency by convention. - - Returns - ------- - array-like - The data transformed by the basis functions, after fitting the basis - functions to the data. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis - - >>> # Example input - >>> X = np.random.normal(size=(100, 1)) - - >>> # Define tranformation basis - >>> basis = EvalMSpline(10) - >>> transformer = TransformerBasis(basis) - - >>> # Fit and transform basis - >>> feature_transformed = transformer.fit_transform(X) - """ - return self._basis.compute_features(*self._unpack_inputs(X)) - - def __getstate__(self): - """ - Explicitly define how to pickle TransformerBasis object. - - See https://docs.python.org/3/library/pickle.html#object.__getstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - return {"_basis": self._basis} - - def __setstate__(self, state): - """ - Define how to populate the object's state when unpickling. - - Note that during unpickling a new object is created without calling __init__. - Needed to avoid infinite recursion in __getattr__ when unpickling. - - See https://docs.python.org/3/library/pickle.html#object.__setstate__ - and https://docs.python.org/3/library/pickle.html#pickle-state - """ - self._basis = state["_basis"] - - def __getattr__(self, name: str): - """ - Enable easy access to attributes of the underlying Basis object. - - Examples - -------- - >>> from nemos import basis - >>> bas = basis.EvalRaisedCosineLinear(5) - >>> trans_bas = basis.TransformerBasis(bas) - >>> bas.n_basis_funcs - 5 - >>> trans_bas.n_basis_funcs - 5 - """ - return getattr(self._basis, name) - - def __setattr__(self, name: str, value) -> None: - r""" - Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. - - Setting any other attribute is not allowed. - - Returns - ------- - None - - Raises - ------ - ValueError - If the attribute being set is not ``_basis`` or an attribute of ``_basis``. - - Examples - -------- - >>> import nemos as nmo - >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.EvalMSpline(10)) - >>> # allowed - >>> trans_bas._basis = nmo.basis.EvalBSpline(10) - >>> # allowed - >>> trans_bas.n_basis_funcs = 20 - >>> # not allowed - >>> try: - ... trans_bas.random_attribute_name = "some value" - ... except ValueError as e: - ... print(repr(e)) - ValueError('Only setting _basis or existing attributes of _basis is allowed.') - """ - # allow self._basis = basis - if name == "_basis": - super().__setattr__(name, value) - # allow changing existing attributes of self._basis - elif hasattr(self._basis, name): - setattr(self._basis, name, value) - # don't allow setting any other attribute - else: - raise ValueError( - "Only setting _basis or existing attributes of _basis is allowed." - ) - - def __sklearn_clone__(self) -> TransformerBasis: - """ - Customize how TransformerBasis objects are cloned when used with sklearn.model_selection. - - By default, scikit-learn tries to clone the object by calling __init__ using the output of get_params, - which fails in our case. - - For more info: https://scikit-learn.org/stable/developers/develop.html#cloning - """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) - cloned_obj._basis.kernel_ = None - return cloned_obj - - def set_params(self, **parameters) -> TransformerBasis: - """ - Set TransformerBasis parameters. - - When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly - or the parameters of the underlying Basis, but not both. - - Examples - -------- - >>> from nemos.basis import EvalBSpline, EvalMSpline, TransformerBasis - >>> basis = EvalMSpline(10) - >>> transformer_basis = TransformerBasis(basis=basis) - - >>> # setting parameters of _basis is allowed - >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) - 8 - >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=EvalBSpline(10))._basis)) - - >>> # mixing is not allowed, this will raise an exception - >>> try: - ... transformer_basis.set_params(_basis=EvalBSpline(10), n_basis_funcs=2) - ... except ValueError as e: - ... print(repr(e)) - ValueError('Set either new _basis object or parameters for existing _basis, not both.') - """ - new_basis = parameters.pop("_basis", None) - if new_basis is not None: - self._basis = new_basis - if len(parameters) > 0: - raise ValueError( - "Set either new _basis object or parameters for existing _basis, not both." - ) - else: - self._basis = self._basis.set_params(**parameters) - - return self - - def get_params(self, deep: bool = True) -> dict: - """Extend the dict of parameters from the underlying Basis with _basis.""" - return {"_basis": self._basis, **self._basis.get_params(deep)} - - def __dir__(self) -> list[str]: - """Extend the list of properties of methods with the ones from the underlying Basis.""" - return super().__dir__() + self._basis.__dir__() - - def __add__(self, other: TransformerBasis) -> TransformerBasis: - """ - Add two TransformerBasis objects. - - Parameters - ---------- - other - The other TransformerBasis object to add. - - Returns - ------- - : TransformerBasis - The resulting Basis object. - """ - return TransformerBasis(self._basis + other._basis) - - def __mul__(self, other: TransformerBasis) -> TransformerBasis: - """ - Multiply two TransformerBasis objects. - - Parameters - ---------- - other - The other TransformerBasis object to multiply. - - Returns - ------- - : - The resulting Basis object. - """ - return TransformerBasis(self._basis * other._basis) - - def __pow__(self, exponent: int) -> TransformerBasis: - """Exponentiation of a TransformerBasis object. - - Define the power of a basis by repeatedly applying the method __mul__. - The exponent must be a positive integer. - - Parameters - ---------- - exponent : - Positive integer exponent - - Returns - ------- - : - The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. - - Raises - ------ - TypeError - If the provided exponent is not an integer. - ValueError - If the integer is zero or negative. - """ - # errors are handled by Basis.__pow__ - return TransformerBasis(self._basis**exponent) - - add_docstring_additive = partial(add_docstring, cls=Basis) add_docstring_multiplicative = partial(add_docstring, cls=Basis) -class AdditiveBasis(Basis): +class AdditiveBasis(Basis, BasisTransformerMixin): """ Class representing the addition of two Basis objects. @@ -1294,7 +869,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._label = "(" + basis1.label + " + " + basis2.label + ")" self._basis1 = basis1 self._basis2 = basis2 - return + BasisTransformerMixin.__init__(self) def _set_num_output_features(self, *xi: NDArray) -> Basis: self._n_basis_input = ( @@ -1592,7 +1167,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) -class MultiplicativeBasis(Basis): +class MultiplicativeBasis(Basis, BasisTransformerMixin): """ Class representing the multiplication (external product) of two Basis objects. @@ -1637,7 +1212,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._label = "(" + basis1.label + " * " + basis2.label + ")" self._basis1 = basis1 self._basis2 = basis2 - return + BasisTransformerMixin.__init__(self) def _check_n_basis_min(self) -> None: pass diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index ab95ab21..8ac42af1 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -1,5 +1,6 @@ """Mixin classes for basis.""" +import copy import inspect from typing import Optional, Tuple, Union @@ -7,6 +8,7 @@ from numpy.typing import ArrayLike from ..convolve import create_convolutional_predictor +from ._transformer_basis import TransformerBasis class EvalBasisMixin: @@ -220,3 +222,36 @@ def _check_convolution_kwargs(conv_kwargs: dict): f"Unrecognized keyword arguments: {invalid}. " f"Allowed convolution keyword arguments are: {convolve_configs}." ) + + +class BasisTransformerMixin: + """Mixin class for constructing a transformer""" + + def to_transformer(self) -> TransformerBasis: + """ + Turn the Basis into a TransformerBasis for use with scikit-learn. + + Examples + -------- + Jointly cross-validating basis and GLM parameters with scikit-learn. + + >>> import nemos as nmo + >>> from sklearn.pipeline import Pipeline + >>> from sklearn.model_selection import GridSearchCV + >>> # load some data + >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) + >>> basis = nmo.basis.EvalRaisedCosineLinear(10).to_transformer() + >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) + >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) + >>> param_grid = dict( + ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), + ... basis__n_basis_funcs=(3, 5, 10, 20, 100), + ... ) + >>> gridsearch = GridSearchCV( + ... pipeline, + ... param_grid=param_grid, + ... cv=5, + ... ) + >>> gridsearch = gridsearch.fit(X, y) + """ + return TransformerBasis(copy.deepcopy(self)) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py new file mode 100644 index 00000000..624fb2c4 --- /dev/null +++ b/src/nemos/basis/_transformer_basis.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +from ..typing import FeatureMatrix + +if TYPE_CHECKING: + from ._basis import Basis + + +class TransformerBasis: + """Basis as ``scikit-learn`` transformers. + + This class abstracts the underlying basis function details, offering methods + similar to scikit-learn's transformers but specifically designed for basis + transformations. It supports fitting to data (calculating any necessary parameters + of the basis functions), transforming data (applying the basis functions to + data), and both fitting and transforming in one step. + + ``TransformerBasis``, unlike ``Basis``, is compatible with scikit-learn pipelining and + model selection, enabling the cross-validation of the basis type and parameters, + for example ``n_basis_funcs``. See the example section below. + + Parameters + ---------- + basis : + A concrete subclass of ``Basis``. + + Examples + -------- + >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import TransformerBasis + >>> from nemos.glm import GLM + >>> from sklearn.pipeline import Pipeline + >>> from sklearn.model_selection import GridSearchCV + >>> import numpy as np + >>> np.random.seed(123) + + >>> # Generate data + >>> num_samples, num_features = 10000, 1 + >>> x = np.random.normal(size=(num_samples, )) # raw time series + >>> basis = EvalBSpline(10) + >>> features = basis.compute_features(x) # basis transformed time series + >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights + >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts + + >>> # transformer can be used in pipelines + >>> transformer = TransformerBasis(basis) + >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) + >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API + >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas + >>> # TransformerBasis parameter can be cross-validated. + >>> # 5-fold cross-validate the number of basis + >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) + >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) + >>> grid_cv = grid_cv.fit(x[:, None], y) + >>> print("Cross-validated number of basis:", grid_cv.best_params_) + Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} + """ + + def __init__(self, basis: Basis): + self._basis = copy.deepcopy(basis) + + @staticmethod + def _unpack_inputs(X: FeatureMatrix): + """Unpack impute without using transpose. + + Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, + returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` + exception since ``pynapple`` assumes that the time axis is the first axis. + + Parameters + ---------- + X: + The inputs horizontally stacked. + + Returns + ------- + : + A tuple of each individual input. + + """ + return (X[:, k] for k in range(X.shape[1])) + + def fit(self, X: FeatureMatrix, y=None): + """ + Compute the convolutional kernels. + + If any of the 1D basis in self._basis is in "conv" mode, it computes the convolutional kernels. + + Parameters + ---------- + X : + The data to fit the basis functions to, shape (num_samples, num_input). + y : ignored + Not used, present for API consistency by convention. + + Returns + ------- + self : + The transformer object. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(100, 2)) + + >>> # Define and fit tranformation basis + >>> basis = EvalMSpline(10) + >>> transformer = TransformerBasis(basis) + >>> transformer_fitted = transformer.fit(X) + """ + self._basis._set_kernel() + return self + + def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: + """ + Transform the data using the fitted basis functions. + + Parameters + ---------- + X : + The data to transform using the basis functions, shape (num_samples, num_input). + y : + Not used, present for API consistency by convention. + + Returns + ------- + : + The data transformed by the basis functions. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(10000, 2)) + + >>> # Define and fit tranformation basis + >>> basis = EvalMSpline(10, mode="conv", window_size=200) + >>> transformer = TransformerBasis(basis) + >>> # Before calling `fit` the convolution kernel is not set + >>> transformer.kernel_ + + >>> transformer_fitted = transformer.fit(X) + >>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs) + >>> transformer_fitted.kernel_.shape + (200, 10) + + >>> # Transform basis + >>> feature_transformed = transformer.transform(X[:, 0:1]) + """ + # transpose does not work with pynapple + # can't use func(*X.T) to unwrap + return self._basis._compute_features(*self._unpack_inputs(X)) + + def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: + """ + Compute the kernels and the features. + + This method is a convenience that combines fit and transform into + one step. + + Parameters + ---------- + X : + The data to fit the basis functions to and then transform. + y : + Not used, present for API consistency by convention. + + Returns + ------- + array-like + The data transformed by the basis functions, after fitting the basis + functions to the data. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalMSpline, TransformerBasis + + >>> # Example input + >>> X = np.random.normal(size=(100, 1)) + + >>> # Define tranformation basis + >>> basis = EvalMSpline(10) + >>> transformer = TransformerBasis(basis) + + >>> # Fit and transform basis + >>> feature_transformed = transformer.fit_transform(X) + """ + return self._basis.compute_features(*self._unpack_inputs(X)) + + def __getstate__(self): + """ + Explicitly define how to pickle TransformerBasis object. + + See https://docs.python.org/3/library/pickle.html#object.__getstate__ + and https://docs.python.org/3/library/pickle.html#pickle-state + """ + return {"_basis": self._basis} + + def __setstate__(self, state): + """ + Define how to populate the object's state when unpickling. + + Note that during unpickling a new object is created without calling __init__. + Needed to avoid infinite recursion in __getattr__ when unpickling. + + See https://docs.python.org/3/library/pickle.html#object.__setstate__ + and https://docs.python.org/3/library/pickle.html#pickle-state + """ + self._basis = state["_basis"] + + def __getattr__(self, name: str): + """ + Enable easy access to attributes of the underlying Basis object. + + Examples + -------- + >>> from nemos import basis + >>> bas = basis.EvalRaisedCosineLinear(5) + >>> trans_bas = basis.TransformerBasis(bas) + >>> bas.n_basis_funcs + 5 + >>> trans_bas.n_basis_funcs + 5 + """ + return getattr(self._basis, name) + + def __setattr__(self, name: str, value) -> None: + r""" + Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. + + Setting any other attribute is not allowed. + + Returns + ------- + None + + Raises + ------ + ValueError + If the attribute being set is not ``_basis`` or an attribute of ``_basis``. + + Examples + -------- + >>> import nemos as nmo + >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.EvalMSpline(10)) + >>> # allowed + >>> trans_bas._basis = nmo.basis.EvalBSpline(10) + >>> # allowed + >>> trans_bas.n_basis_funcs = 20 + >>> # not allowed + >>> try: + ... trans_bas.random_attribute_name = "some value" + ... except ValueError as e: + ... print(repr(e)) + ValueError('Only setting _basis or existing attributes of _basis is allowed.') + """ + # allow self._basis = basis + if name == "_basis": + super().__setattr__(name, value) + # allow changing existing attributes of self._basis + elif hasattr(self._basis, name): + setattr(self._basis, name, value) + # don't allow setting any other attribute + else: + raise ValueError( + "Only setting _basis or existing attributes of _basis is allowed." + ) + + def __sklearn_clone__(self) -> TransformerBasis: + """ + Customize how TransformerBasis objects are cloned when used with sklearn.model_selection. + + By default, scikit-learn tries to clone the object by calling __init__ using the output of get_params, + which fails in our case. + + For more info: https://scikit-learn.org/stable/developers/develop.html#cloning + """ + cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) + cloned_obj._basis.kernel_ = None + return cloned_obj + + def set_params(self, **parameters) -> TransformerBasis: + """ + Set TransformerBasis parameters. + + When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly + or the parameters of the underlying Basis, but not both. + + Examples + -------- + >>> from nemos.basis import EvalBSpline, EvalMSpline, TransformerBasis + >>> basis = EvalMSpline(10) + >>> transformer_basis = TransformerBasis(basis=basis) + + >>> # setting parameters of _basis is allowed + >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) + 8 + >>> # setting _basis directly is allowed + >>> print(type(transformer_basis.set_params(_basis=EvalBSpline(10))._basis)) + + >>> # mixing is not allowed, this will raise an exception + >>> try: + ... transformer_basis.set_params(_basis=EvalBSpline(10), n_basis_funcs=2) + ... except ValueError as e: + ... print(repr(e)) + ValueError('Set either new _basis object or parameters for existing _basis, not both.') + """ + new_basis = parameters.pop("_basis", None) + if new_basis is not None: + self._basis = new_basis + if len(parameters) > 0: + raise ValueError( + "Set either new _basis object or parameters for existing _basis, not both." + ) + else: + self._basis = self._basis.set_params(**parameters) + + return self + + def get_params(self, deep: bool = True) -> dict: + """Extend the dict of parameters from the underlying Basis with _basis.""" + return {"_basis": self._basis, **self._basis.get_params(deep)} + + def __dir__(self) -> list[str]: + """Extend the list of properties of methods with the ones from the underlying Basis.""" + return list(super().__dir__()) + list(self._basis.__dir__()) + + def __add__(self, other: TransformerBasis) -> TransformerBasis: + """ + Add two TransformerBasis objects. + + Parameters + ---------- + other + The other TransformerBasis object to add. + + Returns + ------- + : TransformerBasis + The resulting Basis object. + """ + return TransformerBasis(self._basis + other._basis) + + def __mul__(self, other: TransformerBasis) -> TransformerBasis: + """ + Multiply two TransformerBasis objects. + + Parameters + ---------- + other + The other TransformerBasis object to multiply. + + Returns + ------- + : + The resulting Basis object. + """ + return TransformerBasis(self._basis * other._basis) + + def __pow__(self, exponent: int) -> TransformerBasis: + """Exponentiation of a TransformerBasis object. + + Define the power of a basis by repeatedly applying the method __mul__. + The exponent must be a positive integer. + + Parameters + ---------- + exponent : + Positive integer exponent + + Returns + ------- + : + The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self. + + Raises + ------ + TypeError + If the provided exponent is not an integer. + ValueError + If the integer is zero or negative. + """ + # errors are handled by Basis.__pow__ + return TransformerBasis(self._basis**exponent) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 7d5c56c0..d49279a0 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -8,7 +8,7 @@ from numpy.typing import ArrayLike, NDArray from ..typing import FeatureMatrix -from ._basis_mixin import ConvBasisMixin, EvalBasisMixin +from ._basis_mixin import BasisTransformerMixin, ConvBasisMixin, EvalBasisMixin from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring from ._raised_cosine_basis import ( RaisedCosineBasisLinear, @@ -24,6 +24,7 @@ add_docstrings_cyclic_bspline, add_docstrings_mspline, ) +from ._transformer_basis import TransformerBasis __all__ = [ "EvalMSpline", @@ -38,6 +39,7 @@ "ConvRaisedCosineLog", "EvalOrthExponential", "ConvOrthExponential", + "TransformerBasis", ] @@ -45,7 +47,7 @@ def __dir__() -> list[str]: return __all__ -class EvalBSpline(EvalBasisMixin, BSplineBasis): +class EvalBSpline(EvalBasisMixin, BSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -61,6 +63,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_bspline("split_by_feature") def split_by_feature( @@ -128,7 +131,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return BSplineBasis.evaluate_on_grid(self, n_samples) -class ConvBSpline(ConvBasisMixin, BSplineBasis): +class ConvBSpline(ConvBasisMixin, BSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -145,6 +148,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_bspline("split_by_feature") def split_by_feature( @@ -212,7 +216,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return BSplineBasis.evaluate_on_grid(self, n_samples) -class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): +class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -228,6 +232,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( @@ -295,7 +300,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) -class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): +class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -312,6 +317,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( @@ -379,7 +385,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) -class EvalMSpline(EvalBasisMixin, MSplineBasis): +class EvalMSpline(EvalBasisMixin, MSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -395,6 +401,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_mspline("split_by_feature") def split_by_feature( @@ -462,7 +469,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return MSplineBasis.evaluate_on_grid(self, n_samples) -class ConvMSpline(ConvBasisMixin, MSplineBasis): +class ConvMSpline(ConvBasisMixin, MSplineBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -479,6 +486,7 @@ def __init__( order=order, label=label, ) + BasisTransformerMixin.__init__(self) @add_docstrings_mspline("split_by_feature") def split_by_feature( @@ -546,7 +554,9 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return MSplineBasis.evaluate_on_grid(self, n_samples) -class EvalRaisedCosineLinear(EvalBasisMixin, RaisedCosineBasisLinear): +class EvalRaisedCosineLinear( + EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin +): def __init__( self, n_basis_funcs: int, @@ -562,6 +572,7 @@ def __init__( mode="eval", label=label, ) + BasisTransformerMixin.__init__(self) @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -622,7 +633,9 @@ def split_by_feature( return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) -class ConvRaisedCosineLinear(ConvBasisMixin, RaisedCosineBasisLinear): +class ConvRaisedCosineLinear( + ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin +): def __init__( self, n_basis_funcs: int, @@ -639,6 +652,7 @@ def __init__( width=width, label=label, ) + BasisTransformerMixin.__init__(self) @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -699,7 +713,7 @@ def split_by_feature( return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) -class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): +class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -719,6 +733,7 @@ def __init__( mode="eval", label=label, ) + BasisTransformerMixin.__init__(self) @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -779,7 +794,7 @@ def split_by_feature( return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) -class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): +class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -800,6 +815,7 @@ def __init__( enforce_decay_to_zero=enforce_decay_to_zero, label=label, ) + BasisTransformerMixin.__init__(self) @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -860,7 +876,7 @@ def split_by_feature( return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) -class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): +class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis, BasisTransformerMixin): def __init__( self, n_basis_funcs: int, @@ -907,6 +923,7 @@ def __init__( mode="eval", label=label, ) + BasisTransformerMixin.__init__(self) @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -971,7 +988,7 @@ def split_by_feature( return OrthExponentialBasis.split_by_feature(self, x, axis=axis) -class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): +class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis, BasisTransformerMixin): """ Examples -------- @@ -1004,6 +1021,7 @@ def __init__( decay_rates=decay_rates, label=label, ) + BasisTransformerMixin.__init__(self) @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: diff --git a/tests/test_basis.py b/tests/test_basis.py index e9668dbe..c8412897 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -17,13 +17,7 @@ import nemos as nmo import nemos.basis.basis as basis import nemos.convolve as convolve -from nemos.basis._basis import ( - AdditiveBasis, - Basis, - MultiplicativeBasis, - TransformerBasis, - add_docstring, -) +from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring from nemos.basis._decaying_exponential import OrthExponentialBasis from nemos.basis._raised_cosine_basis import ( RaisedCosineBasisLinear, @@ -86,7 +80,7 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: ] + [ bas for _, bas in utils_testing.get_non_abstract_classes(nmo.basis._basis) - if bas != TransformerBasis + if bas != basis.TransformerBasis ] if filter_basis != "all": all_basis = [a for a in all_basis if filter_basis in a.__name__] @@ -3228,7 +3222,7 @@ def test_basis_to_transformer(basis_cls, class_specific_params): trans_bas = bas.to_transformer() - assert isinstance(trans_bas, TransformerBasis) + assert isinstance(trans_bas, basis.TransformerBasis) # check that things like n_basis_funcs are the same as the original basis for k in bas.__dict__.keys(): @@ -3277,7 +3271,7 @@ def test_to_transformer_and_constructor_are_equivalent( ) trans_bas_a = bas.to_transformer() - trans_bas_b = TransformerBasis(bas) + trans_bas_b = basis.TransformerBasis(bas) # they both just have a _basis assert ( @@ -3334,7 +3328,7 @@ def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): ) @pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): - trans_basis = TransformerBasis( + trans_basis = basis.TransformerBasis( CombinedBasis().instantiate_basis( n_basis_funcs, basis_cls, class_specific_params, window_size=10 ) @@ -3357,7 +3351,7 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_param def test_transformerbasis_set_params( basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params ): - trans_basis = TransformerBasis( + trans_basis = basis.TransformerBasis( CombinedBasis().instantiate_basis( n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 ) @@ -3374,7 +3368,7 @@ def test_transformerbasis_set_params( ) def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): # setting the _basis attribute should change it - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=10 ) @@ -3395,7 +3389,7 @@ def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): # setting an attribute that is an attribute of the underlying _basis # should propagate setting it on _basis itself - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=10 ) @@ -3417,7 +3411,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_para orig_bas = CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=10 ) - trans_bas = TransformerBasis(orig_bas) + trans_bas = basis.TransformerBasis(orig_bas) trans_bas.n_basis_funcs = 20 assert orig_bas.n_basis_funcs == 10 @@ -3433,7 +3427,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_para def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): # changing an attribute that is not _basis or an attribute of _basis # is not allowed - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=10 ) @@ -3459,10 +3453,10 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): bas_b = CombinedBasis().instantiate_basis( n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 ) - trans_bas_a = TransformerBasis(bas_a) - trans_bas_b = TransformerBasis(bas_b) + trans_bas_a = basis.TransformerBasis(bas_a) + trans_bas_b = basis.TransformerBasis(bas_b) trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, TransformerBasis) + assert isinstance(trans_bas_sum, basis.TransformerBasis) assert isinstance(trans_bas_sum._basis, AdditiveBasis) assert ( trans_bas_sum.n_basis_funcs @@ -3484,18 +3478,18 @@ def test_transformerbasis_addition(basis_cls, class_specific_params): def test_transformerbasis_multiplication(basis_cls, class_specific_params): n_basis_funcs_a = 5 n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = TransformerBasis( + trans_bas_a = basis.TransformerBasis( CombinedBasis().instantiate_basis( n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 ) ) - trans_bas_b = TransformerBasis( + trans_bas_b = basis.TransformerBasis( CombinedBasis().instantiate_basis( n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 ) ) trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, TransformerBasis) + assert isinstance(trans_bas_prod, basis.TransformerBasis) assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) assert ( trans_bas_prod.n_basis_funcs @@ -3526,7 +3520,7 @@ def test_transformerbasis_multiplication(basis_cls, class_specific_params): def test_transformerbasis_exponentiation( basis_cls, exponent: int, error_type, error_message, class_specific_params ): - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( 5, basis_cls, class_specific_params, window_size=10 ) @@ -3535,7 +3529,7 @@ def test_transformerbasis_exponentiation( if not isinstance(exponent, int): with pytest.raises(error_type, match=error_message): trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, TransformerBasis) + assert isinstance(trans_bas_exp, basis.TransformerBasis) assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) @@ -3544,7 +3538,7 @@ def test_transformerbasis_exponentiation( list_all_basis_classes(), ) def test_transformerbasis_dir(basis_cls, class_specific_params): - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( 5, basis_cls, class_specific_params, window_size=10 ) @@ -3573,7 +3567,7 @@ def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params orig_bas = CombinedBasis().instantiate_basis( 10, basis_cls, class_specific_params, window_size=20 ) - trans_bas = TransformerBasis(orig_bas) + trans_bas = basis.TransformerBasis(orig_bas) # kernel should be saved in the object after fit trans_bas.fit(np.random.randn(100, 20)) @@ -3597,7 +3591,7 @@ def test_transformerbasis_pickle( tmpdir, basis_cls, n_basis_funcs, class_specific_params ): # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = TransformerBasis( + trans_bas = basis.TransformerBasis( CombinedBasis().instantiate_basis( n_basis_funcs, basis_cls, class_specific_params, window_size=10 ) @@ -3608,7 +3602,7 @@ def test_transformerbasis_pickle( with open(filepath, "rb") as f: trans_bas2 = pickle.load(f) - assert isinstance(trans_bas2, TransformerBasis) + assert isinstance(trans_bas2, basis.TransformerBasis) if basis_cls in [AdditiveBasis, MultiplicativeBasis]: for bas in [ getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") @@ -3743,7 +3737,7 @@ def test_multi_epoch_pynapple_basis_transformer( n_input = bas._n_input_dimensionality # pass through transformer - bas = TransformerBasis(bas) + bas = basis.TransformerBasis(bas) # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) @@ -3839,7 +3833,7 @@ def test__get_splitter( ): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2, bas3] ): return @@ -3996,7 +3990,7 @@ def test__get_splitter_split_by_input( ): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2] ): return @@ -4030,7 +4024,7 @@ def test__get_splitter_split_by_input( def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2, bas3] ): return @@ -4078,7 +4072,7 @@ def test_split_feature_axis( ): # skip nested if any( - bas in (AdditiveBasis, MultiplicativeBasis, TransformerBasis) + bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) for bas in [bas1, bas2] ): return From 9985784b1c646c3a4fb837792069092363223dc0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 27 Nov 2024 21:14:29 -0500 Subject: [PATCH 056/109] added to api refs --- docs/api_reference.rst | 43 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index a10be9b4..9068ff0d 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -22,6 +22,10 @@ Classes for creating Generalized Linear Models (GLMs) for both single neurons an The ``nemos.basis`` module -------------------------- Provides basis function classes to construct and transform features for model inputs. +Basis can be grouped according to the mode of operation into basis that performs convolution and basis that operates +as non-linear maps. + + .. currentmodule:: nemos.basis @@ -31,9 +35,46 @@ Provides basis function classes to construct and transform features for model in :nosignatures: Basis - EvalOrthExponential + +**Bases For Convolution:** + +.. autosummary:: + :toctree: generated/basis + :recursive: + :nosignatures: + + + ConvMSpline + ConvBSpline + ConvCyclicBSpline + ConvRaisedCosineLinear + ConvRaisedCosineLog ConvOrthExponential +**Bases For Non-Linear Mapping:** + +.. autosummary:: + :toctree: generated/basis + :recursive: + :nosignatures: + + EvalMSpline + EvalBSpline + EvalCyclicBSpline + EvalRaisedCosineLinear + EvalRaisedCosineLog + EvalOrthExponential + +**Composite Bases:** + +.. autosummary:: + :toctree: generated/basis + :recursive: + :nosignatures: + + AdditiveBasis + MultiplicativeBasis + .. _observation_models: The ``nemos.observation_models`` module -------------------------------------- From 7cec3c0478a1a128dec15c2bc14faf366693e713 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 10:12:00 -0500 Subject: [PATCH 057/109] fix auto-imports --- src/nemos/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index 42314b90..7f88be21 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -6,7 +6,7 @@ import jaxopt import pynapple as nap from jax.typing import ArrayLike -from statsmodels.tools.typing import NDArray +from numpy.typing import NDArray from .pytrees import FeaturePytree From 7fd2ef78c5787772e000fd36c105406fc696d836 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:14:14 -0500 Subject: [PATCH 058/109] fix ome links --- docs/api_reference.rst | 26 ++++++++++++++++--- docs/conf.py | 4 ++- .../plot_05_sklearn_pipeline_cv_demo.md | 12 ++++----- docs/tutorials/plot_06_calcium_imaging.md | 2 +- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 9068ff0d..7c2cc6bb 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -26,11 +26,12 @@ Basis can be grouped according to the mode of operation into basis that performs as non-linear maps. +**The abstract class `Basis`:** -.. currentmodule:: nemos.basis +.. currentmodule:: nemos.basis._basis .. autosummary:: - :toctree: generated/basis + :toctree: generated/_basis :recursive: :nosignatures: @@ -38,6 +39,8 @@ as non-linear maps. **Bases For Convolution:** +.. currentmodule:: nemos.basis.basis + .. autosummary:: :toctree: generated/basis :recursive: @@ -53,6 +56,8 @@ as non-linear maps. **Bases For Non-Linear Mapping:** +.. currentmodule:: nemos.basis.basis + .. autosummary:: :toctree: generated/basis :recursive: @@ -67,14 +72,27 @@ as non-linear maps. **Composite Bases:** +.. currentmodule:: nemos.basis._basis + .. autosummary:: - :toctree: generated/basis + :toctree: generated/_basis :recursive: :nosignatures: AdditiveBasis MultiplicativeBasis +**Basis as scikit-learn tranformers:** + +.. currentmodule:: nemos.basis._transformer_basis + +.. autosummary:: + :toctree: generated/_transformer_basis + :recursive: + :nosignatures: + + TransformerBasis + .. _observation_models: The ``nemos.observation_models`` module -------------------------------------- @@ -163,7 +181,7 @@ These objects can be provided as input to nemos GLM methods. .. currentmodule:: nemos.pytrees .. autosummary:: - :toctree: generated/identifiability_constraints + :toctree: generated/pytree :recursive: :nosignatures: diff --git a/docs/conf.py b/docs/conf.py index 5a5b2c43..c8f5e3a2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -157,4 +157,6 @@ exclude_tutorials = os.environ.get("EXCLUDE_TUTORIALS", "false").lower() == "true" if exclude_tutorials: - nb_execution_excludepatterns = ["tutorials/**", "how_to_guide/**", "background/**"] \ No newline at end of file + nb_execution_excludepatterns = ["tutorials/**", "how_to_guide/**", "background/**"] + +viewcode_follow_imported_members = True diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index 60f7f4b5..a3fee426 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -48,7 +48,7 @@ In particular, we will learn: 1. What a scikit-learn pipeline is. 2. Why pipelines are useful. -3. How to combine NeMoS [`Basis`](nemos.basis.Basis) and [`GLM`](nemos.glm.GLM) objects in a pipeline. +3. How to combine NeMoS [`Basis`](nemos.basis._basis.Basis) and [`GLM`](nemos.glm.GLM) objects in a pipeline. 4. How to select the number of bases and the basis type through cross-validation (or any other hyperparameter in the pipeline). 5. How to use a custom scoring metric to quantify the performance of each configuration. @@ -150,9 +150,9 @@ sns.despine(ax=ax) ``` ### Converting NeMoS `Basis` to a transformer -In order to use NeMoS [`Basis`](nemos.basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis.TransformerBasis) wrapper class. +In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer): ```{code-cell} ipython3 @@ -160,14 +160,14 @@ bas = nmo.basis.ConvRaisedCosineLinear(5, window_size=5) trans_bas = bas.to_transformer() ``` -[`TransformerBasis`](nemos.basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis.Basis) object's attributes: +[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: ```{code-cell} ipython3 print(bas.n_basis_funcs, trans_bas.n_basis_funcs) ``` -We can also set attributes of the underlying [`Basis`](nemos.basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis.Basis), and neither does changing the original [`Basis`](nemos.basis.Basis) change [`TransformerBasis`](nemos.basis.TransformerBasis) we created: +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: ```{code-cell} ipython3 @@ -442,7 +442,7 @@ We are now able to capture the distribution of the firing rate appropriately: bo ### Evaluating different bases directly -In the previous example we set the number of basis functions of the [`Basis`](nemos.basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well. +In the previous example we set the number of basis functions of the [`Basis`](nemos.basis._basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well. Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`: diff --git a/docs/tutorials/plot_06_calcium_imaging.md b/docs/tutorials/plot_06_calcium_imaging.md index 3eeed024..e896affc 100644 --- a/docs/tutorials/plot_06_calcium_imaging.md +++ b/docs/tutorials/plot_06_calcium_imaging.md @@ -184,7 +184,7 @@ heading_basis = nmo.basis.EvalCyclicBSpline(n_basis_funcs=12) coupling_basis = nmo.basis.ConvRaisedCosineLog(3, window_size=10) ``` -We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an [`AdditiveBasis`](nemos.basis.AdditiveBasis) object). +We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an [`AdditiveBasis`](nemos.basis._basis.AdditiveBasis) object). ```{code-cell} ipython3 From 6b7486aedda9dc544b86b1eb6c3d5bf42e83d02c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:14:30 -0500 Subject: [PATCH 059/109] simplified inheritance --- docs/how_to_guide/plot_06_glm_pytree.md | 2 +- docs/tutorials/plot_05_place_cells.md | 6 ++-- src/nemos/basis/_basis.py | 7 ++-- src/nemos/basis/basis.py | 44 ++++++++++++------------- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 9dbd7ddc..1d6ca5f4 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -274,7 +274,7 @@ Okay, let's use unit number 7. Now let's set up our design matrix. First, let's fit the head direction by itself. Head direction is a circular variable (pi and -pi are adjacent to each other), so we need to use a basis that has this property as well. -[`EvalCyclicBSpline`](nemos.basis.EvalCyclicBSpline) is one such basis. +[`EvalCyclicBSpline`](nemos.basis.basis.EvalCyclicBSpline) is one such basis. Let's create our basis and then arrange our data properly. diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index 87e97d88..597959c3 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -335,9 +335,9 @@ print(count.shape) For each feature, we will use a different set of basis : - - position : [`EvalMSpline`](nemos.basis.EvalMSpline) - - theta phase : [`EvalCyclicBSpline`](nemos.basis.EvalCyclicBSpline) - - speed : [`EvalMSpline`](nemos.basis.EvalMSpline) + - position : [`EvalMSpline`](nemos.basis.basis.EvalMSpline) + - theta phase : [`EvalCyclicBSpline`](nemos.basis.basis.EvalCyclicBSpline) + - speed : [`EvalMSpline`](nemos.basis.basis.EvalMSpline) ```{code-cell} ipython3 diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index c37cec78..b6ec1fb7 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -98,7 +98,7 @@ def min_max_rescale_samples( return sample_pts, scaling -class Basis(Base, abc.ABC): +class Basis(Base, abc.ABC, BasisTransformerMixin): """ Abstract base class for defining basis functions for feature transformation. @@ -824,7 +824,7 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: add_docstring_multiplicative = partial(add_docstring, cls=Basis) -class AdditiveBasis(Basis, BasisTransformerMixin): +class AdditiveBasis(Basis): """ Class representing the addition of two Basis objects. @@ -869,7 +869,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._label = "(" + basis1.label + " + " + basis2.label + ")" self._basis1 = basis1 self._basis2 = basis2 - BasisTransformerMixin.__init__(self) def _set_num_output_features(self, *xi: NDArray) -> Basis: self._n_basis_input = ( @@ -1167,7 +1166,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) -class MultiplicativeBasis(Basis, BasisTransformerMixin): +class MultiplicativeBasis(Basis): """ Class representing the multiplication (external product) of two Basis objects. diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index d49279a0..b70a7104 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -47,7 +47,7 @@ def __dir__() -> list[str]: return __all__ -class EvalBSpline(EvalBasisMixin, BSplineBasis, BasisTransformerMixin): +class EvalBSpline(EvalBasisMixin, BSplineBasis): def __init__( self, n_basis_funcs: int, @@ -63,7 +63,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_bspline("split_by_feature") def split_by_feature( @@ -131,7 +131,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return BSplineBasis.evaluate_on_grid(self, n_samples) -class ConvBSpline(ConvBasisMixin, BSplineBasis, BasisTransformerMixin): +class ConvBSpline(ConvBasisMixin, BSplineBasis): def __init__( self, n_basis_funcs: int, @@ -148,7 +148,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_bspline("split_by_feature") def split_by_feature( @@ -216,7 +216,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return BSplineBasis.evaluate_on_grid(self, n_samples) -class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis, BasisTransformerMixin): +class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): def __init__( self, n_basis_funcs: int, @@ -232,7 +232,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( @@ -300,7 +300,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) -class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis, BasisTransformerMixin): +class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): def __init__( self, n_basis_funcs: int, @@ -317,7 +317,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( @@ -385,7 +385,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) -class EvalMSpline(EvalBasisMixin, MSplineBasis, BasisTransformerMixin): +class EvalMSpline(EvalBasisMixin, MSplineBasis): def __init__( self, n_basis_funcs: int, @@ -401,7 +401,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_mspline("split_by_feature") def split_by_feature( @@ -469,7 +469,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return MSplineBasis.evaluate_on_grid(self, n_samples) -class ConvMSpline(ConvBasisMixin, MSplineBasis, BasisTransformerMixin): +class ConvMSpline(ConvBasisMixin, MSplineBasis): def __init__( self, n_basis_funcs: int, @@ -486,7 +486,7 @@ def __init__( order=order, label=label, ) - BasisTransformerMixin.__init__(self) + @add_docstrings_mspline("split_by_feature") def split_by_feature( @@ -572,7 +572,7 @@ def __init__( mode="eval", label=label, ) - BasisTransformerMixin.__init__(self) + @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -652,7 +652,7 @@ def __init__( width=width, label=label, ) - BasisTransformerMixin.__init__(self) + @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -713,7 +713,7 @@ def split_by_feature( return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) -class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog, BasisTransformerMixin): +class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): def __init__( self, n_basis_funcs: int, @@ -733,7 +733,7 @@ def __init__( mode="eval", label=label, ) - BasisTransformerMixin.__init__(self) + @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -794,7 +794,7 @@ def split_by_feature( return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) -class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog, BasisTransformerMixin): +class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): def __init__( self, n_basis_funcs: int, @@ -815,7 +815,7 @@ def __init__( enforce_decay_to_zero=enforce_decay_to_zero, label=label, ) - BasisTransformerMixin.__init__(self) + @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -876,7 +876,7 @@ def split_by_feature( return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) -class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis, BasisTransformerMixin): +class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): def __init__( self, n_basis_funcs: int, @@ -923,7 +923,7 @@ def __init__( mode="eval", label=label, ) - BasisTransformerMixin.__init__(self) + @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -988,7 +988,7 @@ def split_by_feature( return OrthExponentialBasis.split_by_feature(self, x, axis=axis) -class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis, BasisTransformerMixin): +class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): """ Examples -------- @@ -1021,7 +1021,7 @@ def __init__( decay_rates=decay_rates, label=label, ) - BasisTransformerMixin.__init__(self) + @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: From 0ed90401abdeee4f810ffdafef019254c0fd0ec9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:28:22 -0500 Subject: [PATCH 060/109] fixed other links and added SplineBasis --- docs/api_reference.rst | 11 +++++++- docs/background/plot_01_1D_basis_function.md | 10 +++---- docs/background/plot_03_1D_convolution.md | 4 +-- docs/developers_notes/04-basis_module.md | 28 +++++++++---------- .../plot_05_sklearn_pipeline_cv_demo.md | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 7c2cc6bb..6f56e33a 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -26,7 +26,7 @@ Basis can be grouped according to the mode of operation into basis that performs as non-linear maps. -**The abstract class `Basis`:** +**The abstract classes:** .. currentmodule:: nemos.basis._basis @@ -37,6 +37,15 @@ as non-linear maps. Basis +.. currentmodule:: nemos.basis._spline_basis +.. autosummary:: + :toctree: generated/_basis + :recursive: + :nosignatures: + + SplineBasis + + **Bases For Convolution:** .. currentmodule:: nemos.basis.basis diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 97fedd74..527f8363 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -68,8 +68,8 @@ bspline = nmo.basis.EvalBSpline(n_basis_funcs=n_basis, order=order) ## Evaluating a Basis -The [`Basis`](nemos.basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis -is defined by the samples that we input to the [`__call__`](nemos.basis.Basis.__call__) method, and covers from the smallest to the largest value. +The [`Basis`](nemos.basis._basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis +is defined by the samples that we input to the [`__call__`](nemos.basis._basis.Basis.__call__) method, and covers from the smallest to the largest value. ```{code-cell} ipython3 @@ -143,8 +143,8 @@ plt.tight_layout() ## Feature Computation The bases in the module `nemos.basis` can be classified in two categories: -- **Evaluation Bases**: Objects for which [`compute_features`](nemos.basis.Basis.compute_features) that returns the evaluated basis. This means that the basis are applying a non-linear transformation of the input. The class name for this kind of bases starts with "Eval", e.g. "EvalBSpline". -- **Convolution Bases**: Objects for which [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis elements with `window_size` specified by the user. The class name for this kind of bases starts with "Conv", e.g. "ConvBSpline". +- **Evaluation Bases**: Objects for which [`compute_features`](nemos.basis._basis.Basis.compute_features) that returns the evaluated basis. This means that the basis are applying a non-linear transformation of the input. The class name for this kind of bases starts with "Eval", e.g. "EvalBSpline". +- **Convolution Bases**: Objects for which [`compute_features`](nemos.basis._basis.Basis.compute_features) will convolve the input with a kernel of basis elements with `window_size` specified by the user. The class name for this kind of bases starts with "Conv", e.g. "ConvBSpline". Let's see how this two modalities operate. @@ -198,7 +198,7 @@ check out the tutorial on [1D convolutions](plot_03_1D_convolution). Plotting the Basis Function Elements: -------------------------------------- We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points -and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns +and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become particularly evident when working with multidimensional basis functions. You can find more details and visual background in the diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md index 38ae9e7f..18fed52b 100644 --- a/docs/background/plot_03_1D_convolution.md +++ b/docs/background/plot_03_1D_convolution.md @@ -186,14 +186,14 @@ if path.exists(): fig.savefig(path / "plot_03_1D_convolution.svg") ``` -## Convolve using [`Basis.compute_features`](nemos.basis.Basis.compute_features) +## Convolve using [`Basis.compute_features`](nemos.basis._basis.Basis.compute_features) Every basis in the `nemos.basis` module whose class name starts with "Conv" will perform a 1D convolution over the provided input when the `compute_features` method is called. The basis elements will be used as filters for the convolution. All the parameters of [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) can be passed to the object directly at initialization. -Let's see how we can get the same results through [`Basis`](nemos.basis.Basis). +Let's see how we can get the same results through [`Basis`](nemos.basis._basis.Basis). ```{code-cell} ipython3 diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index c0bfd2a8..13f89413 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -26,23 +26,23 @@ Abstract Class Basis └─ Concrete Subclass OrthExponentialBasis ``` -The super-class [`Basis`](nemos.basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`__call__`](nemos.basis.Basis.__call__) that is specific for each concrete class. See below for more details. +The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`__call__`](nemos.basis._basis.Basis.__call__) that is specific for each concrete class. See below for more details. -## The Class `nemos.basis.Basis` +## The Class `nemos.basis._basis.Basis` (the-public-method-compute_features)= ### The Public Method `compute_features` -The [`compute_features`](nemos.basis.Basis.compute_features) method checks input consistency and applies the basis function to the inputs. -[`Basis`](nemos.basis.Basis) can operate in two modes defined at initialization: `"eval"` and `"conv"`. When a basis is in mode `"eval"`, -[`compute_features`](nemos.basis.Basis.compute_features) evaluates the basis at the given input samples. When in mode `"conv"`, it will convolve the samples +The [`compute_features`](nemos.basis._basis.Basis.compute_features) method checks input consistency and applies the basis function to the inputs. +[`Basis`](nemos.basis._basis.Basis) can operate in two modes defined at initialization: `"eval"` and `"conv"`. When a basis is in mode `"eval"`, +[`compute_features`](nemos.basis._basis.Basis.compute_features) evaluates the basis at the given input samples. When in mode `"conv"`, it will convolve the samples with a bank of kernels, one per basis function. It accepts one or more NumPy array or pynapple `Tsd` object as input, and performs the following steps: 1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case. 2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. -3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). +3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). 4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. :::{admonition} Multiple epochs @@ -55,20 +55,20 @@ input. (the-public-method-evaluate_on_grid)= ### The Public Method `evaluate_on_grid` -The [`compute_features`](nemos.basis.Basis.compute_features) method evaluates the basis set on a grid of equidistant sample points. The user specifies the input as a series of integers, one for each dimension of the basis function, that indicate the number of sample points in each coordinate of the grid. +The [`compute_features`](nemos.basis._basis.Basis.compute_features) method evaluates the basis set on a grid of equidistant sample points. The user specifies the input as a series of integers, one for each dimension of the basis function, that indicate the number of sample points in each coordinate of the grid. This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the [`__call__`](nemos.basis.Basis.__call__) method. +3. Calls the [`__call__`](nemos.basis._basis.Basis.__call__) method. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods -The [`nemos.basis.Basis`](nemos.basis.Basis) class has the following abstract methods, which every concrete subclass must implement: +The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. [`__call__`](nemos.basis.Basis.__call__): Evaluates a basis over some specified samples. +1. [`__call__`](nemos.basis._basis.Basis.__call__): Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines @@ -76,8 +76,8 @@ The [`nemos.basis.Basis`](nemos.basis.Basis) class has the following abstract me ### Implementing Concrete Basis Objects To write a usable (i.e., concrete, non-abstract) basis object, you -- **Must** inherit the abstract superclass [`Basis`](nemos.basis.Basis) -- **Must** define the [`__call__`](nemos.basis.Basis.__call__) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. -- **Should not** overwrite the [`compute_features`](nemos.basis.Basis.compute_features) and [`compute_features`](nemos.basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis.Basis). -- **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis.SplineBasis)). +- **Must** inherit the abstract superclass [`Basis`](nemos.basis._basis.Basis) +- **Must** define the [`__call__`](nemos.basis._basis.Basis.__call__) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. +- **Should not** overwrite the [`compute_features`](nemos.basis._basis.Basis.compute_features) and [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis._basis.Basis). +- **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis._spline_basis.SplineBasis)). diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index a3fee426..f3382de5 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -152,7 +152,7 @@ sns.despine(ax=ax) ### Converting NeMoS `Basis` to a transformer In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): ```{code-cell} ipython3 From 7f927ff16491578f3f26396742880dcc28d4d031 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:41:42 -0500 Subject: [PATCH 061/109] fixed all relative links --- docs/api_reference.rst | 2 ++ docs/background/plot_01_1D_basis_function.md | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 6f56e33a..b38179ae 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -28,6 +28,8 @@ as non-linear maps. **The abstract classes:** +These classes are building blocks for concrete basis classes. + .. currentmodule:: nemos.basis._basis .. autosummary:: diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 527f8363..265a2828 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -45,7 +45,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`EvalMSpline`](nemos.basis.EvalMSpline). +We'll start by defining a 1D basis function object of the type [`EvalMSpline`](nemos.basis.basis.EvalMSpline). The hyperparameters required to initialize this class are: - The number of basis functions, which should be a positive integer. From 92c07574b72086391e3c757b0859d302a678b6f5 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:44:28 -0500 Subject: [PATCH 062/109] ignore timeouts --- .readthedocs.yaml | 2 +- src/nemos/basis/basis.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index ecaca19d..37e28a3a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,7 +15,7 @@ build: - gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9 post_build: # Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking. - - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" + - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" # The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset # therefore for this file I am not checking the figures. - htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003" diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index b70a7104..23e0fa8b 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -64,7 +64,6 @@ def __init__( label=label, ) - @add_docstrings_bspline("split_by_feature") def split_by_feature( self, @@ -149,7 +148,6 @@ def __init__( label=label, ) - @add_docstrings_bspline("split_by_feature") def split_by_feature( self, @@ -233,7 +231,6 @@ def __init__( label=label, ) - @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( self, @@ -318,7 +315,6 @@ def __init__( label=label, ) - @add_docstrings_cyclic_bspline("split_by_feature") def split_by_feature( self, @@ -402,7 +398,6 @@ def __init__( label=label, ) - @add_docstrings_mspline("split_by_feature") def split_by_feature( self, @@ -487,7 +482,6 @@ def __init__( label=label, ) - @add_docstrings_mspline("split_by_feature") def split_by_feature( self, @@ -573,7 +567,6 @@ def __init__( label=label, ) - @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ @@ -653,7 +646,6 @@ def __init__( label=label, ) - @add_raised_cosine_linear_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ @@ -734,7 +726,6 @@ def __init__( label=label, ) - @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ @@ -816,7 +807,6 @@ def __init__( label=label, ) - @add_raised_cosine_log_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ @@ -924,7 +914,6 @@ def __init__( label=label, ) - @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ @@ -1022,7 +1011,6 @@ def __init__( label=label, ) - @add_orth_exp_decay_docstring("evaluate_on_grid") def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ From 9217087bd2faebd603abb2af811a077484b17311 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 13:51:33 -0500 Subject: [PATCH 063/109] uniform caps --- docs/api_reference.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index b38179ae..7f6b6aea 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -26,9 +26,9 @@ Basis can be grouped according to the mode of operation into basis that performs as non-linear maps. -**The abstract classes:** +**The Abstract Classes:** -These classes are building blocks for concrete basis classes. +These classes are the building blocks for the concrete basis classes. .. currentmodule:: nemos.basis._basis @@ -93,7 +93,7 @@ These classes are building blocks for concrete basis classes. AdditiveBasis MultiplicativeBasis -**Basis as scikit-learn tranformers:** +**Basis As `scikit-learn` Tranformers:** .. currentmodule:: nemos.basis._transformer_basis From b95f2f96cac4e75bcec3cafca21ae7dcc5fadbf5 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 16:18:38 -0500 Subject: [PATCH 064/109] fix text --- docs/background/plot_01_1D_basis_function.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 265a2828..3e494265 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -141,10 +141,11 @@ plt.tight_layout() ``` ## Feature Computation -The bases in the module `nemos.basis` can be classified in two categories: +The bases in the `nemos.basis` module can be grouped into two categories: -- **Evaluation Bases**: Objects for which [`compute_features`](nemos.basis._basis.Basis.compute_features) that returns the evaluated basis. This means that the basis are applying a non-linear transformation of the input. The class name for this kind of bases starts with "Eval", e.g. "EvalBSpline". -- **Convolution Bases**: Objects for which [`compute_features`](nemos.basis._basis.Basis.compute_features) will convolve the input with a kernel of basis elements with `window_size` specified by the user. The class name for this kind of bases starts with "Conv", e.g. "ConvBSpline". +1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `EvalBSpline`. + +2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `ConvBSpline`. Let's see how this two modalities operate. From db72fdaa6cecaf38c635fa6c378594f69df4e7ed Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 19:33:28 -0500 Subject: [PATCH 065/109] fix doctests --- src/nemos/basis/_raised_cosine_basis.py | 18 -------- src/nemos/basis/_spline_basis.py | 17 -------- src/nemos/basis/_transformer_basis.py | 7 ++-- src/nemos/basis/basis.py | 52 ++++++++++++------------ src/nemos/identifiability_constraints.py | 8 ++-- src/nemos/simulation.py | 6 +-- 6 files changed, 36 insertions(+), 72 deletions(-) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 7bab06e3..541a6818 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -50,15 +50,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): Note that one cannot change the default value for the ``axis`` parameter. Basis assumes that the convolution axis is ``axis=0``. - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLinear - >>> X = np.random.normal(size=(1000, 1)) - >>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - References ---------- .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., @@ -256,15 +247,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): Note that one cannot change the default value for the ``axis`` parameter. Basis assumes that the convolution axis is ``axis=0``. - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import RaisedCosineBasisLog - >>> X = np.random.normal(size=(1000, 1)) - >>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cosine_basis(sample_points) - References ---------- .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 0eba604b..617138b0 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -343,14 +343,6 @@ class BSplineBasis(SplineBasis, abc.ABC): ---------- .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import BSplineBasis - >>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis(sample_points) """ def __init__( @@ -463,15 +455,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): Number of basis functions, int. order : Order of the splines used in basis functions, int. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import EvalCyclicBSpline - >>> X = np.random.normal(size=(1000, 1)) - >>> cyclic_basis = EvalCyclicBSpline(n_basis_funcs=5, order=3, mode="conv", window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_basis(sample_points) """ def __init__( diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 624fb2c4..461a9399 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -136,13 +136,12 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis + >>> from nemos.basis import ConvMSpline, TransformerBasis >>> # Example input >>> X = np.random.normal(size=(10000, 2)) - >>> # Define and fit tranformation basis - >>> basis = EvalMSpline(10, mode="conv", window_size=200) + >>> basis = ConvMSpline(10, window_size=200) >>> transformer = TransformerBasis(basis) >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ @@ -306,7 +305,7 @@ def set_params(self, **parameters) -> TransformerBasis: 8 >>> # setting _basis directly is allowed >>> print(type(transformer_basis.set_params(_basis=EvalBSpline(10))._basis)) - + >>> # mixing is not allowed, this will raise an exception >>> try: ... transformer_basis.set_params(_basis=EvalBSpline(10), n_basis_funcs=2) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 23e0fa8b..e2520f97 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -76,12 +76,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import EvalBSpline >>> from nemos.glm import GLM - >>> basis = EvalBSpline(n_basis_funcs=6, label="two_inputs") - >>> X_multi = basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> basis = EvalBSpline(n_basis_funcs=6, label="one_input") + >>> X = basis.compute_features(np.random.randn(20,)) + >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") - two_inputs, shape (20, 2, 6) + one_input, shape (20, 1, 6) """ return BSplineBasis.split_by_feature(self, x, axis=axis) @@ -120,7 +120,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') >>> plt.title('B-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') + Text(0.5, 1.0, 'B-Spline Basis Functions') >>> plt.xlabel('Domain') Text(0.5, 0, 'Domain') >>> plt.ylabel('Basis Function Value') @@ -204,7 +204,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') >>> plt.title('B-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') + Text(0.5, 1.0, 'B-Spline Basis Functions') >>> plt.xlabel('Domain') Text(0.5, 0, 'Domain') >>> plt.ylabel('Basis Function Value') @@ -243,12 +243,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import EvalCyclicBSpline >>> from nemos.glm import GLM - >>> basis = EvalCyclicBSpline(n_basis_funcs=6, label="two_inputs") - >>> X_multi = basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> basis = EvalCyclicBSpline(n_basis_funcs=6, label="one_input") + >>> X = basis.compute_features(np.random.randn(20,)) + >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") - two_inputs, shape (20, 2, 6) + one_input, shape (20, 1, 6) """ return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) @@ -287,7 +287,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') >>> plt.title('Cyclic B-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') + Text(0.5, 1.0, 'Cyclic B-Spline Basis Functions') >>> plt.xlabel('Domain') Text(0.5, 0, 'Domain') >>> plt.ylabel('Basis Function Value') @@ -371,7 +371,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') >>> plt.title('Cyclic B-Spline Basis Functions') - Text(0.5, 1.0, 'M-Spline Basis Functions') + Text(0.5, 1.0, 'Cyclic B-Spline Basis Functions') >>> plt.xlabel('Domain') Text(0.5, 0, 'Domain') >>> plt.ylabel('Basis Function Value') @@ -410,12 +410,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import EvalMSpline >>> from nemos.glm import GLM - >>> basis = EvalMSpline(n_basis_funcs=6, label="two_inputs") - >>> X_multi = basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> basis = EvalMSpline(n_basis_funcs=6, label="one_input") + >>> X = basis.compute_features(np.random.randn(20)) + >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") - two_inputs, shape (20, 2, 6) + one_input, shape (20, 1, 6) """ return MSplineBasis.split_by_feature(self, x, axis=axis) @@ -615,12 +615,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import EvalRaisedCosineLinear >>> from nemos.glm import GLM - >>> basis = EvalRaisedCosineLinear(n_basis_funcs=6, label="two_inputs") - >>> X_multi = basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> basis = EvalRaisedCosineLinear(n_basis_funcs=6, label="one_input") + >>> X = basis.compute_features(np.random.randn(20,)) + >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") - two_inputs, shape (20, 2, 6) + one_input, shape (20, 1, 6) """ return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) @@ -774,12 +774,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import EvalRaisedCosineLog >>> from nemos.glm import GLM - >>> basis = EvalRaisedCosineLog(n_basis_funcs=6, label="two_inputs") - >>> X_multi = basis.compute_features(np.random.randn(20, 2)) - >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) + >>> basis = EvalRaisedCosineLog(n_basis_funcs=6, label="one_input") + >>> X = basis.compute_features(np.random.randn(20,)) + >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") - two_inputs, shape (20, 2, 6) + one_input, shape (20, 1, 6) """ return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) @@ -963,7 +963,7 @@ def split_by_feature( >>> from nemos.basis import EvalOrthExponential >>> from nemos.glm import GLM >>> # Define an additive basis - >>> basis = EvalOrthExponential(n_basis_funcs=5, label="feature") + >>> basis = EvalOrthExponential(n_basis_funcs=5, decay_rates=np.arange(1, 6), label="feature") >>> # Generate a sample input array and compute features >>> x = np.random.randn(20) >>> X = basis.compute_features(x) @@ -1059,7 +1059,7 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import ConvOrthExponential >>> from nemos.glm import GLM - >>> basis = ConvOrthExponential(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = ConvOrthExponential(n_basis_funcs=6, decay_rates=np.arange(1, 7), window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index 0659d92f..6098bfd0 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -216,10 +216,10 @@ def apply_identifiability_constraints( -------- >>> import numpy as np >>> from nemos.identifiability_constraints import apply_identifiability_constraints - >>> from nemos.basis import BSplineBasis + >>> from nemos.basis import EvalBSpline >>> from nemos.glm import GLM >>> # define a feature matrix - >>> bas = BSplineBasis(5) + BSplineBasis(6) + >>> bas = EvalBSpline(5) + EvalBSpline(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints(feature_matrix) @@ -281,10 +281,10 @@ def apply_identifiability_constraints_by_basis_component( -------- >>> import numpy as np >>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component - >>> from nemos.basis import BSplineBasis + >>> from nemos.basis import EvalBSpline >>> from nemos.glm import GLM >>> # define a feature matrix - >>> bas = BSplineBasis(5) + BSplineBasis(6) + >>> bas = EvalBSpline(5) + EvalBSpline(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints_by_basis_component(bas, feature_matrix) diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index 48af7008..cbfd674f 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -151,11 +151,11 @@ def regress_filter(coupling_filters: NDArray, eval_basis: NDArray) -> NDArray: >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from nemos.simulation import regress_filter, difference_of_gammas - >>> from nemos.basis import RaisedCosineBasisLog + >>> from nemos.basis import EvalRaisedCosineLog >>> filter_duration = 100 >>> n_basis_funcs = 20 >>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1) - >>> _, basis = RaisedCosineBasisLog(10).evaluate_on_grid(filter_duration) + >>> _, basis = EvalRaisedCosineLog(10).evaluate_on_grid(filter_duration) >>> weights = regress_filter(filter_bank, basis)[0, 0] >>> print("Weights shape:", weights.shape) Weights shape: (10,) @@ -275,7 +275,7 @@ def simulate_recurrent( >>> feedforward_input = np.random.normal(size=(1000, n_neurons, 1)) >>> coupling_basis = np.random.normal(size=(coupling_duration, 10)) >>> coupling_coef = np.random.normal(size=(n_neurons, n_neurons, 10)) - >>> intercept = -8 * np.ones(n_neurons) + >>> intercept = -9 * np.ones(n_neurons) >>> init_spikes = np.zeros((coupling_duration, n_neurons)) >>> random_key = jax.random.key(123) >>> spikes, rates = simulate_recurrent( From cbfc4a76fcf8440ec9caee9a0972860402255148 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 19:34:34 -0500 Subject: [PATCH 066/109] linted --- src/nemos/basis/basis.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index e2520f97..b1c35c37 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -1059,7 +1059,12 @@ def split_by_feature( >>> import numpy as np >>> from nemos.basis import ConvOrthExponential >>> from nemos.glm import GLM - >>> basis = ConvOrthExponential(n_basis_funcs=6, decay_rates=np.arange(1, 7), window_size=10, label="two_inputs") + >>> basis = ConvOrthExponential( + ... n_basis_funcs=6, + ... decay_rates=np.arange(1, 7), + ... window_size=10, + ... label="two_inputs" + ... ) >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): From 462670b4c8622df72700bab92341cb123adc2118 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 19:41:57 -0500 Subject: [PATCH 067/109] fix test pipeline --- tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a12bafff..872cced8 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,7 +6,7 @@ from sklearn.model_selection import GridSearchCV from nemos import basis -from nemos.basis._basis import TransformerBasis +from nemos.basis._transformer_basis import TransformerBasis @pytest.mark.parametrize( From 6a0574dc6c765a9e956c8230205dbd799e60d890 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 28 Nov 2024 19:47:14 -0500 Subject: [PATCH 068/109] fix warns --- tests/test_identifiability_constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index e09742fb..ace846ea 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -215,7 +215,7 @@ def test_apply_constraint_by_basis_with_invalid(invalid_entries): ) # add invalid x[:2, 2] = invalid_entries - constrained_x, kept_cols = apply_identifiability_constraints(x) + constrained_x, kept_cols = apply_identifiability_constraints(x, warn_if_float32=False) assert jnp.array_equal(kept_cols, jnp.arange(1, 5)) assert constrained_x.shape[0] == x.shape[0] assert jnp.all(jnp.isnan(constrained_x[:2])) From b7a7b60c9ffc0f4003f915bf94c6d8ca1d9f8e64 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 09:11:27 -0500 Subject: [PATCH 069/109] generalized tests --- tests/test_basis.py | 32 +++++------------------ tests/test_identifiability_constraints.py | 4 ++- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index c8412897..3f86ebe7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -27,31 +27,6 @@ from nemos.utils import pynapple_concatenate_numpy -@pytest.fixture() -def class_specific_params(): - shared_params = ["n_basis_funcs", "label"] - eval_params = ["bounds"] - conv_params = ["window_size", "conv_kwargs"] - return dict( - EvalBSpline=shared_params + eval_params + ["order"], - ConvBSpline=shared_params + conv_params + ["order"], - EvalMSpline=shared_params + eval_params + ["order"], - ConvMSpline=shared_params + conv_params + ["order"], - EvalCyclicBSpline=shared_params + eval_params + ["order"], - ConvCyclicBSpline=shared_params + conv_params + ["order"], - EvalRaisedCosineLinear=shared_params + eval_params + ["width"], - ConvRaisedCosineLinear=shared_params + conv_params + ["width"], - EvalRaisedCosineLog=shared_params - + eval_params - + ["width", "time_scaling", "enforce_decay_to_zero"], - ConvRaisedCosineLog=shared_params - + conv_params - + ["width", "time_scaling", "enforce_decay_to_zero"], - EvalOrthExponential=shared_params + eval_params + ["decay_rates"], - ConvOrthExponential=shared_params + conv_params + ["decay_rates"], - ) - - def trim_kwargs(cls, kwargs, class_specific_params): return { key: value @@ -87,6 +62,13 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: return all_basis +@pytest.fixture() +def class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + def test_all_basis_are_tested() -> None: """Meta-test. diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ace846ea..0fda51e9 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -215,7 +215,9 @@ def test_apply_constraint_by_basis_with_invalid(invalid_entries): ) # add invalid x[:2, 2] = invalid_entries - constrained_x, kept_cols = apply_identifiability_constraints(x, warn_if_float32=False) + constrained_x, kept_cols = apply_identifiability_constraints( + x, warn_if_float32=False + ) assert jnp.array_equal(kept_cols, jnp.arange(1, 5)) assert constrained_x.shape[0] == x.shape[0] assert jnp.all(jnp.isnan(constrained_x[:2])) From 053554fbca4703f25eac942098d8082d5f876355 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 18:54:46 -0500 Subject: [PATCH 070/109] added class lev docstrings for splines --- src/nemos/basis/basis.py | 244 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index b1c35c37..5d4b67c2 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -55,6 +55,44 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalBSpline", ): + """ + B-spline 1-dimensional basis functions. + + Implementation of the one-dimensional BSpline basis [1]_. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + order : + Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + order : + Spline order. + + + References + ---------- + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> bspline_basis = EvalBSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = bspline_basis(sample_points) + """ EvalBasisMixin.__init__(self, bounds=bounds) BSplineBasis.__init__( self, @@ -139,6 +177,46 @@ def __init__( label: Optional[str] = "ConvBSpline", conv_kwargs: Optional[dict] = None, ): + """ + B-spline 1-dimensional basis functions. + + Implementation of the one-dimensional BSpline basis [1]_. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + window_size : + The window size for convolution in number of samples. + order : + Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + order : + Spline order. + + + References + ---------- + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. + Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = bspline_basis(sample_points) + """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( self, @@ -222,6 +300,38 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalCyclicBSpline", ): + """ + B-spline 1-dimensional basis functions for cyclic splines. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + order : + Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + n_basis_funcs : + Number of basis functions, int. + order : + Order of the splines used in basis functions, int. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalCyclicBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cyclic_bspline_basis(sample_points) + """ EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( self, @@ -306,6 +416,40 @@ def __init__( label: Optional[str] = "ConvCyclicBSpline", conv_kwargs: Optional[dict] = None, ): + """ + B-spline 1-dimensional basis functions for cyclic splines. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + window_size : + The window size for convolution in number of samples. + order : + Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + n_basis_funcs : + Number of basis functions, int. + order : + Order of the splines used in basis functions, int. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvCyclicBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> cyclic_bspline_basis = ConvCyclicBSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = cyclic_bspline_basis.compute_features(sample_points) + """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( self, @@ -389,6 +533,55 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalMSpline", ): + r""" + M-spline basis functions for modeling and data transformation. + + M-splines [1]_ are a type of spline basis function used for smooth curve fitting + and data representation. They are positive and integrate to one, making them + suitable for probabilistic models and density estimation. The order of an + M-spline defines its smoothness, with higher orders resulting in smoother + splines. + + This class provides functionality to create M-spline basis functions, allowing + for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` + abstract class, providing specific implementations for M-splines. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. + order : + The order of the splines used in basis functions. Must be between [1, + n_basis_funcs]. Default is 2. Higher order splines have more continuous + derivatives at each interior knot, resulting in smoother basis functions. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, + 3(4), 425-441. + + Notes + ----- + ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values + will shrink by a factor of :math:`1/\alpha`. + For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. + If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = mspline_basis(sample_points) + """ EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( self, @@ -473,6 +666,57 @@ def __init__( label: Optional[str] = "ConvMSpline", conv_kwargs: Optional[dict] = None, ): + r""" + M-spline basis functions for modeling and data transformation. + + M-splines [1]_ are a type of spline basis function used for smooth curve fitting + and data representation. They are positive and integrate to one, making them + suitable for probabilistic models and density estimation. The order of an + M-spline defines its smoothness, with higher orders resulting in smoother + splines. + + This class provides functionality to create M-spline basis functions, allowing + for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` + abstract class, providing specific implementations for M-splines. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. + order : + The order of the splines used in basis functions. Must be between [1, + n_basis_funcs]. Default is 2. Higher order splines have more continuous + derivatives at each interior knot, resulting in smoother basis functions. + window_size : + The window size for convolution in number of samples. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, + 3(4), 425-441. + + Notes + ----- + ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values + will shrink by a factor of :math:`1/\alpha`. + For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. + If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = mspline_basis(sample_points) + """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( self, From 04b529ff4f58ce6a2d72493daef20e5794b2f7ba Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 18:57:49 -0500 Subject: [PATCH 071/109] added class lev docstrings for orth exp --- src/nemos/basis/basis.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 5d4b67c2..3c202e9e 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -68,6 +68,10 @@ def __init__( Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. The B-splines have (order-2) continuous derivatives at each interior knot. The higher this number, the smoother the basis representation will be. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -311,6 +315,10 @@ def __init__( Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. The B-splines have (order-2) continuous derivatives at each interior knot. The higher this number, the smoother the basis representation will be. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -555,6 +563,10 @@ def __init__( The order of the splines used in basis functions. Must be between [1, n_basis_funcs]. Default is 2. Higher order splines have more continuous derivatives at each interior knot, resulting in smoother basis functions. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -1222,7 +1234,20 @@ def split_by_feature( class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): - """ + """Set of 1D basis decaying exponential functions numerically orthogonalized. + + Parameters + ---------- + n_basis_funcs + Number of basis functions. + window_size : + The window size for convolution in number of samples. + decay_rates : + Decay rates of the exponentials, shape ``(n_basis_funcs,)``. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + Examples -------- >>> import numpy as np From 38c4138194f46d3166078a344b42b10b6046474d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:08:02 -0500 Subject: [PATCH 072/109] fix all docstrings --- src/nemos/basis/basis.py | 169 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 163 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 3c202e9e..e2ba5409 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -95,7 +95,7 @@ def __init__( >>> order = 3 >>> bspline_basis = EvalBSpline(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis(sample_points) + >>> basis_functions = bspline_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) BSplineBasis.__init__( @@ -219,7 +219,7 @@ def __init__( >>> order = 3 >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis(sample_points) + >>> basis_functions = bspline_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( @@ -338,7 +338,7 @@ def __init__( >>> order = 3 >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_bspline_basis(sample_points) + >>> basis_functions = cyclic_bspline_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( @@ -592,7 +592,7 @@ def __init__( >>> order = 3 >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) + >>> basis_functions = mspline_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( @@ -727,7 +727,7 @@ def __init__( >>> order = 3 >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) + >>> basis_functions = mspline_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( @@ -807,6 +807,43 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class EvalRaisedCosineLinear( EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin ): + """ + Represent linearly-spaced raised cosine basis functions. + + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. By default, it's set to 2.0. + bounds : + The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLinear + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = EvalRaisedCosineLinear(n_basis_funcs) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + """ def __init__( self, n_basis_funcs: int, @@ -885,6 +922,41 @@ def split_by_feature( class ConvRaisedCosineLinear( ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin ): + """ + Represent linearly-spaced raised cosine basis functions. + + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. By default, it's set to 2.0. + window_size : + The window size for convolution in number of samples. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLinear + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = ConvRaisedCosineLinear(n_basis_funcs, window_size=10) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + """ def __init__( self, n_basis_funcs: int, @@ -971,6 +1043,49 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalRaisedCosineLog", ): + """Represent log-spaced raised cosine basis functions. + + Similar to ``EvalRaisedCosineLinear`` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. + time_scaling : + Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with + larger values resulting in more stretching. As this approaches 0, the + transformation becomes linear. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. + window_size : + The window size for convolution. Required if mode is 'conv'. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLog + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = EvalRaisedCosineLog(n_basis_funcs) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + """ EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( self, @@ -1052,6 +1167,49 @@ def __init__( label: Optional[str] = "ConvRaisedCosineLog", conv_kwargs: Optional[dict] = None, ): + """Represent log-spaced raised cosine basis functions. + + Similar to ``ConvRaisedCosineLinear`` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. + time_scaling : + Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with + larger values resulting in more stretching. As this approaches 0, the + transformation becomes linear. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. + window_size : + The window size for convolution. Required if mode is 'conv'. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLog + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = ConvRaisedCosineLog(n_basis_funcs, window_size=10) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( self, @@ -1260,7 +1418,6 @@ class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> basis_functions = ortho_basis.compute_features(sample_points) - """ def __init__( From 7c754bfd3860069db266334e1bb78911b5b72cd4 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:16:00 -0500 Subject: [PATCH 073/109] removed unnecessary kwargs --- src/nemos/basis/_decaying_exponential.py | 10 ----- src/nemos/basis/_raised_cosine_basis.py | 15 -------- src/nemos/basis/_spline_basis.py | 48 +++--------------------- 3 files changed, 5 insertions(+), 68 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index d6b92e59..0a6f5003 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -34,8 +34,6 @@ class OrthExponentialBasis(Basis, abc.ABC): mode : The mode of operation. ``'eval'`` for evaluation at sample points, ``'conv'`` for convolutional operation. - window_size : - The window size for convolution. Required if mode is ``'conv'``. bounds : The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the minimum and the maximum of the samples provided when evaluating the basis. @@ -43,12 +41,6 @@ class OrthExponentialBasis(Basis, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. """ def __init__( @@ -57,13 +49,11 @@ def __init__( decay_rates: NDArray[np.floating], mode="eval", label: Optional[str] = "OrthExponentialBasis", - **kwargs, ): super().__init__( n_basis_funcs, mode=mode, label=label, - **kwargs, ) self.decay_rates = decay_rates self._check_rates() diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 541a6818..508a2226 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -43,12 +43,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. References ---------- @@ -70,7 +64,6 @@ def __init__( n_basis_funcs, mode=mode, label=label, - **kwargs, ) self._n_input_dimensionality = 1 self._check_width(width) @@ -240,12 +233,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. References ---------- @@ -263,13 +250,11 @@ def __init__( time_scaling: float = None, enforce_decay_to_zero: bool = True, label: Optional[str] = "RaisedCosineBasisLog", - **kwargs, ) -> None: super().__init__( n_basis_funcs, mode=mode, width=width, - **kwargs, label=label, ) # The samples are scaled appropriately in the self._transform_samples which scales diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 617138b0..6342faad 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -4,7 +4,7 @@ import abc import copy from functools import partial -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import numpy as np from numpy.typing import ArrayLike, NDArray @@ -34,21 +34,9 @@ class SplineBasis(Basis, abc.ABC): 'conv' for convolutional operation. order : optional Spline order. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to ``nemos.convolve.create_convolutional_predictor`` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. Attributes ---------- @@ -61,13 +49,13 @@ def __init__( n_basis_funcs: int, order: int = 2, label: Optional[str] = None, - **kwargs, + mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order super().__init__( n_basis_funcs, label=label, - **kwargs, + mode=mode, ) self._n_input_dimensionality = 1 @@ -186,17 +174,9 @@ class MSplineBasis(SplineBasis, abc.ABC): The order of the splines used in basis functions. Must be between [1, n_basis_funcs]. Default is 2. Higher order splines have more continuous derivatives at each interior knot, resulting in smoother basis functions. - window_size : - The window size for convolution. Required if mode is 'conv'. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs: - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. References ---------- @@ -225,13 +205,13 @@ class MSplineBasis(SplineBasis, abc.ABC): def __init__( self, n_basis_funcs: int, + mode: Literal["eval", "conv"] = "eval", order: int = 2, label: Optional[str] = "EvalMSpline", - **kwargs, ) -> None: super().__init__( n_basis_funcs, - mode="eval", + mode=mode, order=order, label=label, ) @@ -326,12 +306,6 @@ class BSplineBasis(SplineBasis, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. Attributes ---------- @@ -351,14 +325,12 @@ def __init__( mode="eval", order: int = 4, label: Optional[str] = "BSplineBasis", - **kwargs, ): super().__init__( n_basis_funcs, mode=mode, order=order, label=label, - **kwargs, ) @support_pynapple(conv_type="numpy") @@ -442,12 +414,6 @@ class CyclicBSplineBasis(SplineBasis, abc.ABC): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - **kwargs : - Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when - ``mode='conv'``; These arguments are used to change the default behavior of the convolution. - For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. - Note that one cannot change the default value for the ``axis`` parameter. Basis assumes - that the convolution axis is ``axis=0``. Attributes ---------- @@ -462,17 +428,13 @@ def __init__( n_basis_funcs: int, mode="eval", order: int = 4, - window_size: Optional[int] = None, - bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "CyclicBSplineBasis", - **kwargs, ): super().__init__( n_basis_funcs, mode=mode, order=order, label=label, - **kwargs, ) if self.order < 2: raise ValueError( From 4493196735ae852947695f8c0008aad81e98cccd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:16:23 -0500 Subject: [PATCH 074/109] removed unnecessary kwargs --- src/nemos/basis/_raised_cosine_basis.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 508a2226..6570d346 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -34,12 +34,6 @@ class RaisedCosineBasisLinear(Basis, abc.ABC): 'conv' for convolutional operation. width : Width of the raised cosine. By default, it's set to 2.0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. @@ -58,7 +52,6 @@ def __init__( mode="eval", width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", - **kwargs, ) -> None: super().__init__( n_basis_funcs, From 00a80f5e2e1e5d5c0bc7c2229f40e034457f3a1b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:22:57 -0500 Subject: [PATCH 075/109] pydocstyle --- src/nemos/basis/_basis.py | 2 ++ src/nemos/basis/_basis_mixin.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index b6ec1fb7..ceed7a0f 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -20,6 +20,7 @@ def add_docstring(method_name, cls=None): + """Prepend super-class docstrings.""" attr = getattr(cls, method_name, None) if attr is None: raise AttributeError(f"{cls.__name__} has no attribute {method_name}!") @@ -49,6 +50,7 @@ def wrapper(self: Basis, *xi: ArrayLike, **kwargs) -> NDArray: def check_one_dimensional(func: Callable) -> Callable: + """Check if the input is one-dimensional.""" @wraps(func) def wrapper(self: Basis, *xi: ArrayLike, **kwargs): if any(x.ndim != 1 for x in xi): diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 8ac42af1..502ad791 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -12,6 +12,7 @@ class EvalBasisMixin: + """Mixin class for evaluational basis.""" def __init__(self, bounds: Optional[Tuple[float, float]] = None): self.bounds = bounds @@ -80,6 +81,7 @@ def bounds(self, values: Union[None, Tuple[float, float]]): class ConvBasisMixin: + """Mixin class for convolutional basis.""" def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.window_size = window_size @@ -153,7 +155,6 @@ def window_size(self): @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" - if window_size is None: raise ValueError( "If the basis is in `conv` mode, you must provide a window_size!" @@ -225,7 +226,7 @@ def _check_convolution_kwargs(conv_kwargs: dict): class BasisTransformerMixin: - """Mixin class for constructing a transformer""" + """Mixin class for constructing a transformer.""" def to_transformer(self) -> TransformerBasis: """ From ff51725c66c7d41f26ab5df3eb271cb0f7ecab1f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:25:20 -0500 Subject: [PATCH 076/109] linted --- src/nemos/basis/_basis.py | 1 + src/nemos/basis/basis.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index ceed7a0f..980b9657 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -51,6 +51,7 @@ def wrapper(self: Basis, *xi: ArrayLike, **kwargs) -> NDArray: def check_one_dimensional(func: Callable) -> Callable: """Check if the input is one-dimensional.""" + @wraps(func) def wrapper(self: Basis, *xi: ArrayLike, **kwargs): if any(x.ndim != 1 for x in xi): diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index e2ba5409..92f3dcf6 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -84,8 +84,9 @@ def __init__( References ---------- - .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. - Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline + Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. + https://doi.org/10.1007/978-3-662-04919-8_5 Examples -------- @@ -208,8 +209,9 @@ def __init__( References ---------- - .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques. - Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5 + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: + Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. + https://doi.org/10.1007/978-3-662-04919-8_5 Examples -------- @@ -579,8 +581,8 @@ def __init__( Notes ----- ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values - will shrink by a factor of :math:`1/\alpha`. + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain + (y-axis) values will shrink by a factor of :math:`1/\alpha`. For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. @@ -714,8 +716,8 @@ def __init__( Notes ----- ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values - will shrink by a factor of :math:`1/\alpha`. + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain + (y-axis) values will shrink by a factor of :math:`1/\alpha`. For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. @@ -844,6 +846,7 @@ class EvalRaisedCosineLinear( >>> # convolve the basis >>> basis_functions = raised_cosine_basis.compute_features(sample_points) """ + def __init__( self, n_basis_funcs: int, @@ -957,6 +960,7 @@ class ConvRaisedCosineLinear( >>> # convolve the basis >>> basis_functions = raised_cosine_basis.compute_features(sample_points) """ + def __init__( self, n_basis_funcs: int, From 4a588f01acf881881f07fcd0e557c2b0a3fce6b6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:26:35 -0500 Subject: [PATCH 077/109] fixed naming --- src/nemos/basis/basis.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 92f3dcf6..452cf4b3 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -221,7 +221,7 @@ def __init__( >>> order = 3 >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis.compute_features(sample_points) + >>> features = bspline_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( @@ -340,7 +340,7 @@ def __init__( >>> order = 3 >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_bspline_basis.compute_features(sample_points) + >>> features = cyclic_bspline_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( @@ -458,7 +458,7 @@ def __init__( >>> order = 3 >>> cyclic_bspline_basis = ConvCyclicBSpline(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = cyclic_bspline_basis.compute_features(sample_points) + >>> features = cyclic_bspline_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( @@ -594,7 +594,7 @@ def __init__( >>> order = 3 >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis.compute_features(sample_points) + >>> features = mspline_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( @@ -729,7 +729,7 @@ def __init__( >>> order = 3 >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis.compute_features(sample_points) + >>> features = mspline_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( @@ -844,7 +844,7 @@ class EvalRaisedCosineLinear( >>> raised_cosine_basis = EvalRaisedCosineLinear(n_basis_funcs) >>> sample_points = np.random.randn(100) >>> # convolve the basis - >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + >>> features = raised_cosine_basis.compute_features(sample_points) """ def __init__( @@ -958,7 +958,7 @@ class ConvRaisedCosineLinear( >>> raised_cosine_basis = ConvRaisedCosineLinear(n_basis_funcs, window_size=10) >>> sample_points = np.random.randn(100) >>> # convolve the basis - >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + >>> features = raised_cosine_basis.compute_features(sample_points) """ def __init__( @@ -1088,7 +1088,7 @@ def __init__( >>> raised_cosine_basis = EvalRaisedCosineLog(n_basis_funcs) >>> sample_points = np.random.randn(100) >>> # convolve the basis - >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + >>> features = raised_cosine_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( @@ -1212,7 +1212,7 @@ def __init__( >>> raised_cosine_basis = ConvRaisedCosineLog(n_basis_funcs, window_size=10) >>> sample_points = np.random.randn(100) >>> # convolve the basis - >>> basis_functions = raised_cosine_basis.compute_features(sample_points) + >>> features = raised_cosine_basis.compute_features(sample_points) """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( @@ -1320,7 +1320,7 @@ def __init__( >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates) >>> sample_points = linspace(0, 1, 100) >>> # evaluate the basis - >>> basis_functions = ortho_basis.compute_features(sample_points) + >>> features = ortho_basis.compute_features(sample_points) """ EvalBasisMixin.__init__(self, bounds=bounds) @@ -1421,7 +1421,7 @@ class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): >>> ortho_basis = ConvOrthExponential(n_basis_funcs, window_size, decay_rates) >>> sample_points = np.random.randn(100) >>> # convolve the basis - >>> basis_functions = ortho_basis.compute_features(sample_points) + >>> features = ortho_basis.compute_features(sample_points) """ def __init__( From be5a5b89d831e428dd77e6a53773ce3df23c8254 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:38:55 -0500 Subject: [PATCH 078/109] moved docstrings from init --- src/nemos/basis/basis.py | 759 ++++++++++++++++++++------------------- 1 file changed, 384 insertions(+), 375 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 452cf4b3..ca36d32a 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -48,6 +48,50 @@ def __dir__() -> list[str]: class EvalBSpline(EvalBasisMixin, BSplineBasis): + """ + B-spline 1-dimensional basis functions. + + Implementation of the one-dimensional BSpline basis [1]_. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + order : + Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + order : + Spline order. + + + References + ---------- + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline + Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. + https://doi.org/10.1007/978-3-662-04919-8_5 + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> bspline_basis = EvalBSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> basis_functions = bspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -55,49 +99,6 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalBSpline", ): - """ - B-spline 1-dimensional basis functions. - - Implementation of the one-dimensional BSpline basis [1]_. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - order : - Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - bounds : - The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - Attributes - ---------- - order : - Spline order. - - - References - ---------- - .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline - Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. - https://doi.org/10.1007/978-3-662-04919-8_5 - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import EvalBSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> bspline_basis = EvalBSpline(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = bspline_basis.compute_features(sample_points) - """ EvalBasisMixin.__init__(self, bounds=bounds) BSplineBasis.__init__( self, @@ -174,6 +175,48 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class ConvBSpline(ConvBasisMixin, BSplineBasis): + """ + B-spline 1-dimensional basis functions. + + Implementation of the one-dimensional BSpline basis [1]_. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + window_size : + The window size for convolution in number of samples. + order : + Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + order : + Spline order. + + + References + ---------- + .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: + Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. + https://doi.org/10.1007/978-3-662-04919-8_5 + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> features = bspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -182,47 +225,6 @@ def __init__( label: Optional[str] = "ConvBSpline", conv_kwargs: Optional[dict] = None, ): - """ - B-spline 1-dimensional basis functions. - - Implementation of the one-dimensional BSpline basis [1]_. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - window_size : - The window size for convolution in number of samples. - order : - Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - Attributes - ---------- - order : - Spline order. - - - References - ---------- - .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: - Bézier and B-Spline Techniques. Mathematics and Visualization. Springer, Berlin, Heidelberg. - https://doi.org/10.1007/978-3-662-04919-8_5 - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import ConvBSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> features = bspline_basis.compute_features(sample_points) - """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) BSplineBasis.__init__( self, @@ -299,6 +301,43 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): + """ + B-spline 1-dimensional basis functions for cyclic splines. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + order : + Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + n_basis_funcs : + Number of basis functions, int. + order : + Order of the splines used in basis functions, int. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalCyclicBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> features = cyclic_bspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -306,42 +345,6 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalCyclicBSpline", ): - """ - B-spline 1-dimensional basis functions for cyclic splines. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - order : - Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - bounds : - The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - Attributes - ---------- - n_basis_funcs : - Number of basis functions, int. - order : - Order of the splines used in basis functions, int. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import EvalCyclicBSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> features = cyclic_bspline_basis.compute_features(sample_points) - """ EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( self, @@ -418,6 +421,41 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): + """ + B-spline 1-dimensional basis functions for cyclic splines. + + Parameters + ---------- + n_basis_funcs : + Number of basis functions. + window_size : + The window size for convolution in number of samples. + order : + Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. + The B-splines have (order-2) continuous derivatives at each interior knot. + The higher this number, the smoother the basis representation will be. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Attributes + ---------- + n_basis_funcs : + Number of basis functions, int. + order : + Order of the splines used in basis functions, int. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvCyclicBSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> cyclic_bspline_basis = ConvCyclicBSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> features = cyclic_bspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -426,40 +464,6 @@ def __init__( label: Optional[str] = "ConvCyclicBSpline", conv_kwargs: Optional[dict] = None, ): - """ - B-spline 1-dimensional basis functions for cyclic splines. - - Parameters - ---------- - n_basis_funcs : - Number of basis functions. - window_size : - The window size for convolution in number of samples. - order : - Order of the splines used in basis functions. Order must lie within [2, n_basis_funcs]. - The B-splines have (order-2) continuous derivatives at each interior knot. - The higher this number, the smoother the basis representation will be. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - Attributes - ---------- - n_basis_funcs : - Number of basis functions, int. - order : - Order of the splines used in basis functions, int. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import ConvCyclicBSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> cyclic_bspline_basis = ConvCyclicBSpline(n_basis_funcs, order=order, window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> features = cyclic_bspline_basis.compute_features(sample_points) - """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) CyclicBSplineBasis.__init__( self, @@ -536,6 +540,60 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class EvalMSpline(EvalBasisMixin, MSplineBasis): + r""" + M-spline basis functions for modeling and data transformation. + + M-splines [1]_ are a type of spline basis function used for smooth curve fitting + and data representation. They are positive and integrate to one, making them + suitable for probabilistic models and density estimation. The order of an + M-spline defines its smoothness, with higher orders resulting in smoother + splines. + + This class provides functionality to create M-spline basis functions, allowing + for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` + abstract class, providing specific implementations for M-splines. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. + order : + The order of the splines used in basis functions. Must be between [1, + n_basis_funcs]. Default is 2. Higher order splines have more continuous + derivatives at each interior knot, resulting in smoother basis functions. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, + 3(4), 425-441. + + Notes + ----- + ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain + (y-axis) values will shrink by a factor of :math:`1/\alpha`. + For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. + If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import EvalMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> sample_points = linspace(0, 1, 100) + >>> features = mspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -543,59 +601,6 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalMSpline", ): - r""" - M-spline basis functions for modeling and data transformation. - - M-splines [1]_ are a type of spline basis function used for smooth curve fitting - and data representation. They are positive and integrate to one, making them - suitable for probabilistic models and density estimation. The order of an - M-spline defines its smoothness, with higher orders resulting in smoother - splines. - - This class provides functionality to create M-spline basis functions, allowing - for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` - abstract class, providing specific implementations for M-splines. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. - order : - The order of the splines used in basis functions. Must be between [1, - n_basis_funcs]. Default is 2. Higher order splines have more continuous - derivatives at each interior knot, resulting in smoother basis functions. - bounds : - The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - References - ---------- - .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, - 3(4), 425-441. - - Notes - ----- - ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain - (y-axis) values will shrink by a factor of :math:`1/\alpha`. - For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. - If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import EvalMSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) - >>> sample_points = linspace(0, 1, 100) - >>> features = mspline_basis.compute_features(sample_points) - """ EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( self, @@ -672,6 +677,58 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class ConvMSpline(ConvBasisMixin, MSplineBasis): + r""" + M-spline basis functions for modeling and data transformation. + + M-splines [1]_ are a type of spline basis function used for smooth curve fitting + and data representation. They are positive and integrate to one, making them + suitable for probabilistic models and density estimation. The order of an + M-spline defines its smoothness, with higher orders resulting in smoother + splines. + + This class provides functionality to create M-spline basis functions, allowing + for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` + abstract class, providing specific implementations for M-splines. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions to generate. More basis functions allow for + more flexible data modeling but can lead to overfitting. + order : + The order of the splines used in basis functions. Must be between [1, + n_basis_funcs]. Default is 2. Higher order splines have more continuous + derivatives at each interior knot, resulting in smoother basis functions. + window_size : + The window size for convolution in number of samples. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, + 3(4), 425-441. + + Notes + ----- + ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain + (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain + (y-axis) values will shrink by a factor of :math:`1/\alpha`. + For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. + If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. + + Examples + -------- + >>> from numpy import linspace + >>> from nemos.basis import ConvMSpline + >>> n_basis_funcs = 5 + >>> order = 3 + >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) + >>> sample_points = linspace(0, 1, 100) + >>> features = mspline_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -680,57 +737,6 @@ def __init__( label: Optional[str] = "ConvMSpline", conv_kwargs: Optional[dict] = None, ): - r""" - M-spline basis functions for modeling and data transformation. - - M-splines [1]_ are a type of spline basis function used for smooth curve fitting - and data representation. They are positive and integrate to one, making them - suitable for probabilistic models and density estimation. The order of an - M-spline defines its smoothness, with higher orders resulting in smoother - splines. - - This class provides functionality to create M-spline basis functions, allowing - for flexible and smooth modeling of data. It inherits from the ``SplineBasis`` - abstract class, providing specific implementations for M-splines. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions to generate. More basis functions allow for - more flexible data modeling but can lead to overfitting. - order : - The order of the splines used in basis functions. Must be between [1, - n_basis_funcs]. Default is 2. Higher order splines have more continuous - derivatives at each interior knot, resulting in smoother basis functions. - window_size : - The window size for convolution in number of samples. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - References - ---------- - .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science, - 3(4), 425-441. - - Notes - ----- - ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain - (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain - (y-axis) values will shrink by a factor of :math:`1/\alpha`. - For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18. - If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2. - - Examples - -------- - >>> from numpy import linspace - >>> from nemos.basis import ConvMSpline - >>> n_basis_funcs = 5 - >>> order = 3 - >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) - >>> sample_points = linspace(0, 1, 100) - >>> features = mspline_basis.compute_features(sample_points) - """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) MSplineBasis.__init__( self, @@ -1038,6 +1044,50 @@ def split_by_feature( class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): + """Represent log-spaced raised cosine basis functions. + + Similar to ``EvalRaisedCosineLinear`` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. + time_scaling : + Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with + larger values resulting in more stretching. As this approaches 0, the + transformation becomes linear. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. + window_size : + The window size for convolution. Required if mode is 'conv'. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import EvalRaisedCosineLog + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = EvalRaisedCosineLog(n_basis_funcs) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> features = raised_cosine_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -1047,49 +1097,6 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalRaisedCosineLog", ): - """Represent log-spaced raised cosine basis functions. - - Similar to ``EvalRaisedCosineLinear`` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al. [1]_ - to uniformly tile the internal points of the domain. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - width : - Width of the raised cosine. - time_scaling : - Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with - larger values resulting in more stretching. As this approaches 0, the - transformation becomes linear. - enforce_decay_to_zero: - If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements - and subsequently trims off the extra basis elements. This ensures that the final basis element - decays to 0. - window_size : - The window size for convolution. Required if mode is 'conv'. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - References - ---------- - .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLog - >>> n_basis_funcs = 5 - >>> raised_cosine_basis = EvalRaisedCosineLog(n_basis_funcs) - >>> sample_points = np.random.randn(100) - >>> # convolve the basis - >>> features = raised_cosine_basis.compute_features(sample_points) - """ EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( self, @@ -1161,6 +1168,50 @@ def split_by_feature( class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): + """Represent log-spaced raised cosine basis functions. + + Similar to ``ConvRaisedCosineLinear`` but the basis functions are log-spaced. + This implementation is based on the cosine bumps used by Pillow et al. [1]_ + to uniformly tile the internal points of the domain. + + Parameters + ---------- + n_basis_funcs : + The number of basis functions. + width : + Width of the raised cosine. + time_scaling : + Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with + larger values resulting in more stretching. As this approaches 0, the + transformation becomes linear. + enforce_decay_to_zero: + If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements + and subsequently trims off the extra basis elements. This ensures that the final basis element + decays to 0. + window_size : + The window size for convolution. Required if mode is 'conv'. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + References + ---------- + .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., + C. E. (2005). Prediction and decoding of retinal ganglion cell responses + with a probabilistic spiking model. Journal of Neuroscience, 25(47), + 11003–11013. + + Examples + -------- + >>> import numpy as np + >>> from nemos.basis import ConvRaisedCosineLog + >>> n_basis_funcs = 5 + >>> raised_cosine_basis = ConvRaisedCosineLog(n_basis_funcs, window_size=10) + >>> sample_points = np.random.randn(100) + >>> # convolve the basis + >>> features = raised_cosine_basis.compute_features(sample_points) + """ + def __init__( self, n_basis_funcs: int, @@ -1171,49 +1222,6 @@ def __init__( label: Optional[str] = "ConvRaisedCosineLog", conv_kwargs: Optional[dict] = None, ): - """Represent log-spaced raised cosine basis functions. - - Similar to ``ConvRaisedCosineLinear`` but the basis functions are log-spaced. - This implementation is based on the cosine bumps used by Pillow et al. [1]_ - to uniformly tile the internal points of the domain. - - Parameters - ---------- - n_basis_funcs : - The number of basis functions. - width : - Width of the raised cosine. - time_scaling : - Non-negative hyper-parameter controlling the logarithmic stretch magnitude, with - larger values resulting in more stretching. As this approaches 0, the - transformation becomes linear. - enforce_decay_to_zero: - If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements - and subsequently trims off the extra basis elements. This ensures that the final basis element - decays to 0. - window_size : - The window size for convolution. Required if mode is 'conv'. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - References - ---------- - .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J., - C. E. (2005). Prediction and decoding of retinal ganglion cell responses - with a probabilistic spiking model. Journal of Neuroscience, 25(47), - 11003–11013. - - Examples - -------- - >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLog - >>> n_basis_funcs = 5 - >>> raised_cosine_basis = ConvRaisedCosineLog(n_basis_funcs, window_size=10) - >>> sample_points = np.random.randn(100) - >>> # convolve the basis - >>> features = raised_cosine_basis.compute_features(sample_points) - """ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) RaisedCosineBasisLog.__init__( self, @@ -1285,6 +1293,38 @@ def split_by_feature( class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): + """Set of 1D basis decaying exponential functions numerically orthogonalized. + + Parameters + ---------- + n_basis_funcs + Number of basis functions. + decay_rates : + Decay rates of the exponentials, shape ``(n_basis_funcs,)``. + bounds : + The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the + minimum and the maximum of the samples provided when evaluating the basis. + If a sample is outside the bounds, the basis will return NaN. + label : + The label of the basis, intended to be descriptive of the task variable being processed. + For example: velocity, position, spike_counts. + + Examples + -------- + >>> import numpy as np + >>> from numpy import linspace + >>> from nemos.basis import EvalOrthExponential + >>> X = np.random.normal(size=(1000, 1)) + >>> n_basis_funcs = 5 + >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates + >>> window_size = 10 + >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates) + >>> sample_points = linspace(0, 1, 100) + >>> # evaluate the basis + >>> features = ortho_basis.compute_features(sample_points) + + """ + def __init__( self, n_basis_funcs: int, @@ -1292,37 +1332,6 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "EvalOrthExponential", ): - """Set of 1D basis decaying exponential functions numerically orthogonalized. - - Parameters - ---------- - n_basis_funcs - Number of basis functions. - decay_rates : - Decay rates of the exponentials, shape ``(n_basis_funcs,)``. - bounds : - The bounds for the basis domain. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. - label : - The label of the basis, intended to be descriptive of the task variable being processed. - For example: velocity, position, spike_counts. - - Examples - -------- - >>> import numpy as np - >>> from numpy import linspace - >>> from nemos.basis import EvalOrthExponential - >>> X = np.random.normal(size=(1000, 1)) - >>> n_basis_funcs = 5 - >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates - >>> window_size = 10 - >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates) - >>> sample_points = linspace(0, 1, 100) - >>> # evaluate the basis - >>> features = ortho_basis.compute_features(sample_points) - - """ EvalBasisMixin.__init__(self, bounds=bounds) OrthExponentialBasis.__init__( self, From ed414becd3ff5ab327ce700c0d85b1300c1d11cb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 19:46:05 -0500 Subject: [PATCH 079/109] removed attrs from class docstrings --- src/nemos/basis/basis.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index ca36d32a..eb8fa770 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -69,11 +69,6 @@ class EvalBSpline(EvalBasisMixin, BSplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - Attributes - ---------- - order : - Spline order. - References ---------- @@ -194,12 +189,6 @@ class ConvBSpline(ConvBasisMixin, BSplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - Attributes - ---------- - order : - Spline order. - - References ---------- .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: @@ -320,13 +309,6 @@ class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - Attributes - ---------- - n_basis_funcs : - Number of basis functions, int. - order : - Order of the splines used in basis functions, int. - Examples -------- >>> from numpy import linspace @@ -438,13 +420,6 @@ class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. - Attributes - ---------- - n_basis_funcs : - Number of basis functions, int. - order : - Order of the splines used in basis functions, int. - Examples -------- >>> from numpy import linspace From 2b7b123af524ef63131b7e31f37e49099d47e46d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 3 Dec 2024 11:13:41 -0500 Subject: [PATCH 080/109] removed args from docstrings --- src/nemos/basis/_decaying_exponential.py | 4 ---- src/nemos/basis/_raised_cosine_basis.py | 6 ------ 2 files changed, 10 deletions(-) diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 0a6f5003..15202e6d 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -34,10 +34,6 @@ class OrthExponentialBasis(Basis, abc.ABC): mode : The mode of operation. ``'eval'`` for evaluation at sample points, ``'conv'`` for convolutional operation. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 6570d346..bf7bda95 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -217,12 +217,6 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear, abc.ABC): If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements and subsequently trims off the extra basis elements. This ensures that the final basis element decays to 0. - window_size : - The window size for convolution. Required if mode is 'conv'. - bounds : - The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the - minimum and the maximum of the samples provided when evaluating the basis. - If a sample is outside the bounds, the basis will return NaN. label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. From 77009a7f1a7825f06af9735c78140698178a6f47 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Tue, 3 Dec 2024 11:57:35 -0500 Subject: [PATCH 081/109] Update typing.py --- src/nemos/typing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nemos/typing.py b/src/nemos/typing.py index 7f88be21..62c6525f 100644 --- a/src/nemos/typing.py +++ b/src/nemos/typing.py @@ -1,5 +1,7 @@ """Collection of nemos typing.""" +from __future__ import annotations + from typing import Any, Callable, NamedTuple, Tuple, Union import jax.numpy as jnp From f96d08f430b31077e081c48df741ec5aeedc47a8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 3 Dec 2024 17:54:49 -0500 Subject: [PATCH 082/109] fixed tests and renamed funcs --- docs/api_reference.rst | 30 +- docs/background/plot_01_1D_basis_function.md | 120 ++--- docs/background/plot_02_ND_basis_function.md | 10 +- docs/background/plot_03_1D_convolution.md | 8 +- docs/conf.py | 2 +- docs/developers_notes/04-basis_module.md | 12 +- docs/how_to_guide/plot_02_glm_demo.md | 2 +- docs/how_to_guide/plot_04_batch_glm.md | 2 +- .../plot_05_sklearn_pipeline_cv_demo.md | 32 +- docs/how_to_guide/plot_06_glm_pytree.md | 6 +- docs/quickstart.md | 6 +- docs/tutorials/plot_02_head_direction.md | 4 +- docs/tutorials/plot_03_grid_cells.md | 4 +- docs/tutorials/plot_04_v1_cells.md | 2 +- docs/tutorials/plot_05_place_cells.md | 12 +- docs/tutorials/plot_06_calcium_imaging.md | 4 +- src/nemos/_documentation_utils/plotting.py | 4 +- src/nemos/basis/__init__.py | 28 +- src/nemos/basis/_basis.py | 92 ++-- src/nemos/basis/_basis_mixin.py | 6 +- src/nemos/basis/_decaying_exponential.py | 7 +- src/nemos/basis/_raised_cosine_basis.py | 13 +- src/nemos/basis/_spline_basis.py | 21 +- src/nemos/basis/_transformer_basis.py | 32 +- src/nemos/basis/basis.py | 429 +++++++++--------- src/nemos/identifiability_constraints.py | 8 +- src/nemos/simulation.py | 4 +- tests/conftest.py | 2 +- tests/test_basis.py | 235 +++++----- tests/test_identifiability_constraints.py | 12 +- tests/test_pipeline.py | 80 ++-- tests/test_simulation.py | 2 +- 32 files changed, 580 insertions(+), 651 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 7f6b6aea..f41c4c02 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -50,7 +50,7 @@ These classes are the building blocks for the concrete basis classes. **Bases For Convolution:** -.. currentmodule:: nemos.basis.basis +.. currentmodule:: nemos.basis .. autosummary:: :toctree: generated/basis @@ -58,28 +58,30 @@ These classes are the building blocks for the concrete basis classes. :nosignatures: - ConvMSpline - ConvBSpline - ConvCyclicBSpline - ConvRaisedCosineLinear - ConvRaisedCosineLog - ConvOrthExponential + MSplineConv + BSplineConv + CyclicBSplineConv + RaisedCosineLinearConv + RaisedCosineLogConv + OrthExponentialConv + +.. check for a config that prints only nemos.basis.Name **Bases For Non-Linear Mapping:** -.. currentmodule:: nemos.basis.basis +.. currentmodule:: nemos.basis .. autosummary:: :toctree: generated/basis :recursive: :nosignatures: - EvalMSpline - EvalBSpline - EvalCyclicBSpline - EvalRaisedCosineLinear - EvalRaisedCosineLog - EvalOrthExponential + MSplineEval + BSplineEval + CyclicBSplineEval + RaisedCosineLinearEval + RaisedCosineLogEval + OrthExponentialEval **Composite Bases:** diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 3e494265..424be525 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -45,7 +45,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`EvalMSpline`](nemos.basis.basis.EvalMSpline). +We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.basis.MSplineEval). The hyperparameters required to initialize this class are: - The number of basis functions, which should be a positive integer. @@ -63,96 +63,22 @@ order = 4 n_basis = 10 # Define the 1D basis function object -bspline = nmo.basis.EvalBSpline(n_basis_funcs=n_basis, order=order) -``` - -## Evaluating a Basis - -The [`Basis`](nemos.basis._basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis -is defined by the samples that we input to the [`__call__`](nemos.basis._basis.Basis.__call__) method, and covers from the smallest to the largest value. - - -```{code-cell} ipython3 - -# Generate a time series of sample points -samples = nap.Tsd(t=np.arange(1001), d=np.linspace(0, 1,1001)) - -# Evaluate the basis at the sample points -eval_basis = bspline(samples) - -# Output information about the evaluated basis -print(f"Evaluated B-spline of order {order} with {eval_basis.shape[1]} " - f"basis element and {eval_basis.shape[0]} samples.") - -fig = plt.figure() -plt.title("B-spline basis") -plt.plot(samples, eval_basis); -``` - -```{code-cell} ipython3 -:tags: [hide-input] - -# save image for thumbnail -from pathlib import Path -import os - -root = os.environ.get("READTHEDOCS_OUTPUT") -if root: - path = Path(root) / "html/_static/thumbnails/background" -# if local store in ../_build/html/... -else: - path = Path("../_build/html/_static/thumbnails/background") - -# make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): - path.mkdir(parents=True, exist_ok=True) - -if path.exists(): - fig.savefig(path / "plot_01_1D_basis_function.svg") -``` - -## Setting the basis support -Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that -your basis covers the same range across multiple experimental sessions. -You can specify a range for the support of your basis by setting the `bounds` -parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN. - - -```{code-cell} ipython3 -bspline_range = nmo.basis.EvalBSpline(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) - -print("Evaluated basis:") -# 0.5 is within the support, 0.1 is outside the support -print(np.round(bspline_range([0.5, 0.1]), 3)) -``` - -Let's compare the default behavior of basis (estimating the range from the samples) with -the fixed range basis. - - -```{code-cell} ipython3 -fig, axs = plt.subplots(2,1, sharex=True) -plt.suptitle("B-spline basis ") -axs[0].plot(samples, bspline(samples), color="k") -axs[0].set_title("default") -axs[1].plot(samples, bspline_range(samples), color="tomato") -axs[1].set_title("bounds=[0.2, 0.8]") -plt.tight_layout() +bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order) ``` ## Feature Computation The bases in the `nemos.basis` module can be grouped into two categories: -1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `EvalBSpline`. +1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `BSplineEval`. -2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `ConvBSpline`. +2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `BSplineConv`. Let's see how this two modalities operate. ```{code-cell} ipython3 -eval_mode = nmo.basis.EvalMSpline(n_basis_funcs=n_basis) -conv_mode = nmo.basis.ConvMSpline(n_basis_funcs=n_basis, window_size=100) +eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis) +conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100) # define an input angles = np.linspace(0, np.pi*4, 201) @@ -195,7 +121,6 @@ check out the tutorial on [1D convolutions](plot_03_1D_convolution). ::: - Plotting the Basis Function Elements: -------------------------------------- We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points @@ -218,6 +143,37 @@ plt.plot(equispaced_samples, eval_basis) plt.show() ``` + +## Setting the basis support (Eval only) +Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that +your basis covers the same range across multiple experimental sessions. +You can specify a range for the support of your basis by setting the `bounds` +parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions). +Evaluating the basis at any sample outside the bounds will result in a NaN. + + +```{code-cell} ipython3 +bspline_range = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8)) + +print("Evaluated basis:") +# 0.5 is within the support, 0.1 is outside the support +print(np.round(bspline_range.compute_features([0.5, 0.1]), 3)) +``` + +Let's compare the default behavior of basis (estimating the range from the samples) with +the fixed range basis. + + +```{code-cell} ipython3 +fig, axs = plt.subplots(2,1, sharex=True) +plt.suptitle("B-spline basis ") +axs[0].plot(samples, bspline.compute_features(samples), color="k") +axs[0].set_title("default") +axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato") +axs[1].set_title("bounds=[0.2, 0.8]") +plt.tight_layout() +``` + Other Basis Types ----------------- Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, @@ -228,7 +184,7 @@ evaluate a log-spaced cosine raised function basis. ```{code-cell} ipython3 # Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter -raised_cosine_log = nmo.basis.EvalRaisedCosineLog(n_basis_funcs=10, width=1.5, time_scaling=50) +raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50) # Evaluate the raised cosine basis at the equi-spaced sample points # (same method in all Basis elements) diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/plot_02_ND_basis_function.md index 5cd4c1ec..03c0062d 100644 --- a/docs/background/plot_02_ND_basis_function.md +++ b/docs/background/plot_02_ND_basis_function.md @@ -130,8 +130,8 @@ import numpy as np import nemos as nmo # Define 1D basis objects -a_basis = nmo.basis.EvalMSpline(n_basis_funcs=15, order=3) -b_basis = nmo.basis.EvalRaisedCosineLog(n_basis_funcs=14) +a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3) +b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14) # Define the 2D additive basis object additive_basis = a_basis + b_basis @@ -340,9 +340,9 @@ will output a $K^N \times T$ matrix. T = 10 n_basis = 8 -a_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) -b_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) -c_basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=n_basis) +a_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis) +b_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis) +c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md index 18fed52b..17dfdb1c 100644 --- a/docs/background/plot_03_1D_convolution.md +++ b/docs/background/plot_03_1D_convolution.md @@ -82,7 +82,7 @@ see [jax.numpy.convolve](https://jax.readthedocs.io/en/latest/_autosummary/jax.n ```{code-cell} ipython3 # create three filters -basis_obj = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=3) +basis_obj = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=3) _, w = basis_obj.evaluate_on_grid(ws) plt.plot(w) @@ -198,18 +198,18 @@ Let's see how we can get the same results through [`Basis`](nemos.basis._basis.B ```{code-cell} ipython3 # define basis with different predictor causality -causal_basis = nmo.basis.ConvRaisedCosineLinear( +causal_basis = nmo.basis.RaisedCosineLinearConv( n_basis_funcs=3, window_size=ws, conv_kwargs=dict(predictor_causality="causal") ) -acausal_basis = nmo.basis.ConvRaisedCosineLinear( +acausal_basis = nmo.basis.RaisedCosineLinearConv( n_basis_funcs=3, window_size=ws, conv_kwargs=dict(predictor_causality="acausal") ) -anticausal_basis = nmo.basis.ConvRaisedCosineLinear( +anticausal_basis = nmo.basis.RaisedCosineLinearConv( n_basis_funcs=3, window_size=ws, conv_kwargs=dict(predictor_causality="anti-causal") ) diff --git a/docs/conf.py b/docs/conf.py index c8f5e3a2..966f7ddf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ 'inherited-members': True, 'undoc-members': True, 'show-inheritance': True, - 'special-members': '__call__, __add__, __mul__, __pow__' + 'special-members': ' __add__, __mul__, __pow__' } # # napolean configs diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index 13f89413..f7823a86 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -21,12 +21,12 @@ Abstract Class Basis │ ├─ Concrete Subclass RaisedCosineBasisLinear │ │ -│ └─ Concrete Subclass EvalRaisedCosineLog +│ └─ Concrete Subclass RaisedCosineLogEval │ └─ Concrete Subclass OrthExponentialBasis ``` -The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`__call__`](nemos.basis._basis.Basis.__call__) that is specific for each concrete class. See below for more details. +The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`_evaluate`](nemos.basis._basis.Basis._evaluate) that is specific for each concrete class. See below for more details. ## The Class `nemos.basis._basis.Basis` @@ -42,7 +42,7 @@ It accepts one or more NumPy array or pynapple `Tsd` object as input, and perfor 1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case. 2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. -3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). +3. In `"eval"` mode, calls the `_evaluate` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). 4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. :::{admonition} Multiple epochs @@ -61,14 +61,14 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the [`__call__`](nemos.basis._basis.Basis.__call__) method. +3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. [`__call__`](nemos.basis._basis.Basis.__call__): Evaluates a basis over some specified samples. +1. [`_evaluate`](nemos.basis._basis.Basis._evaluate): Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines @@ -77,7 +77,7 @@ The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the followi To write a usable (i.e., concrete, non-abstract) basis object, you - **Must** inherit the abstract superclass [`Basis`](nemos.basis._basis.Basis) -- **Must** define the [`__call__`](nemos.basis._basis.Basis.__call__) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. +- **Must** define the [`_evaluate`](nemos.basis._basis.Basis._evaluate) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. - **Should not** overwrite the [`compute_features`](nemos.basis._basis.Basis.compute_features) and [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis._basis.Basis). - **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis._spline_basis.SplineBasis)). diff --git a/docs/how_to_guide/plot_02_glm_demo.md b/docs/how_to_guide/plot_02_glm_demo.md index d89d10c9..83591da4 100644 --- a/docs/how_to_guide/plot_02_glm_demo.md +++ b/docs/how_to_guide/plot_02_glm_demo.md @@ -329,7 +329,7 @@ coupling_filter_bank *= 0.8 # define a basis function n_basis_funcs = 20 -basis = nmo.basis.EvalRaisedCosineLog(n_basis_funcs) +basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/docs/how_to_guide/plot_04_batch_glm.md b/docs/how_to_guide/plot_04_batch_glm.md index 84e58ad6..217de9ba 100644 --- a/docs/how_to_guide/plot_04_batch_glm.md +++ b/docs/how_to_guide/plot_04_batch_glm.md @@ -106,7 +106,7 @@ Here we instantiate the basis. `ws` is 40 time bins. It corresponds to a 200 ms ```{code-cell} ipython3 ws = 40 -basis = nmo.basis.ConvRaisedCosineLog(5, window_size=ws) +basis = nmo.basis.RaisedCosineLogConv(5, window_size=ws) ``` ## Batch definition diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index f3382de5..90fe1716 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -156,7 +156,7 @@ Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerB ```{code-cell} ipython3 -bas = nmo.basis.ConvRaisedCosineLinear(5, window_size=5) +bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5) trans_bas = bas.to_transformer() ``` @@ -188,7 +188,7 @@ pipeline = Pipeline( [ ( "transformerbasis", - nmo.basis.EvalRaisedCosineLinear(6).to_transformer(), + nmo.basis.RaisedCosineLinearEval(6).to_transformer(), ), ( "glm", @@ -324,7 +324,7 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds)) coeffs = {} # initialize basis and model -basis = nmo.basis.TransformerBasis(nmo.basis.EvalRaisedCosineLinear(6)) +basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineLinearEval(6)) model = nmo.glm.GLM(regularizer="Ridge") # loop over combinations @@ -451,12 +451,12 @@ Here we include `transformerbasis___basis` in the parameter grid to try differen param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis___basis=( - nmo.basis.EvalRaisedCosineLinear(5), - nmo.basis.EvalRaisedCosineLinear(10), - nmo.basis.EvalRaisedCosineLog(5), - nmo.basis.EvalRaisedCosineLog(10), - nmo.basis.EvalMSpline(5), - nmo.basis.EvalMSpline(10), + nmo.basis.RaisedCosineLinearEval(5), + nmo.basis.RaisedCosineLinearEval(10), + nmo.basis.RaisedCosineLogEval(5), + nmo.basis.RaisedCosineLogEval(10), + nmo.basis.MSplineEval(5), + nmo.basis.MSplineEval(10), ), ) ``` @@ -496,7 +496,7 @@ cvdf_wide = cvdf.pivot( doc_plots.plot_heatmap_cv_results(cvdf_wide) ``` -As shown in the table, the model with the highest score, highlighted in blue, used a EvalRaisedCosineLinear basis (as used above), which appears to be a suitable choice for our toy data. +As shown in the table, the model with the highest score, highlighted in blue, used a RaisedCosineLinearEval basis (as used above), which appears to be a suitable choice for our toy data. We can confirm that by plotting the firing rate predictions: @@ -537,12 +537,12 @@ param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), transformerbasis___basis=( - nmo.basis.EvalRaisedCosineLinear(5), - nmo.basis.EvalRaisedCosineLinear(10), - nmo.basis.EvalRaisedCosineLog(5), - nmo.basis.EvalRaisedCosineLog(10), - nmo.basis.EvalMSpline(5), - nmo.basis.EvalMSpline(10), + nmo.basis.RaisedCosineLinearEval(5), + nmo.basis.RaisedCosineLinearEval(10), + nmo.basis.RaisedCosineLogEval(5), + nmo.basis.RaisedCosineLogEval(10), + nmo.basis.MSplineEval(5), + nmo.basis.MSplineEval(10), ), ) ``` diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 1d6ca5f4..36910bc3 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -274,7 +274,7 @@ Okay, let's use unit number 7. Now let's set up our design matrix. First, let's fit the head direction by itself. Head direction is a circular variable (pi and -pi are adjacent to each other), so we need to use a basis that has this property as well. -[`EvalCyclicBSpline`](nemos.basis.basis.EvalCyclicBSpline) is one such basis. +[`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) is one such basis. Let's create our basis and then arrange our data properly. @@ -283,7 +283,7 @@ Let's create our basis and then arrange our data properly. unit_no = 7 spikes = nwb['units'][unit_no] -basis = nmo.basis.EvalCyclicBSpline(10, order=5) +basis = nmo.basis.CyclicBSplineEval(10, order=5) x = np.linspace(-np.pi, np.pi, 100) plt.figure() plt.plot(x, basis(x)) @@ -351,7 +351,7 @@ our data similarly. ```{code-cell} ipython3 -pos_basis = nmo.basis.EvalRaisedCosineLinear(10) * nmo.basis.EvalRaisedCosineLinear(10) +pos_basis = nmo.basis.RaisedCosineLinearEval(10) * nmo.basis.RaisedCosineLinearEval(10) spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data) X['spatial_position'] = pos_basis(*spatial_pos.values.T) diff --git a/docs/quickstart.md b/docs/quickstart.md index 8420d078..bdf3ffd4 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -165,7 +165,7 @@ you need to specify the number of basis functions. For some `basis` objects, add >>> import nemos as nmo >>> n_basis_funcs = 10 ->>> basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs) +>>> basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs) ``` @@ -205,7 +205,7 @@ number of sample points. >>> n_basis_funcs = 10 >>> # define a filter bank of 10 basis function, 200 samples long. ->>> basis = nmo.basis.ConvBSpline(n_basis_funcs, window_size=200) +>>> basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200) ``` @@ -350,7 +350,7 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa >>> upsampled_head_dir = head_dir.bin_average(0.01) >>> # create your features ->>> X = nmo.basis.EvalCyclicBSpline(10).compute_features(upsampled_head_dir) +>>> X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir) >>> # add a neuron axis and fit model >>> model = nmo.glm.GLM().fit(X, counts) diff --git a/docs/tutorials/plot_02_head_direction.md b/docs/tutorials/plot_02_head_direction.md index e4402053..74f9ab89 100644 --- a/docs/tutorials/plot_02_head_direction.md +++ b/docs/tutorials/plot_02_head_direction.md @@ -419,7 +419,7 @@ the cost of adding additional parameters. ```{code-cell} ipython3 # a basis object can be instantiated in "conv" mode for convolving the input. -basis = nmo.basis.ConvRaisedCosineLog( +basis = nmo.basis.RaisedCosineLogConv( n_basis_funcs=8, window_size=window_size ) @@ -600,7 +600,7 @@ to get an array of predictors of shape, `(num_time_points, num_neurons * num_bas ```{code-cell} ipython3 # re-initialize basis -basis = nmo.basis.ConvRaisedCosineLog( +basis = nmo.basis.RaisedCosineLogConv( n_basis_funcs=8, window_size=window_size ) diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index a7f767ef..65bfbc43 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -146,9 +146,9 @@ We can define a two-dimensional basis for position by multiplying two one-dimens see [here](../../background/plot_02_ND_basis_function) for more details. ```{code-cell} ipython3 -basis_2d = nmo.basis.EvalRaisedCosineLinear( +basis_2d = nmo.basis.RaisedCosineLinearEval( n_basis_funcs=10 -) * nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=10) +) * nmo.basis.RaisedCosineLinearEval(n_basis_funcs=10) ``` Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid. diff --git a/docs/tutorials/plot_04_v1_cells.md b/docs/tutorials/plot_04_v1_cells.md index aa479928..c9faaa82 100644 --- a/docs/tutorials/plot_04_v1_cells.md +++ b/docs/tutorials/plot_04_v1_cells.md @@ -345,7 +345,7 @@ GLM: ```{code-cell} ipython3 window_size = 100 -basis = nmo.basis.ConvRaisedCosineLog(8, window_size=window_size) +basis = nmo.basis.RaisedCosineLogConv(8, window_size=window_size) convolved_input = basis.compute_features(filtered_stimulus) ``` diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index 597959c3..ae8e38ac 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -335,15 +335,15 @@ print(count.shape) For each feature, we will use a different set of basis : - - position : [`EvalMSpline`](nemos.basis.basis.EvalMSpline) - - theta phase : [`EvalCyclicBSpline`](nemos.basis.basis.EvalCyclicBSpline) - - speed : [`EvalMSpline`](nemos.basis.basis.EvalMSpline) + - position : [`MSplineEval`](nemos.basis.basis.MSplineEval) + - theta phase : [`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) + - speed : [`MSplineEval`](nemos.basis.basis.MSplineEval) ```{code-cell} ipython3 -position_basis = nmo.basis.EvalMSpline(n_basis_funcs=10) -phase_basis = nmo.basis.EvalCyclicBSpline(n_basis_funcs=12) -speed_basis = nmo.basis.EvalMSpline(n_basis_funcs=15) +position_basis = nmo.basis.MSplineEval(n_basis_funcs=10) +phase_basis = nmo.basis.CyclicBSplineEval(n_basis_funcs=12) +speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15) ``` In addition, we will consider position and phase to be a joint variable. In NeMoS, we can combine basis by multiplying them and adding them. In this case the final basis object for our model can be made in one line : diff --git a/docs/tutorials/plot_06_calcium_imaging.md b/docs/tutorials/plot_06_calcium_imaging.md index e896affc..c95987a9 100644 --- a/docs/tutorials/plot_06_calcium_imaging.md +++ b/docs/tutorials/plot_06_calcium_imaging.md @@ -180,8 +180,8 @@ We can combine the two bases. ```{code-cell} ipython3 -heading_basis = nmo.basis.EvalCyclicBSpline(n_basis_funcs=12) -coupling_basis = nmo.basis.ConvRaisedCosineLog(3, window_size=10) +heading_basis = nmo.basis.CyclicBSplineEval(n_basis_funcs=12) +coupling_basis = nmo.basis.RaisedCosineLogConv(3, window_size=10) ``` We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an [`AdditiveBasis`](nemos.basis._basis.AdditiveBasis) object). diff --git a/src/nemos/_documentation_utils/plotting.py b/src/nemos/_documentation_utils/plotting.py index 58a94bc0..cfebac65 100644 --- a/src/nemos/_documentation_utils/plotting.py +++ b/src/nemos/_documentation_utils/plotting.py @@ -33,7 +33,7 @@ from matplotlib.patches import Rectangle from numpy.typing import NDArray -from ..basis import EvalRaisedCosineLog +from ..basis import RaisedCosineLogEval warnings.warn( "plotting functions contained within `_documentation_utils` are intended for nemos's documentation. " @@ -682,7 +682,7 @@ def plot_rates_and_smoothed_counts( def plot_basis(n_basis_funcs=8, window_size_sec=0.8): fig = plt.figure() - basis = EvalRaisedCosineLog(n_basis_funcs=n_basis_funcs) + basis = RaisedCosineLogEval(n_basis_funcs=n_basis_funcs) time, basis_kernels = basis.evaluate_on_grid(1000) time *= window_size_sec plt.plot(time, basis_kernels) diff --git a/src/nemos/basis/__init__.py b/src/nemos/basis/__init__.py index 3a08ad2e..94231346 100644 --- a/src/nemos/basis/__init__.py +++ b/src/nemos/basis/__init__.py @@ -1,18 +1,16 @@ -from ._basis import AdditiveBasis, Basis, MultiplicativeBasis -from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog -from ._spline_basis import BSplineBasis +from ._basis import AdditiveBasis, MultiplicativeBasis from ._transformer_basis import TransformerBasis from .basis import ( - ConvBSpline, - ConvCyclicBSpline, - ConvMSpline, - ConvOrthExponential, - ConvRaisedCosineLinear, - ConvRaisedCosineLog, - EvalBSpline, - EvalCyclicBSpline, - EvalMSpline, - EvalOrthExponential, - EvalRaisedCosineLinear, - EvalRaisedCosineLog, + BSplineConv, + BSplineEval, + CyclicBSplineConv, + CyclicBSplineEval, + MSplineConv, + MSplineEval, + OrthExponentialConv, + OrthExponentialEval, + RaisedCosineLinearConv, + RaisedCosineLinearEval, + RaisedCosineLogConv, + RaisedCosineLogEval, ) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 980b9657..698d4ce4 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -3,7 +3,7 @@ import abc import copy -from functools import partial, wraps +from functools import wraps from typing import Callable, Generator, Literal, Optional, Tuple, Union import jax @@ -19,7 +19,7 @@ from ._basis_mixin import BasisTransformerMixin -def add_docstring(method_name, cls=None): +def add_docstring(method_name, cls): """Prepend super-class docstrings.""" attr = getattr(cls, method_name, None) if attr is None: @@ -38,7 +38,7 @@ def check_transform_input(func: Callable) -> Callable: """Check input before calling basis. This decorator allows to raise an exception that is more readable - when the wrong number of input is provided to __call__. + when the wrong number of input is provided to _evaluate. """ @wraps(func) @@ -182,8 +182,6 @@ def n_basis_input(self) -> tuple | None: The number of inputs ``compute_feature`` expects. """ - if self._n_basis_input is None: - return return self._n_basis_input @property @@ -246,7 +244,7 @@ def add_constant(x): @check_transform_input def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ - Compute the basis functions and transform input data into model features. + Apply the basis transformation to the input data. This method is designed to be a high-level interface for transforming input data using the basis functions defined by the subclass. Depending on the basis' @@ -288,7 +286,7 @@ def _set_kernel(self): pass @abc.abstractmethod - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: """ Abstract method to evaluate the basis functions at given points. @@ -411,7 +409,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: Xs = np.meshgrid(*sample_tuple, indexing="ij") # evaluates the basis on a flat NDArray and reshape to match meshgrid output - Y = self.__call__(*tuple(grid_axis.flatten() for grid_axis in Xs)).reshape( + Y = self._evaluate(*tuple(grid_axis.flatten() for grid_axis in Xs)).reshape( (*n_samples, self.n_basis_funcs) ) @@ -823,10 +821,6 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis: return self -add_docstring_additive = partial(add_docstring, cls=Basis) -add_docstring_multiplicative = partial(add_docstring, cls=Basis) - - class AdditiveBasis(Basis): """ Class representing the addition of two Basis objects. @@ -851,13 +845,13 @@ class AdditiveBasis(Basis): >>> X = np.random.normal(size=(30, 2)) >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.EvalBSpline(10) - >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> basis_1 = nmo.basis.BSplineEval(10) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(15) >>> additive_basis = basis_1 + basis_2 >>> # can add another basis to the AdditiveBasis object >>> X = np.random.normal(size=(30, 3)) - >>> basis_3 = nmo.basis.EvalRaisedCosineLog(100) + >>> basis_3 = nmo.basis.RaisedCosineLogEval(100) >>> additive_basis_2 = additive_basis + basis_3 """ @@ -893,7 +887,7 @@ def _check_n_basis_min(self) -> None: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: """ Evaluate the basis at the input samples. @@ -916,32 +910,32 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix: >>> x, y = np.random.normal(size=(2, 30)) >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.EvalBSpline(10) - >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> basis_1 = nmo.basis.BSplineEval(10) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(15) >>> additive_basis = basis_1 + basis_2 >>> # call the basis. - >>> out = additive_basis(x, y) + >>> out = additive_basis._evaluate(x, y) """ X = np.hstack( ( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), + self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), ) ) return X - @add_docstring_additive("compute_features") + @add_docstring("compute_features", Basis) def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: r""" Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.basis import BSplineEval, RaisedCosineLogConv >>> from nemos.glm import GLM - >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") - >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis1 = BSplineEval(n_basis_funcs=5, label="one_input") + >>> basis2 = RaisedCosineLogConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> basis_add = basis1 + basis2 >>> X_multi = basis_add.compute_features(np.random.randn(20), np.random.randn(20, 2)) >>> print(X_multi.shape) # num_features: 17 = 5 + 2*6 @@ -1078,12 +1072,12 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvBSpline + >>> from nemos.basis import BSplineConv >>> from nemos.glm import GLM >>> # Define an additive basis >>> basis = ( - ... ConvBSpline(n_basis_funcs=5, window_size=10, label="feature_1") + - ... ConvBSpline(n_basis_funcs=6, window_size=10, label="feature_2") + ... BSplineConv(n_basis_funcs=5, window_size=10, label="feature_1") + + ... BSplineConv(n_basis_funcs=6, window_size=10, label="feature_2") ... ) >>> # Generate a sample input array and compute features >>> x1, x2 = np.random.randn(20), np.random.randn(20) @@ -1095,7 +1089,7 @@ def split_by_feature( feature_1: shape (20, 1, 5) feature_2: shape (20, 1, 6) >>> # If one of the basis components accepts multiple inputs, the resulting dictionary will be nested: - >>> multi_input_basis = ConvBSpline(n_basis_funcs=6, window_size=10, + >>> multi_input_basis = BSplineConv(n_basis_funcs=6, window_size=10, ... label="multi_input") >>> X_multi = multi_input_basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = multi_input_basis.split_by_feature(X_multi, axis=1) @@ -1158,8 +1152,8 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: >>> import nemos as nmo >>> # define two basis objects and add them - >>> basis_1 = nmo.basis.EvalBSpline(10) - >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> basis_1 = nmo.basis.BSplineEval(10) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(15) >>> additive_basis = basis_1 + basis_2 >>> # evaluate on a grid of 10 x 10 equi-spaced samples @@ -1193,13 +1187,13 @@ class MultiplicativeBasis(Basis): >>> X = np.random.normal(size=(30, 3)) >>> # define two basis and multiply - >>> basis_1 = nmo.basis.EvalBSpline(10) - >>> basis_2 = nmo.basis.EvalRaisedCosineLinear(15) + >>> basis_1 = nmo.basis.BSplineEval(10) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(15) >>> multiplicative_basis = basis_1 * basis_2 >>> # Can multiply or add another basis to the AdditiveBasis object >>> # This will cause the number of output features of the result basis to grow accordingly - >>> basis_3 = nmo.basis.EvalRaisedCosineLog(100) + >>> basis_3 = nmo.basis.RaisedCosineLogEval(100) >>> multiplicative_basis_2 = multiplicative_basis * basis_3 """ @@ -1241,7 +1235,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: """ Evaluate the basis at the input samples. @@ -1260,14 +1254,14 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix: -------- >>> import numpy as np >>> import nemos as nmo - >>> mult_basis = nmo.basis.EvalBSpline(5) * nmo.basis.EvalRaisedCosineLinear(6) + >>> mult_basis = nmo.basis.BSplineEval(5) * nmo.basis.RaisedCosineLinearEval(6) >>> x, y = np.random.randn(2, 30) - >>> X = mult_basis(x, y) + >>> X = mult_basis._evaluate(x, y) """ X = np.asarray( row_wise_kron( - self._basis1.__call__(*xi[: self._basis1._n_input_dimensionality]), - self._basis2.__call__(*xi[self._basis1._n_input_dimensionality :]), + self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), + self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -1292,7 +1286,7 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: -------- >>> import numpy as np >>> import nemos as nmo - >>> mult_basis = nmo.basis.EvalBSpline(5) * nmo.basis.EvalRaisedCosineLinear(6) + >>> mult_basis = nmo.basis.BSplineEval(5) * nmo.basis.RaisedCosineLinearEval(6) >>> x, y = np.random.randn(2, 30) >>> X = mult_basis.compute_features(x, y) """ @@ -1360,21 +1354,21 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: -------- >>> import numpy as np >>> import nemos as nmo - >>> mult_basis = nmo.basis.EvalBSpline(4) * nmo.basis.EvalRaisedCosineLinear(5) + >>> mult_basis = nmo.basis.BSplineEval(4) * nmo.basis.RaisedCosineLinearEval(5) >>> X, Y, Z = mult_basis.evaluate_on_grid(10, 10) """ return super().evaluate_on_grid(*n_samples) - @add_docstring_multiplicative("compute_features") + @add_docstring("compute_features", Basis) def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.basis import BSplineEval, RaisedCosineLogConv >>> from nemos.glm import GLM - >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") - >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis1 = BSplineEval(n_basis_funcs=5, label="one_input") + >>> basis2 = RaisedCosineLogConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> basis_mul = basis1 * basis2 >>> X_multi = basis_mul.compute_features(np.random.randn(20), np.random.randn(20, 2)) >>> print(X_multi.shape) # num_features: 60 = 5 * 2 * 6 @@ -1383,7 +1377,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ return super().compute_features(*xi) - @add_docstring_multiplicative("split_by_feature") + @add_docstring("split_by_feature", Basis) def split_by_feature( self, x: NDArray, @@ -1393,10 +1387,10 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalBSpline, ConvRaisedCosineLog + >>> from nemos.basis import BSplineEval, RaisedCosineLogConv >>> from nemos.glm import GLM - >>> basis1 = EvalBSpline(n_basis_funcs=5, label="one_input") - >>> basis2 = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis1 = BSplineEval(n_basis_funcs=5, label="one_input") + >>> basis2 = RaisedCosineLogConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> basis_mul = basis1 * basis2 >>> X_multi = basis_mul.compute_features(np.random.randn(20), np.random.randn(20, 2)) >>> print(X_multi.shape) # num_features: 60 = 5 * 2 * 6 diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 502ad791..3ae84376 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -38,7 +38,7 @@ def _compute_features(self, *xi: ArrayLike): or a pynapple Tsd. """ - return self.__call__(*xi) + return self._evaluate(*xi) def _set_kernel(self) -> "EvalBasisMixin": """ @@ -141,7 +141,7 @@ def _set_kernel(self) -> "ConvBasisMixin": computed and how the input parameters are utilized. If the basis operates in 'eval' mode exclusively, this method should simply return `self` without modification. """ - self.kernel_ = self.__call__(np.linspace(0, 1, self.window_size)) + self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size)) return self @property @@ -241,7 +241,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.EvalRaisedCosineLinear(10).to_transformer() + >>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 15202e6d..a7a403fb 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -4,7 +4,6 @@ from __future__ import annotations import abc -from functools import partial from typing import Optional, Tuple import numpy as np @@ -15,7 +14,6 @@ from ..typing import FeatureMatrix from ._basis import ( Basis, - add_docstring, check_one_dimensional, check_transform_input, min_max_rescale_samples, @@ -134,7 +132,7 @@ def _check_sample_size(self, *sample_pts: NDArray) -> None: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__( + def _evaluate( self, sample_pts: NDArray, ) -> FeatureMatrix: @@ -193,6 +191,3 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: OrthExponential basis functions, shape (n_samples, n_basis_funcs). """ return super().evaluate_on_grid(n_samples) - - -add_orth_exp_decay_docstring = partial(add_docstring, cls=OrthExponentialBasis) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index bf7bda95..db442fb4 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -2,7 +2,6 @@ from __future__ import annotations import abc -from functools import partial from typing import Optional, Tuple import numpy as np @@ -12,7 +11,6 @@ from ..typing import FeatureMatrix from ._basis import ( Basis, - add_docstring, check_one_dimensional, check_transform_input, min_max_rescale_samples, @@ -101,7 +99,7 @@ def _check_width(width: float) -> None: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__( + def _evaluate( # call these _evaluate self, sample_pts: ArrayLike, ) -> FeatureMatrix: @@ -330,7 +328,7 @@ def _compute_peaks(self) -> NDArray: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__( + def _evaluate( self, sample_pts: ArrayLike, ) -> FeatureMatrix: @@ -351,9 +349,4 @@ def __call__( ValueError If the sample provided do not lie in [0,1]. """ - return super().__call__(self._transform_samples(sample_pts)) - - -add_raised_cosine_linear_docstring = partial(add_docstring, cls=RaisedCosineBasisLinear) - -add_raised_cosine_log_docstring = partial(add_docstring, cls=RaisedCosineBasisLog) + return super()._evaluate(self._transform_samples(sample_pts)) diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index 6342faad..f8e93bbe 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -3,7 +3,6 @@ import abc import copy -from functools import partial from typing import Literal, Optional, Tuple import numpy as np @@ -14,7 +13,6 @@ from ..typing import FeatureMatrix from ._basis import ( Basis, - add_docstring, check_one_dimensional, check_transform_input, min_max_rescale_samples, @@ -194,12 +192,12 @@ class MSplineBasis(SplineBasis, abc.ABC): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import EvalMSpline + >>> from nemos.basis import MSplineEval >>> n_basis_funcs = 5 >>> order = 3 - >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> mspline_basis = MSplineEval(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) - >>> basis_functions = mspline_basis(sample_points) + >>> basis_functions = mspline_basis.compute_features(sample_points) """ def __init__( @@ -207,7 +205,7 @@ def __init__( n_basis_funcs: int, mode: Literal["eval", "conv"] = "eval", order: int = 2, - label: Optional[str] = "EvalMSpline", + label: Optional[str] = "MSplineEval", ) -> None: super().__init__( n_basis_funcs, @@ -219,7 +217,7 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: + def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix: """ Evaluate the M-spline basis functions at given sample points. @@ -336,7 +334,7 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix: + def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix: """ Evaluate the B-spline basis functions with given sample points. @@ -445,7 +443,7 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def __call__( + def _evaluate( self, sample_pts: ArrayLike, ) -> FeatureMatrix: @@ -669,8 +667,3 @@ def bspline( ) return basis_eval.T - - -add_docstrings_mspline = partial(add_docstring, cls=MSplineBasis) -add_docstrings_bspline = partial(add_docstring, cls=BSplineBasis) -add_docstrings_cyclic_bspline = partial(add_docstring, cls=CyclicBSplineBasis) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 461a9399..f6abdad5 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -29,7 +29,7 @@ class TransformerBasis: Examples -------- - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> from nemos.basis import TransformerBasis >>> from nemos.glm import GLM >>> from sklearn.pipeline import Pipeline @@ -40,7 +40,7 @@ class TransformerBasis: >>> # Generate data >>> num_samples, num_features = 10000, 1 >>> x = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalBSpline(10) + >>> basis = BSplineEval(10) >>> features = basis.compute_features(x) # basis transformed time series >>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights >>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts @@ -104,13 +104,13 @@ def fit(self, X: FeatureMatrix, y=None): Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis + >>> from nemos.basis import MSplineEval, TransformerBasis >>> # Example input >>> X = np.random.normal(size=(100, 2)) >>> # Define and fit tranformation basis - >>> basis = EvalMSpline(10) + >>> basis = MSplineEval(10) >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ @@ -136,12 +136,12 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvMSpline, TransformerBasis + >>> from nemos.basis import MSplineConv, TransformerBasis >>> # Example input >>> X = np.random.normal(size=(10000, 2)) - >>> basis = ConvMSpline(10, window_size=200) + >>> basis = MSplineConv(10, window_size=200) >>> transformer = TransformerBasis(basis) >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ @@ -181,13 +181,13 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalMSpline, TransformerBasis + >>> from nemos.basis import MSplineEval, TransformerBasis >>> # Example input >>> X = np.random.normal(size=(100, 1)) >>> # Define tranformation basis - >>> basis = EvalMSpline(10) + >>> basis = MSplineEval(10) >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis @@ -223,7 +223,7 @@ def __getattr__(self, name: str): Examples -------- >>> from nemos import basis - >>> bas = basis.EvalRaisedCosineLinear(5) + >>> bas = basis.RaisedCosineLinearEval(5) >>> trans_bas = basis.TransformerBasis(bas) >>> bas.n_basis_funcs 5 @@ -250,9 +250,9 @@ def __setattr__(self, name: str, value) -> None: Examples -------- >>> import nemos as nmo - >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.EvalMSpline(10)) + >>> trans_bas = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(10)) >>> # allowed - >>> trans_bas._basis = nmo.basis.EvalBSpline(10) + >>> trans_bas._basis = nmo.basis.BSplineEval(10) >>> # allowed >>> trans_bas.n_basis_funcs = 20 >>> # not allowed @@ -296,19 +296,19 @@ def set_params(self, **parameters) -> TransformerBasis: Examples -------- - >>> from nemos.basis import EvalBSpline, EvalMSpline, TransformerBasis - >>> basis = EvalMSpline(10) + >>> from nemos.basis import BSplineEval, MSplineEval, TransformerBasis + >>> basis = MSplineEval(10) >>> transformer_basis = TransformerBasis(basis=basis) >>> # setting parameters of _basis is allowed >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) 8 >>> # setting _basis directly is allowed - >>> print(type(transformer_basis.set_params(_basis=EvalBSpline(10))._basis)) - + >>> print(type(transformer_basis.set_params(_basis=BSplineEval(10))._basis)) + >>> # mixing is not allowed, this will raise an exception >>> try: - ... transformer_basis.set_params(_basis=EvalBSpline(10), n_basis_funcs=2) + ... transformer_basis.set_params(_basis=BSplineEval(10), n_basis_funcs=2) ... except ValueError as e: ... print(repr(e)) ValueError('Set either new _basis object or parameters for existing _basis, not both.') diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index eb8fa770..730aaee9 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -8,37 +8,26 @@ from numpy.typing import ArrayLike, NDArray from ..typing import FeatureMatrix +from ._basis import add_docstring from ._basis_mixin import BasisTransformerMixin, ConvBasisMixin, EvalBasisMixin -from ._decaying_exponential import OrthExponentialBasis, add_orth_exp_decay_docstring -from ._raised_cosine_basis import ( - RaisedCosineBasisLinear, - RaisedCosineBasisLog, - add_raised_cosine_linear_docstring, - add_raised_cosine_log_docstring, -) -from ._spline_basis import ( - BSplineBasis, - CyclicBSplineBasis, - MSplineBasis, - add_docstrings_bspline, - add_docstrings_cyclic_bspline, - add_docstrings_mspline, -) +from ._decaying_exponential import OrthExponentialBasis +from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog +from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis from ._transformer_basis import TransformerBasis __all__ = [ - "EvalMSpline", - "ConvMSpline", - "EvalBSpline", - "ConvBSpline", - "EvalCyclicBSpline", - "ConvCyclicBSpline", - "EvalRaisedCosineLinear", - "ConvRaisedCosineLinear", - "EvalRaisedCosineLog", - "ConvRaisedCosineLog", - "EvalOrthExponential", - "ConvOrthExponential", + "MSplineEval", + "MSplineConv", + "BSplineEval", + "BSplineConv", + "CyclicBSplineEval", + "CyclicBSplineConv", + "RaisedCosineLinearEval", + "RaisedCosineLinearConv", + "RaisedCosineLogEval", + "RaisedCosineLogConv", + "OrthExponentialEval", + "OrthExponentialConv", "TransformerBasis", ] @@ -47,7 +36,7 @@ def __dir__() -> list[str]: return __all__ -class EvalBSpline(EvalBasisMixin, BSplineBasis): +class BSplineEval(EvalBasisMixin, BSplineBasis): """ B-spline 1-dimensional basis functions. @@ -79,10 +68,10 @@ class EvalBSpline(EvalBasisMixin, BSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> n_basis_funcs = 5 >>> order = 3 - >>> bspline_basis = EvalBSpline(n_basis_funcs, order=order) + >>> bspline_basis = BSplineEval(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) >>> basis_functions = bspline_basis.compute_features(sample_points) """ @@ -92,7 +81,7 @@ def __init__( n_basis_funcs: int, order: int = 4, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalBSpline", + label: Optional[str] = "BSplineEval", ): EvalBasisMixin.__init__(self, bounds=bounds) BSplineBasis.__init__( @@ -103,7 +92,7 @@ def __init__( label=label, ) - @add_docstrings_bspline("split_by_feature") + @add_docstring("split_by_feature", BSplineBasis) def split_by_feature( self, x: NDArray, @@ -113,9 +102,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM - >>> basis = EvalBSpline(n_basis_funcs=6, label="one_input") + >>> basis = BSplineEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -123,28 +112,28 @@ def split_by_feature( one_input, shape (20, 1, 6) """ - return BSplineBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) - @add_docstrings_bspline("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalBSpline(10) + >>> basis = BSplineEval(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return BSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_bspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", BSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -153,8 +142,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalBSpline - >>> bspline_basis = EvalBSpline(n_basis_funcs=4, order=3) + >>> from nemos.basis import BSplineEval + >>> bspline_basis = BSplineEval(n_basis_funcs=4, order=3) >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -166,10 +155,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return BSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class ConvBSpline(ConvBasisMixin, BSplineBasis): +class BSplineConv(ConvBasisMixin, BSplineBasis): """ B-spline 1-dimensional basis functions. @@ -198,10 +187,10 @@ class ConvBSpline(ConvBasisMixin, BSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import ConvBSpline + >>> from nemos.basis import BSplineConv >>> n_basis_funcs = 5 >>> order = 3 - >>> bspline_basis = ConvBSpline(n_basis_funcs, order=order, window_size=10) + >>> bspline_basis = BSplineConv(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) >>> features = bspline_basis.compute_features(sample_points) """ @@ -211,7 +200,7 @@ def __init__( n_basis_funcs: int, window_size: int, order: int = 4, - label: Optional[str] = "ConvBSpline", + label: Optional[str] = "BSplineConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -223,7 +212,7 @@ def __init__( label=label, ) - @add_docstrings_bspline("split_by_feature") + @add_docstring("split_by_feature", BSplineBasis) def split_by_feature( self, x: NDArray, @@ -233,9 +222,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvBSpline + >>> from nemos.basis import BSplineConv >>> from nemos.glm import GLM - >>> basis = ConvBSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = BSplineConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -243,28 +232,28 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return BSplineBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) - @add_docstrings_bspline("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvBSpline + >>> from nemos.basis import BSplineConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvBSpline(10, window_size=11) + >>> basis = BSplineConv(10, window_size=11) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return BSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_bspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", BSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -273,8 +262,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvBSpline - >>> bspline_basis = ConvBSpline(n_basis_funcs=4, order=3, window_size=10) + >>> from nemos.basis import BSplineConv + >>> bspline_basis = BSplineConv(n_basis_funcs=4, order=3, window_size=10) >>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -286,10 +275,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return BSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): +class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis): """ B-spline 1-dimensional basis functions for cyclic splines. @@ -312,10 +301,10 @@ class EvalCyclicBSpline(EvalBasisMixin, CyclicBSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import EvalCyclicBSpline + >>> from nemos.basis import CyclicBSplineEval >>> n_basis_funcs = 5 >>> order = 3 - >>> cyclic_bspline_basis = EvalCyclicBSpline(n_basis_funcs, order=order) + >>> cyclic_bspline_basis = CyclicBSplineEval(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) >>> features = cyclic_bspline_basis.compute_features(sample_points) """ @@ -325,7 +314,7 @@ def __init__( n_basis_funcs: int, order: int = 4, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalCyclicBSpline", + label: Optional[str] = "CyclicBSplineEval", ): EvalBasisMixin.__init__(self, bounds=bounds) CyclicBSplineBasis.__init__( @@ -336,7 +325,7 @@ def __init__( label=label, ) - @add_docstrings_cyclic_bspline("split_by_feature") + @add_docstring("split_by_feature", CyclicBSplineBasis) def split_by_feature( self, x: NDArray, @@ -346,9 +335,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalCyclicBSpline + >>> from nemos.basis import CyclicBSplineEval >>> from nemos.glm import GLM - >>> basis = EvalCyclicBSpline(n_basis_funcs=6, label="one_input") + >>> basis = CyclicBSplineEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -356,28 +345,28 @@ def split_by_feature( one_input, shape (20, 1, 6) """ - return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) - @add_docstrings_cyclic_bspline("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalCyclicBSpline + >>> from nemos.basis import CyclicBSplineEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalCyclicBSpline(10) + >>> basis = CyclicBSplineEval(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return CyclicBSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_cyclic_bspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", CyclicBSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -386,8 +375,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalCyclicBSpline - >>> cbspline_basis = EvalCyclicBSpline(n_basis_funcs=4, order=3) + >>> from nemos.basis import CyclicBSplineEval + >>> cbspline_basis = CyclicBSplineEval(n_basis_funcs=4, order=3) >>> sample_points, basis_values = cbspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -399,10 +388,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): +class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis): """ B-spline 1-dimensional basis functions for cyclic splines. @@ -423,10 +412,10 @@ class ConvCyclicBSpline(ConvBasisMixin, CyclicBSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import ConvCyclicBSpline + >>> from nemos.basis import CyclicBSplineConv >>> n_basis_funcs = 5 >>> order = 3 - >>> cyclic_bspline_basis = ConvCyclicBSpline(n_basis_funcs, order=order, window_size=10) + >>> cyclic_bspline_basis = CyclicBSplineConv(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) >>> features = cyclic_bspline_basis.compute_features(sample_points) """ @@ -436,7 +425,7 @@ def __init__( n_basis_funcs: int, window_size: int, order: int = 4, - label: Optional[str] = "ConvCyclicBSpline", + label: Optional[str] = "CyclicBSplineConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -448,7 +437,7 @@ def __init__( label=label, ) - @add_docstrings_cyclic_bspline("split_by_feature") + @add_docstring("split_by_feature", CyclicBSplineBasis) def split_by_feature( self, x: NDArray, @@ -458,9 +447,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvCyclicBSpline + >>> from nemos.basis import CyclicBSplineConv >>> from nemos.glm import GLM - >>> basis = ConvCyclicBSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = CyclicBSplineConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -468,28 +457,28 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return CyclicBSplineBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) - @add_docstrings_cyclic_bspline("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvCyclicBSpline + >>> from nemos.basis import CyclicBSplineConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvCyclicBSpline(10, window_size=11) + >>> basis = CyclicBSplineConv(10, window_size=11) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return CyclicBSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_cyclic_bspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", CyclicBSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -498,8 +487,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvCyclicBSpline - >>> cbspline_basis = ConvCyclicBSpline(n_basis_funcs=4, order=3, window_size=10) + >>> from nemos.basis import CyclicBSplineConv + >>> cbspline_basis = CyclicBSplineConv(n_basis_funcs=4, order=3, window_size=10) >>> sample_points, basis_values = cbspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -511,10 +500,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return CyclicBSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class EvalMSpline(EvalBasisMixin, MSplineBasis): +class MSplineEval(EvalBasisMixin, MSplineBasis): r""" M-spline basis functions for modeling and data transformation. @@ -561,10 +550,10 @@ class EvalMSpline(EvalBasisMixin, MSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import EvalMSpline + >>> from nemos.basis import MSplineEval >>> n_basis_funcs = 5 >>> order = 3 - >>> mspline_basis = EvalMSpline(n_basis_funcs, order=order) + >>> mspline_basis = MSplineEval(n_basis_funcs, order=order) >>> sample_points = linspace(0, 1, 100) >>> features = mspline_basis.compute_features(sample_points) """ @@ -574,7 +563,7 @@ def __init__( n_basis_funcs: int, order: int = 4, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalMSpline", + label: Optional[str] = "MSplineEval", ): EvalBasisMixin.__init__(self, bounds=bounds) MSplineBasis.__init__( @@ -585,7 +574,7 @@ def __init__( label=label, ) - @add_docstrings_mspline("split_by_feature") + @add_docstring("split_by_feature", MSplineBasis) def split_by_feature( self, x: NDArray, @@ -595,9 +584,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalMSpline + >>> from nemos.basis import MSplineEval >>> from nemos.glm import GLM - >>> basis = EvalMSpline(n_basis_funcs=6, label="one_input") + >>> basis = MSplineEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -607,26 +596,26 @@ def split_by_feature( """ return MSplineBasis.split_by_feature(self, x, axis=axis) - @add_docstrings_mspline("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalMSpline + >>> from nemos.basis import MSplineEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalMSpline(10) + >>> basis = MSplineEval(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return MSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_mspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", MSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -635,8 +624,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalMSpline - >>> mspline_basis = EvalMSpline(n_basis_funcs=4, order=3) + >>> from nemos.basis import MSplineEval + >>> mspline_basis = MSplineEval(n_basis_funcs=4, order=3) >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -648,10 +637,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return MSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class ConvMSpline(ConvBasisMixin, MSplineBasis): +class MSplineConv(ConvBasisMixin, MSplineBasis): r""" M-spline basis functions for modeling and data transformation. @@ -696,10 +685,10 @@ class ConvMSpline(ConvBasisMixin, MSplineBasis): Examples -------- >>> from numpy import linspace - >>> from nemos.basis import ConvMSpline + >>> from nemos.basis import MSplineConv >>> n_basis_funcs = 5 >>> order = 3 - >>> mspline_basis = ConvMSpline(n_basis_funcs, order=order, window_size=10) + >>> mspline_basis = MSplineConv(n_basis_funcs, order=order, window_size=10) >>> sample_points = linspace(0, 1, 100) >>> features = mspline_basis.compute_features(sample_points) """ @@ -709,7 +698,7 @@ def __init__( n_basis_funcs: int, window_size: int, order: int = 4, - label: Optional[str] = "ConvMSpline", + label: Optional[str] = "MSplineConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -721,7 +710,7 @@ def __init__( label=label, ) - @add_docstrings_mspline("split_by_feature") + @add_docstring("split_by_feature", MSplineBasis) def split_by_feature( self, x: NDArray, @@ -731,9 +720,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvMSpline + >>> from nemos.basis import MSplineConv >>> from nemos.glm import GLM - >>> basis = ConvMSpline(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = MSplineConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -741,28 +730,28 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return MSplineBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) - @add_docstrings_mspline("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvMSpline + >>> from nemos.basis import MSplineConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvMSpline(10, window_size=11) + >>> basis = MSplineConv(10, window_size=11) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return MSplineBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_docstrings_mspline("evaluate_on_grid") + @add_docstring("evaluate_on_grid", MSplineBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples @@ -771,8 +760,8 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvMSpline - >>> mspline_basis = ConvMSpline(n_basis_funcs=4, order=3, window_size=10) + >>> from nemos.basis import MSplineConv + >>> mspline_basis = MSplineConv(n_basis_funcs=4, order=3, window_size=10) >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) >>> for i in range(4): ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') @@ -784,10 +773,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: Text(0, 0.5, 'Basis Function Value') >>> l = plt.legend() """ - return MSplineBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) -class EvalRaisedCosineLinear( +class RaisedCosineLinearEval( EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin ): """ @@ -820,9 +809,9 @@ class EvalRaisedCosineLinear( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearEval >>> n_basis_funcs = 5 - >>> raised_cosine_basis = EvalRaisedCosineLinear(n_basis_funcs) + >>> raised_cosine_basis = RaisedCosineLinearEval(n_basis_funcs) >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> features = raised_cosine_basis.compute_features(sample_points) @@ -833,7 +822,7 @@ def __init__( n_basis_funcs: int, width: float = 2.0, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalRaisedCosineLinear", + label: Optional[str] = "RaisedCosineLinearEval", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLinear.__init__( @@ -844,43 +833,43 @@ def __init__( label=label, ) - @add_raised_cosine_linear_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", RaisedCosineBasisLinear) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearEval >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = EvalRaisedCosineLinear(n_basis_funcs) + >>> ortho_basis = RaisedCosineLinearEval(n_basis_funcs) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) - @add_raised_cosine_linear_docstring("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalRaisedCosineLinear(10) + >>> basis = RaisedCosineLinearEval(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return RaisedCosineBasisLinear.compute_features(self, xi) + return super().compute_features(xi) - @add_raised_cosine_linear_docstring("split_by_feature") + @add_docstring("split_by_feature", RaisedCosineBasisLinear) def split_by_feature( self, x: NDArray, @@ -890,9 +879,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearEval >>> from nemos.glm import GLM - >>> basis = EvalRaisedCosineLinear(n_basis_funcs=6, label="one_input") + >>> basis = RaisedCosineLinearEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -900,10 +889,10 @@ def split_by_feature( one_input, shape (20, 1, 6) """ - return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) -class ConvRaisedCosineLinear( +class RaisedCosineLinearConv( ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin ): """ @@ -934,9 +923,9 @@ class ConvRaisedCosineLinear( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearConv >>> n_basis_funcs = 5 - >>> raised_cosine_basis = ConvRaisedCosineLinear(n_basis_funcs, window_size=10) + >>> raised_cosine_basis = RaisedCosineLinearConv(n_basis_funcs, window_size=10) >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> features = raised_cosine_basis.compute_features(sample_points) @@ -947,7 +936,7 @@ def __init__( n_basis_funcs: int, window_size: int, width: float = 2.0, - label: Optional[str] = "ConvRaisedCosineLinear", + label: Optional[str] = "RaisedCosineLinearConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -959,43 +948,43 @@ def __init__( label=label, ) - @add_raised_cosine_linear_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", RaisedCosineBasisLinear) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearConv >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = ConvRaisedCosineLinear(n_basis_funcs, window_size) + >>> ortho_basis = RaisedCosineLinearConv(n_basis_funcs, window_size) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return RaisedCosineBasisLinear.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) - @add_raised_cosine_linear_docstring("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvRaisedCosineLinear(10, window_size=100) + >>> basis = RaisedCosineLinearConv(10, window_size=100) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return RaisedCosineBasisLinear.compute_features(self, xi) + return super().compute_features(xi) - @add_raised_cosine_linear_docstring("split_by_feature") + @add_docstring("split_by_feature", RaisedCosineBasisLinear) def split_by_feature( self, x: NDArray, @@ -1005,9 +994,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLinear + >>> from nemos.basis import RaisedCosineLinearConv >>> from nemos.glm import GLM - >>> basis = ConvRaisedCosineLinear(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = RaisedCosineLinearConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -1015,13 +1004,13 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return RaisedCosineBasisLinear.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) -class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): +class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. - Similar to ``EvalRaisedCosineLinear`` but the basis functions are log-spaced. + Similar to ``RaisedCosineLinearEval`` but the basis functions are log-spaced. This implementation is based on the cosine bumps used by Pillow et al. [1]_ to uniformly tile the internal points of the domain. @@ -1055,9 +1044,9 @@ class EvalRaisedCosineLog(EvalBasisMixin, RaisedCosineBasisLog): Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogEval >>> n_basis_funcs = 5 - >>> raised_cosine_basis = EvalRaisedCosineLog(n_basis_funcs) + >>> raised_cosine_basis = RaisedCosineLogEval(n_basis_funcs) >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> features = raised_cosine_basis.compute_features(sample_points) @@ -1070,7 +1059,7 @@ def __init__( time_scaling: float = None, enforce_decay_to_zero: bool = True, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalRaisedCosineLog", + label: Optional[str] = "RaisedCosineLogEval", ): EvalBasisMixin.__init__(self, bounds=bounds) RaisedCosineBasisLog.__init__( @@ -1083,43 +1072,43 @@ def __init__( label=label, ) - @add_raised_cosine_log_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", RaisedCosineBasisLog) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogEval >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = EvalRaisedCosineLog(n_basis_funcs) + >>> ortho_basis = RaisedCosineLogEval(n_basis_funcs) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) - @add_raised_cosine_log_docstring("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalRaisedCosineLog(10) + >>> basis = RaisedCosineLogEval(10) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return RaisedCosineBasisLog.compute_features(self, xi) + return super().compute_features(xi) - @add_raised_cosine_log_docstring("split_by_feature") + @add_docstring("split_by_feature", RaisedCosineBasisLog) def split_by_feature( self, x: NDArray, @@ -1129,9 +1118,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogEval >>> from nemos.glm import GLM - >>> basis = EvalRaisedCosineLog(n_basis_funcs=6, label="one_input") + >>> basis = RaisedCosineLogEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -1139,13 +1128,13 @@ def split_by_feature( one_input, shape (20, 1, 6) """ - return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) -class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): +class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. - Similar to ``ConvRaisedCosineLinear`` but the basis functions are log-spaced. + Similar to ``RaisedCosineLinearConv`` but the basis functions are log-spaced. This implementation is based on the cosine bumps used by Pillow et al. [1]_ to uniformly tile the internal points of the domain. @@ -1179,9 +1168,9 @@ class ConvRaisedCosineLog(ConvBasisMixin, RaisedCosineBasisLog): Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogConv >>> n_basis_funcs = 5 - >>> raised_cosine_basis = ConvRaisedCosineLog(n_basis_funcs, window_size=10) + >>> raised_cosine_basis = RaisedCosineLogConv(n_basis_funcs, window_size=10) >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> features = raised_cosine_basis.compute_features(sample_points) @@ -1194,7 +1183,7 @@ def __init__( width: float = 2.0, time_scaling: float = None, enforce_decay_to_zero: bool = True, - label: Optional[str] = "ConvRaisedCosineLog", + label: Optional[str] = "RaisedCosineLogConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -1208,43 +1197,43 @@ def __init__( label=label, ) - @add_raised_cosine_log_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", RaisedCosineBasisLog) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogConv >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = ConvRaisedCosineLog(n_basis_funcs, window_size) + >>> ortho_basis = RaisedCosineLogConv(n_basis_funcs, window_size) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return RaisedCosineBasisLog.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) - @add_raised_cosine_log_docstring("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvRaisedCosineLog(10, window_size=100) + >>> basis = RaisedCosineLogConv(10, window_size=100) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return RaisedCosineBasisLog.compute_features(self, xi) + return super().compute_features(xi) - @add_raised_cosine_log_docstring("split_by_feature") + @add_docstring("split_by_feature", RaisedCosineBasisLog) def split_by_feature( self, x: NDArray, @@ -1254,9 +1243,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogConv >>> from nemos.glm import GLM - >>> basis = ConvRaisedCosineLog(n_basis_funcs=6, window_size=10, label="two_inputs") + >>> basis = RaisedCosineLogConv(n_basis_funcs=6, window_size=10, label="two_inputs") >>> X_multi = basis.compute_features(np.random.randn(20, 2)) >>> split_features_multi = basis.split_by_feature(X_multi, axis=1) >>> for feature, sub_dict in split_features_multi.items(): @@ -1264,10 +1253,10 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return RaisedCosineBasisLog.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) -class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): +class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. Parameters @@ -1288,12 +1277,12 @@ class EvalOrthExponential(EvalBasisMixin, OrthExponentialBasis): -------- >>> import numpy as np >>> from numpy import linspace - >>> from nemos.basis import EvalOrthExponential + >>> from nemos.basis import OrthExponentialEval >>> X = np.random.normal(size=(1000, 1)) >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size = 10 - >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates) + >>> ortho_basis = OrthExponentialEval(n_basis_funcs, decay_rates) >>> sample_points = linspace(0, 1, 100) >>> # evaluate the basis >>> features = ortho_basis.compute_features(sample_points) @@ -1305,7 +1294,7 @@ def __init__( n_basis_funcs: int, decay_rates: NDArray, bounds: Optional[Tuple[float, float]] = None, - label: Optional[str] = "EvalOrthExponential", + label: Optional[str] = "OrthExponentialEval", ): EvalBasisMixin.__init__(self, bounds=bounds) OrthExponentialBasis.__init__( @@ -1316,43 +1305,43 @@ def __init__( label=label, ) - @add_orth_exp_decay_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import EvalOrthExponential + >>> from nemos.basis import OrthExponentialEval >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = EvalOrthExponential(n_basis_funcs, decay_rates=decay_rates) + >>> ortho_basis = OrthExponentialEval(n_basis_funcs, decay_rates=decay_rates) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return OrthExponentialBasis.evaluate_on_grid(self, n_samples=n_samples) + return super().evaluate_on_grid(n_samples=n_samples) - @add_orth_exp_decay_docstring("compute_features") + @add_docstring("_compute_features", EvalBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalOrthExponential + >>> from nemos.basis import OrthExponentialEval >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = EvalOrthExponential(10, decay_rates=np.arange(1, 11)) + >>> basis = OrthExponentialEval(10, decay_rates=np.arange(1, 11)) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return OrthExponentialBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_orth_exp_decay_docstring("split_by_feature") + @add_docstring("split_by_feature", OrthExponentialBasis) def split_by_feature( self, x: NDArray, @@ -1362,10 +1351,10 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import EvalOrthExponential + >>> from nemos.basis import OrthExponentialEval >>> from nemos.glm import GLM >>> # Define an additive basis - >>> basis = EvalOrthExponential(n_basis_funcs=5, decay_rates=np.arange(1, 6), label="feature") + >>> basis = OrthExponentialEval(n_basis_funcs=5, decay_rates=np.arange(1, 6), label="feature") >>> # Generate a sample input array and compute features >>> x = np.random.randn(20) >>> X = basis.compute_features(x) @@ -1376,10 +1365,10 @@ def split_by_feature( feature: shape (20, 1, 5) """ - return OrthExponentialBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) -class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): +class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. Parameters @@ -1397,12 +1386,12 @@ class ConvOrthExponential(ConvBasisMixin, OrthExponentialBasis): Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import OrthExponentialConv >>> X = np.random.normal(size=(1000, 1)) >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size = 10 - >>> ortho_basis = ConvOrthExponential(n_basis_funcs, window_size, decay_rates) + >>> ortho_basis = OrthExponentialConv(n_basis_funcs, window_size, decay_rates) >>> sample_points = np.random.randn(100) >>> # convolve the basis >>> features = ortho_basis.compute_features(sample_points) @@ -1413,7 +1402,7 @@ def __init__( n_basis_funcs: int, window_size: int, decay_rates: NDArray, - label: Optional[str] = "ConvOrthExponential", + label: Optional[str] = "OrthExponentialConv", conv_kwargs: Optional[dict] = None, ): ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs) @@ -1425,43 +1414,43 @@ def __init__( label=label, ) - @add_orth_exp_decay_docstring("evaluate_on_grid") + @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ Examples -------- >>> import numpy as np >>> import matplotlib.pyplot as plt - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import OrthExponentialConv >>> n_basis_funcs = 5 >>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates >>> window_size=10 - >>> ortho_basis = ConvOrthExponential(n_basis_funcs, window_size, decay_rates=decay_rates) + >>> ortho_basis = OrthExponentialConv(n_basis_funcs, window_size, decay_rates=decay_rates) >>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100) """ - return OrthExponentialBasis.evaluate_on_grid(self, n_samples) + return super().evaluate_on_grid(n_samples) - @add_orth_exp_decay_docstring("compute_features") + @add_docstring("_compute_features", ConvBasisMixin) def compute_features(self, xi: ArrayLike) -> FeatureMatrix: """ Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import OrthExponentialConv >>> # Generate data >>> num_samples = 1000 >>> X = np.random.normal(size=(num_samples, )) # raw time series - >>> basis = ConvOrthExponential(10, window_size=100, decay_rates=np.arange(1, 11)) + >>> basis = OrthExponentialConv(10, window_size=100, decay_rates=np.arange(1, 11)) >>> features = basis.compute_features(X) # basis transformed time series >>> features.shape (1000, 10) """ - return OrthExponentialBasis.compute_features(self, xi) + return super().compute_features(xi) - @add_orth_exp_decay_docstring("split_by_feature") + @add_docstring("split_by_feature", OrthExponentialBasis) def split_by_feature( self, x: NDArray, @@ -1471,9 +1460,9 @@ def split_by_feature( Examples -------- >>> import numpy as np - >>> from nemos.basis import ConvOrthExponential + >>> from nemos.basis import OrthExponentialConv >>> from nemos.glm import GLM - >>> basis = ConvOrthExponential( + >>> basis = OrthExponentialConv( ... n_basis_funcs=6, ... decay_rates=np.arange(1, 7), ... window_size=10, @@ -1486,4 +1475,4 @@ def split_by_feature( two_inputs, shape (20, 2, 6) """ - return OrthExponentialBasis.split_by_feature(self, x, axis=axis) + return super().split_by_feature(x, axis=axis) diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index 6098bfd0..b949b489 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -216,10 +216,10 @@ def apply_identifiability_constraints( -------- >>> import numpy as np >>> from nemos.identifiability_constraints import apply_identifiability_constraints - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM >>> # define a feature matrix - >>> bas = EvalBSpline(5) + EvalBSpline(6) + >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints(feature_matrix) @@ -281,10 +281,10 @@ def apply_identifiability_constraints_by_basis_component( -------- >>> import numpy as np >>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component - >>> from nemos.basis import EvalBSpline + >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM >>> # define a feature matrix - >>> bas = EvalBSpline(5) + EvalBSpline(6) + >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints_by_basis_component(bas, feature_matrix) diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py index cbfd674f..01d634b6 100644 --- a/src/nemos/simulation.py +++ b/src/nemos/simulation.py @@ -151,11 +151,11 @@ def regress_filter(coupling_filters: NDArray, eval_basis: NDArray) -> NDArray: >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from nemos.simulation import regress_filter, difference_of_gammas - >>> from nemos.basis import EvalRaisedCosineLog + >>> from nemos.basis import RaisedCosineLogEval >>> filter_duration = 100 >>> n_basis_funcs = 20 >>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1) - >>> _, basis = EvalRaisedCosineLog(10).evaluate_on_grid(filter_duration) + >>> _, basis = RaisedCosineLogEval(10).evaluate_on_grid(filter_duration) >>> weights = regress_filter(filter_bank, basis)[0, 0] >>> print("Weights shape:", weights.shape) Weights shape: (10,) diff --git a/tests/conftest.py b/tests/conftest.py index 4b225125..eb88ed10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -286,7 +286,7 @@ def coupled_model_simulate(): ) # shrink the filters for simulation stability coupling_filter_bank *= 0.8 - basis = nmo.basis.EvalRaisedCosineLog(20) + basis = nmo.basis.RaisedCosineLogEval(20) # approximate the coupling filters in terms of the basis function _, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0]) diff --git a/tests/test_basis.py b/tests/test_basis.py index 3f86ebe7..15a9e5f7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -132,7 +132,7 @@ def test_all_basis_are_tested() -> None: ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), ( "compute_features", - "Compute the basis functions and transform input data into model features", + "Apply the basis transformation to the input data", ), ( "split_by_feature", @@ -167,7 +167,7 @@ def test_example_docstrings_add( continue if basis_name == basis_instance.__class__.__name__: continue - assert basis_name not in doc_components[1] + assert f" {basis_name}" not in doc_components[1] def test_add_docstring(): @@ -191,19 +191,19 @@ def method(self): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10), BSplineBasis), - (basis.ConvBSpline(10, window_size=11), BSplineBasis), - (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), - (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), - (basis.EvalMSpline(10), MSplineBasis), - (basis.ConvMSpline(10, window_size=11), MSplineBasis), - (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), - (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), - (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), - (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.BSplineEval(10), BSplineBasis), + (basis.BSplineConv(10, window_size=11), BSplineBasis), + (basis.CyclicBSplineEval(10), CyclicBSplineBasis), + (basis.CyclicBSplineConv(10, window_size=11), CyclicBSplineBasis), + (basis.MSplineEval(10), MSplineBasis), + (basis.MSplineConv(10, window_size=11), MSplineBasis), + (basis.RaisedCosineLinearEval(10), RaisedCosineBasisLinear), + (basis.RaisedCosineLinearConv(10, window_size=11), RaisedCosineBasisLinear), + (basis.RaisedCosineLogEval(10), RaisedCosineBasisLog), + (basis.RaisedCosineLogConv(10, window_size=11), RaisedCosineBasisLog), + (basis.OrthExponentialEval(10, np.arange(1, 11)), OrthExponentialBasis), ( - basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + basis.OrthExponentialConv(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis, ), ], @@ -218,19 +218,19 @@ def test_expected_output_eval_on_grid(basis_instance, super_class): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10), BSplineBasis), - (basis.ConvBSpline(10, window_size=11), BSplineBasis), - (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), - (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), - (basis.EvalMSpline(10), MSplineBasis), - (basis.ConvMSpline(10, window_size=11), MSplineBasis), - (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), - (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), - (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), - (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.BSplineEval(10), BSplineBasis), + (basis.BSplineConv(10, window_size=11), BSplineBasis), + (basis.CyclicBSplineEval(10), CyclicBSplineBasis), + (basis.CyclicBSplineConv(10, window_size=11), CyclicBSplineBasis), + (basis.MSplineEval(10), MSplineBasis), + (basis.MSplineConv(10, window_size=11), MSplineBasis), + (basis.RaisedCosineLinearEval(10), RaisedCosineBasisLinear), + (basis.RaisedCosineLinearConv(10, window_size=11), RaisedCosineBasisLinear), + (basis.RaisedCosineLogEval(10), RaisedCosineBasisLog), + (basis.RaisedCosineLogConv(10, window_size=11), RaisedCosineBasisLog), + (basis.OrthExponentialEval(10, np.arange(1, 11)), OrthExponentialBasis), ( - basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + basis.OrthExponentialConv(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis, ), ], @@ -246,31 +246,31 @@ def test_expected_output_compute_features(basis_instance, super_class): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10, label="label"), BSplineBasis), - (basis.ConvBSpline(10, window_size=11, label="label"), BSplineBasis), - (basis.EvalCyclicBSpline(10, label="label"), CyclicBSplineBasis), + (basis.BSplineEval(10, label="label"), BSplineBasis), + (basis.BSplineConv(10, window_size=11, label="label"), BSplineBasis), + (basis.CyclicBSplineEval(10, label="label"), CyclicBSplineBasis), ( - basis.ConvCyclicBSpline(10, window_size=11, label="label"), + basis.CyclicBSplineConv(10, window_size=11, label="label"), CyclicBSplineBasis, ), - (basis.EvalMSpline(10, label="label"), MSplineBasis), - (basis.ConvMSpline(10, window_size=11, label="label"), MSplineBasis), - (basis.EvalRaisedCosineLinear(10, label="label"), RaisedCosineBasisLinear), + (basis.MSplineEval(10, label="label"), MSplineBasis), + (basis.MSplineConv(10, window_size=11, label="label"), MSplineBasis), + (basis.RaisedCosineLinearEval(10, label="label"), RaisedCosineBasisLinear), ( - basis.ConvRaisedCosineLinear(10, window_size=11, label="label"), + basis.RaisedCosineLinearConv(10, window_size=11, label="label"), RaisedCosineBasisLinear, ), - (basis.EvalRaisedCosineLog(10, label="label"), RaisedCosineBasisLog), + (basis.RaisedCosineLogEval(10, label="label"), RaisedCosineBasisLog), ( - basis.ConvRaisedCosineLog(10, window_size=11, label="label"), + basis.RaisedCosineLogConv(10, window_size=11, label="label"), RaisedCosineBasisLog, ), ( - basis.EvalOrthExponential(10, np.arange(1, 11), label="label"), + basis.OrthExponentialEval(10, np.arange(1, 11), label="label"), OrthExponentialBasis, ), ( - basis.ConvOrthExponential( + basis.OrthExponentialConv( 10, decay_rates=np.arange(1, 11), window_size=12, label="label" ), OrthExponentialBasis, @@ -305,12 +305,12 @@ def cls(self): @pytest.mark.parametrize( "cls", [ - {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog}, - {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear}, - {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline}, - {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline}, - {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline}, - {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential}, + {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv}, + {"eval": basis.RaisedCosineLinearEval, "conv": basis.RaisedCosineLinearConv}, + {"eval": basis.BSplineEval, "conv": basis.BSplineConv}, + {"eval": basis.CyclicBSplineEval, "conv": basis.CyclicBSplineConv}, + {"eval": basis.MSplineEval, "conv": basis.MSplineConv}, + {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv}, ], ) class TestSharedMethods: @@ -345,7 +345,7 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): return bas = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) with expectation: - bas(samples) + bas._evaluate(samples) @pytest.mark.parametrize( "attribute, value", @@ -532,7 +532,7 @@ def test_call_basis_number(self, n_basis, mode, kwargs, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis + assert bas._evaluate(x).shape[1] == n_basis @pytest.mark.parametrize("n_basis", [6]) def test_call_equivalent_in_conv(self, n_basis, cls): @@ -545,7 +545,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) ) x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eval(x)) + assert np.all(bas_con._evaluate(x) == bas_eval._evaluate(x)) @pytest.mark.parametrize( "num_input, expectation", @@ -564,7 +564,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) + bas._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -582,7 +582,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with expectation: - bas(inp) + bas._evaluate(inp) @pytest.mark.parametrize( "samples, expectation", @@ -600,7 +600,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) ) # Only eval mode is relevant here with expectation: - bas(samples) + bas._evaluate(samples) @pytest.mark.parametrize( "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] @@ -609,7 +609,7 @@ def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x[3] = np.nan - assert all(np.isnan(bas(x)[3])) + assert all(np.isnan(bas._evaluate(x)[3])) @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( @@ -620,7 +620,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) + bas._evaluate(np.array([])) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( @@ -628,7 +628,10 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape + assert ( + bas._evaluate(np.linspace(0, 1, time_axis_shape)).shape[0] + == time_axis_shape + ) @pytest.mark.parametrize( "mn, mx, expectation", @@ -643,7 +646,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) with expectation: - bas(np.linspace(mn, mx, 10)) + bas._evaluate(np.linspace(mn, mx, 10)) @pytest.mark.parametrize( "kwargs, input1_shape, expectation", @@ -1020,9 +1023,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): # @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) # def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs, order, cls): # min_per_basis = { - # "EvalMSpline": (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs), - # "EvalRaisedCosineLog": lambda x: x < 2, - # "EvalBSpline": lambda x: order > x, + # "MSplineEval": (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs), + # "RaisedCosineLogEval": lambda x: x < 2, + # "BSplineEval": lambda x: order > x, # } # if n_basis_funcs < 2: # with pytest.raises( @@ -1085,8 +1088,8 @@ def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) + y = bas._evaluate(x) + y_nap = bas._evaluate(x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap.t) @@ -1258,7 +1261,7 @@ def test_transformer_get_params(self, cls): class TestRaisedCosineLogBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} + cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @pytest.mark.parametrize("width", [1.5, 2, 2.5]) def test_decay_to_zero_basis_number_match(self, width): @@ -1332,7 +1335,7 @@ def test_set_width(self, width, expectation, mode, kwargs): def test_time_scaling_property(self): time_scaling = [0.1, 10, 100] n_basis_funcs = 5 - _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) + _, lin_ev = basis.RaisedCosineLinearEval(n_basis_funcs).evaluate_on_grid(100) corr = np.zeros(len(time_scaling)) for idx, ts in enumerate(time_scaling): basis_log = self.cls["eval"]( @@ -1395,7 +1398,7 @@ def test_width_values(self, width, expectation, mode, kwargs): class TestRaisedCosineLinearBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} + cls = {"eval": basis.RaisedCosineLinearEval, "conv": basis.RaisedCosineLinearConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( @@ -1469,7 +1472,7 @@ def test_width_values(self, width, expectation, mode, kwargs): class TestMSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline} + cls = {"eval": basis.MSplineEval, "conv": basis.MSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @@ -1573,7 +1576,7 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential} + cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] @@ -1645,7 +1648,7 @@ def test_minimum_number_of_basis_required_is_matched( class TestBSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline} + cls = {"eval": basis.BSplineEval, "conv": basis.BSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @@ -1733,7 +1736,7 @@ def test_samples_range_matches_compute_features_requirements( class TestCyclicBSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline} + cls = {"eval": basis.CyclicBSplineEval, "conv": basis.CyclicBSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @@ -1872,23 +1875,23 @@ def instantiate_basis( if basis_class == AdditiveBasis: kwargs_mspline = trim_kwargs( - basis.EvalMSpline, kwargs, class_specific_params + basis.MSplineEval, kwargs, class_specific_params ) kwargs_raised_cosine = trim_kwargs( - basis.ConvRaisedCosineLinear, kwargs, class_specific_params + basis.RaisedCosineLinearConv, kwargs, class_specific_params ) - b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) basis_obj = b1 + b2 elif basis_class == MultiplicativeBasis: kwargs_mspline = trim_kwargs( - basis.EvalMSpline, kwargs, class_specific_params + basis.MSplineEval, kwargs, class_specific_params ) kwargs_raised_cosine = trim_kwargs( - basis.ConvRaisedCosineLinear, kwargs, class_specific_params + basis.RaisedCosineLinearConv, kwargs, class_specific_params ) - b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) basis_obj = b1 * b2 else: basis_obj = basis_class( @@ -1901,7 +1904,7 @@ class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) - @pytest.mark.parametrize("base_cls", [basis.EvalBSpline, basis.ConvBSpline]) + @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) def test_non_empty_samples(self, base_cls, samples, class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} kwargs = trim_kwargs(base_cls, kwargs, class_specific_params) @@ -1928,7 +1931,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.EvalMSpline(5) + basis.EvalMSpline(5) + basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -2183,7 +2186,7 @@ def test_call_input_num( TypeError, match="Input dimensionality mismatch" ) with expectation: - basis_obj(*([np.linspace(0, 1, 10)] * num_input)) + basis_obj._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -2216,7 +2219,7 @@ def test_call_input_shape( ) basis_obj = basis_a_obj + basis_b_obj with expectation: - basis_obj(*([inp] * basis_obj._n_input_dimensionality)) + basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize(" window_size", [3]) @@ -2242,7 +2245,7 @@ def test_call_sample_axis( ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality - assert basis_obj(*inp).shape[0] == time_axis_shape + assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2267,7 +2270,7 @@ def test_call_nan( inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality for x in inp: x[3] = np.nan - assert all(np.isnan(basis_obj(*inp)[3])) + assert all(np.isnan(basis_obj._evaluate(*inp)[3])) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2293,7 +2296,7 @@ def test_call_equivalent_in_conv( bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality - assert np.all(bas_con(*x) == bas_eva(*x)) + assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2313,8 +2316,8 @@ def test_pynapple_support( x = np.linspace(0, 1, 10) x_nap = [nap.Tsd(t=np.arange(10), d=x)] * bas._n_input_dimensionality x = [x] * bas._n_input_dimensionality - y = bas(*x) - y_nap = bas(*x_nap) + y = bas._evaluate(*x) + y_nap = bas._evaluate(*x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) @@ -2335,7 +2338,10 @@ def test_call_basis_number( ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality - assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs + assert ( + bas._evaluate(*x).shape[1] + == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs + ) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2353,7 +2359,7 @@ def test_call_non_empty( ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): - bas(*([np.array([])] * bas._n_input_dimensionality)) + bas._evaluate(*([np.array([])] * bas._n_input_dimensionality)) @pytest.mark.parametrize( "mn, mx, expectation", @@ -2398,7 +2404,7 @@ def test_call_sample_range( ) bas = basis_a_obj + basis_b_obj with expectation: - bas(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) + bas._evaluate(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2457,8 +2463,8 @@ def test_transform_fails( @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(11, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2469,8 +2475,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2488,8 +2494,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas = bas1 + bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2505,7 +2511,7 @@ class TestMultiplicativeBasis(CombinedBasis): ) @pytest.mark.parametrize(" ws", [3]) def test_non_empty_samples(self, samples, ws): - basis_obj = basis.EvalMSpline(5) * basis.EvalRaisedCosineLinear(5) + basis_obj = basis.MSplineEval(5) * basis.RaisedCosineLinearEval(5) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -2528,7 +2534,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.EvalMSpline(5) * basis.EvalMSpline(5) + basis_obj = basis.MSplineEval(5) * basis.MSplineEval(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -2821,7 +2827,7 @@ def test_call_input_num( TypeError, match="Input dimensionality mismatch" ) with expectation: - basis_obj(*([np.linspace(0, 1, 10)] * num_input)) + basis_obj._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -2854,7 +2860,7 @@ def test_call_input_shape( ) basis_obj = basis_a_obj * basis_b_obj with expectation: - basis_obj(*([inp] * basis_obj._n_input_dimensionality)) + basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize(" window_size", [3]) @@ -2880,7 +2886,7 @@ def test_call_sample_axis( ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality - assert basis_obj(*inp).shape[0] == time_axis_shape + assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2905,7 +2911,7 @@ def test_call_nan( inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality for x in inp: x[3] = np.nan - assert all(np.isnan(basis_obj(*inp)[3])) + assert all(np.isnan(basis_obj._evaluate(*inp)[3])) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2931,7 +2937,7 @@ def test_call_equivalent_in_conv( bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality - assert np.all(bas_con(*x) == bas_eva(*x)) + assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2951,8 +2957,8 @@ def test_pynapple_support( x = np.linspace(0, 1, 10) x_nap = [nap.Tsd(t=np.arange(10), d=x)] * bas._n_input_dimensionality x = [x] * bas._n_input_dimensionality - y = bas(*x) - y_nap = bas(*x_nap) + y = bas._evaluate(*x) + y_nap = bas._evaluate(*x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) @@ -2973,7 +2979,10 @@ def test_call_basis_number( ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality - assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs + assert ( + bas._evaluate(*x).shape[1] + == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs + ) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2991,7 +3000,7 @@ def test_call_non_empty( ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): - bas(*([np.array([])] * bas._n_input_dimensionality)) + bas._evaluate(*([np.array([])] * bas._n_input_dimensionality)) @pytest.mark.parametrize( "mn, mx, expectation", @@ -3036,7 +3045,7 @@ def test_call_sample_range( ) bas = basis_a_obj * basis_b_obj with expectation: - bas(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) + bas._evaluate(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -3095,8 +3104,8 @@ def test_transform_fails( @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(11, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -3107,8 +3116,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -3126,8 +3135,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas = bas1 * bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -3137,8 +3146,8 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_n_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) @@ -3147,7 +3156,7 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize( - "exponent", [-1, 0, 0.5, basis.EvalRaisedCosineLog(4), 1, 2, 3] + "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) def test_power_of_basis(exponent, basis_class, class_specific_params): diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index 0fda51e9..ca4f4be2 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from nemos.basis.basis import ConvBSpline, EvalBSpline, EvalRaisedCosineLinear +from nemos.basis.basis import BSplineConv, BSplineEval, RaisedCosineLinearEval from nemos.identifiability_constraints import ( _WARN_FLOAT32_MESSAGE, _find_drop_column, @@ -92,20 +92,20 @@ def test_apply_identifiability_constraints_add_constant(add_intercept, expected_ @pytest.mark.parametrize( "basis, input_shape, output_shape, expected_columns", [ - (EvalRaisedCosineLinear(10, width=4), (50,), (50, 10), jnp.arange(10)), + (RaisedCosineLinearEval(10, width=4), (50,), (50, 10), jnp.arange(10)), ( - EvalBSpline(5) + EvalBSpline(6), + BSplineEval(5) + BSplineEval(6), (20,), (20, 9), jnp.array([1, 2, 3, 4, 6, 7, 8, 9, 10]), ), ( - ConvBSpline(5, window_size=10) + EvalBSpline(6), + BSplineConv(5, window_size=10) + BSplineEval(6), (20,), (20, 10), jnp.array([0, 1, 2, 3, 4, 6, 7, 8, 9, 10]), ), - (EvalBSpline(5), (10,), (10, 4), jnp.arange(1, 5)), + (BSplineEval(5), (10,), (10, 4), jnp.arange(1, 5)), ], ) def test_apply_identifiability_constraints_by_basis_component( @@ -207,7 +207,7 @@ def test_apply_constraint_with_invalid(invalid_entries): ) def test_apply_constraint_by_basis_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" - basis = EvalBSpline(5) + basis = BSplineEval(5) x = basis.compute_features( np.random.randn( 10, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 872cced8..5e4ce13d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,11 +12,11 @@ @pytest.mark.parametrize( "bas", [ - basis.EvalMSpline(5), - basis.EvalBSpline(5), - basis.EvalCyclicBSpline(5), - basis.EvalOrthExponential(5, decay_rates=np.arange(1, 6)), - basis.EvalRaisedCosineLinear(5), + basis.MSplineEval(5), + basis.BSplineEval(5), + basis.CyclicBSplineEval(5), + basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6)), + basis.RaisedCosineLinearEval(5), ], ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): @@ -30,11 +30,11 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): @pytest.mark.parametrize( "bas", [ - basis.EvalMSpline(5), - basis.EvalBSpline(5), - basis.EvalCyclicBSpline(5), - basis.EvalRaisedCosineLinear(5), - basis.EvalRaisedCosineLog(5), + basis.MSplineEval(5), + basis.BSplineEval(5), + basis.CyclicBSplineEval(5), + basis.RaisedCosineLinearEval(5), + basis.RaisedCosineLogEval(5), ], ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): @@ -49,11 +49,11 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): @pytest.mark.parametrize( "bas", [ - basis.EvalMSpline(5), - basis.EvalBSpline(5), - basis.EvalCyclicBSpline(5), - basis.EvalRaisedCosineLinear(5), - basis.EvalRaisedCosineLog(5), + basis.MSplineEval(5), + basis.BSplineEval(5), + basis.CyclicBSplineEval(5), + basis.RaisedCosineLinearEval(5), + basis.RaisedCosineLogEval(5), ], ) def test_sklearn_transformer_pipeline_cv_multiprocess( @@ -74,11 +74,11 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( @pytest.mark.parametrize( "bas_cls", [ - basis.EvalMSpline, - basis.EvalMSpline, - basis.EvalCyclicBSpline, - basis.EvalRaisedCosineLinear, - basis.EvalRaisedCosineLog, + basis.MSplineEval, + basis.MSplineEval, + basis.CyclicBSplineEval, + basis.RaisedCosineLinearEval, + basis.RaisedCosineLogEval, ], ) def test_sklearn_transformer_pipeline_cv_directly_over_basis( @@ -95,11 +95,11 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( @pytest.mark.parametrize( "bas_cls", [ - basis.EvalMSpline, - basis.EvalMSpline, - basis.EvalCyclicBSpline, - basis.EvalRaisedCosineLinear, - basis.EvalRaisedCosineLog, + basis.MSplineEval, + basis.MSplineEval, + basis.CyclicBSplineEval, + basis.RaisedCosineLinearEval, + basis.RaisedCosineLogEval, ], ) def test_sklearn_transformer_pipeline_cv_illegal_combination( @@ -123,35 +123,35 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( @pytest.mark.parametrize( "bas, expected_nans", [ - (basis.EvalMSpline(5), 0), - (basis.EvalBSpline(5), 0), - (basis.EvalCyclicBSpline(5), 0), - (basis.EvalOrthExponential(5, decay_rates=np.arange(1, 6)), 0), - (basis.EvalRaisedCosineLinear(5), 0), - (basis.EvalRaisedCosineLog(5), 0), - (basis.EvalRaisedCosineLog(5) + basis.EvalMSpline(5), 0), - (basis.ConvMSpline(5, window_size=3), 6), - (basis.ConvBSpline(5, window_size=3), 6), + (basis.MSplineEval(5), 0), + (basis.BSplineEval(5), 0), + (basis.CyclicBSplineEval(5), 0), + (basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6)), 0), + (basis.RaisedCosineLinearEval(5), 0), + (basis.RaisedCosineLogEval(5), 0), + (basis.RaisedCosineLogEval(5) + basis.MSplineEval(5), 0), + (basis.MSplineConv(5, window_size=3), 6), + (basis.BSplineConv(5, window_size=3), 6), ( - basis.ConvCyclicBSpline( + basis.CyclicBSplineConv( 5, window_size=3, conv_kwargs=dict(predictor_causality="acausal") ), 4, ), ( - basis.ConvOrthExponential( + basis.OrthExponentialConv( 5, decay_rates=np.linspace(0.1, 1, 5), window_size=7 ), 14, ), - (basis.ConvRaisedCosineLinear(5, window_size=3), 6), - (basis.ConvRaisedCosineLog(5, window_size=3), 6), + (basis.RaisedCosineLinearConv(5, window_size=3), 6), + (basis.RaisedCosineLogConv(5, window_size=3), 6), ( - basis.ConvRaisedCosineLog(5, window_size=3) + basis.EvalMSpline(5), + basis.RaisedCosineLogConv(5, window_size=3) + basis.MSplineEval(5), 6, ), ( - basis.ConvRaisedCosineLog(5, window_size=3) * basis.EvalMSpline(5), + basis.RaisedCosineLogConv(5, window_size=3) * basis.MSplineEval(5), 6, ), ], diff --git a/tests/test_simulation.py b/tests/test_simulation.py index e64072a7..e06f482b 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -213,7 +213,7 @@ def test_least_square_correctness(): # set up problem dimensionality ws, n_neurons_receiver, n_neurons_sender, n_basis_funcs = 100, 1, 2, 10 # evaluate a basis - _, eval_basis = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(ws) + _, eval_basis = basis.RaisedCosineLinearEval(n_basis_funcs).evaluate_on_grid(ws) # generate random weights to define filters weights = np.random.normal( size=(n_neurons_receiver, n_neurons_sender, n_basis_funcs) From 550835cf49846d1df201a4aa8b7d1950c85d3d60 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 3 Dec 2024 17:55:47 -0500 Subject: [PATCH 083/109] fixed warns --- tests/test_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 15a9e5f7..b2729465 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1073,7 +1073,7 @@ def test_number_of_required_inputs_compute_features( ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( - TypeError, match="takes 2 positional arguments but \d were given" + TypeError, match=r"takes 2 positional arguments but \d were given" ) else: expectation = does_not_raise() From fb3dd75e924591f089750c6c152f4fd75ae64e02 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 11:46:17 -0500 Subject: [PATCH 084/109] fix tutorials --- docs/background/plot_01_1D_basis_function.md | 39 +++++++++++++++++++- docs/background/plot_02_ND_basis_function.md | 23 ++++++------ 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 424be525..fa5c5879 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -38,6 +38,7 @@ warnings.filterwarnings( ), category=RuntimeWarning, ) + ``` (simple_basis_function)= @@ -58,6 +59,9 @@ import pynapple as nap import nemos as nmo +# configure plots some +plt.style.use(nmo.styles.plot_style) + # Initialize hyperparameters order = 4 n_basis = 10 @@ -66,6 +70,39 @@ n_basis = 10 bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order) ``` +We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements. + +```{code-cell} ipython3 +# evaluate the basis on 100 sample points +x, y = bspline.evaluate_on_grid(100) + +fig = plt.figure(figsize=(5, 3)) +plt.plot(x, y, lw=2) +plt.title("B-Spline Basis") +``` + +```{code-cell} ipython3 +:tags: [hide-input] + +# save image for thumbnail +from pathlib import Path +import os + +root = os.environ.get("READTHEDOCS_OUTPUT") +if root: + path = Path(root) / "html/_static/thumbnails/background" +# if local store in ../_build/html/... +else: + path = Path("../_build/html/_static/thumbnails/background") + +# make sure the folder exists if run from build +if root or Path("../_build/html/_static").exists(): + path.mkdir(parents=True, exist_ok=True) + +if path.exists(): + fig.savefig(path / "plot_01_1D_basis_function.svg") +``` + ## Feature Computation The bases in the `nemos.basis` module can be grouped into two categories: @@ -75,7 +112,6 @@ The bases in the `nemos.basis` module can be grouped into two categories: Let's see how this two modalities operate. - ```{code-cell} ipython3 eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis) conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100) @@ -165,6 +201,7 @@ the fixed range basis. ```{code-cell} ipython3 +samples = np.linspace(0, 1, 200) fig, axs = plt.subplots(2,1, sharex=True) plt.suptitle("B-spline basis ") axs[0].plot(samples, bspline.compute_features(samples), color="k") diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/plot_02_ND_basis_function.md index 03c0062d..a9636285 100644 --- a/docs/background/plot_02_ND_basis_function.md +++ b/docs/background/plot_02_ND_basis_function.md @@ -150,7 +150,7 @@ x_coord = np.linspace(0, 1, 1000) y_coord = np.linspace(0, 1, 1000) # Evaluate the basis functions for the given trajectory. -eval_basis = additive_basis(x_coord, y_coord) +eval_basis = additive_basis.compute_features(x_coord, y_coord) print(f"Sum of two 1D splines with {eval_basis.shape[1]} " f"basis element and {eval_basis.shape[0]} samples:\n" @@ -169,13 +169,13 @@ basis_b_element = 1 fig, axs = plt.subplots(1, 2, figsize=(6, 3)) axs[0].set_title(f"$a_{{{basis_a_element}}}(x)$", color="b") -axs[0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3) -axs[0].plot(x_coord, a_basis(x_coord)[:, basis_a_element], "b") +axs[0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) +axs[0].plot(x_coord, a_basis.compute_features(x_coord)[:, basis_a_element], "b") axs[0].set_xlabel("x-coord") axs[1].set_title(f"$b_{{{basis_b_element}}}(x)$", color="b") -axs[1].plot(y_coord, b_basis(x_coord), "grey", alpha=.3) -axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b") +axs[1].plot(y_coord, b_basis.compute_features(x_coord), "grey", alpha=.3) +axs[1].plot(y_coord, b_basis.compute_features(x_coord)[:, basis_b_element], "b") axs[1].set_xlabel("y-coord") plt.tight_layout() ``` @@ -242,7 +242,7 @@ The number of elements of the product basis will be the product of the elements ```{code-cell} ipython3 # Evaluate the product basis at the x and y coordinates -eval_basis = prod_basis(x_coord, y_coord) +eval_basis = prod_basis.compute_features(x_coord, y_coord) # Output the number of elements and samples of the evaluated basis, # as well as the number of elements in the original 1D basis objects @@ -268,13 +268,13 @@ fig, axs = plt.subplots(3,3,figsize=(8, 6)) cc = 0 for i, j in element_pairs: # plot the element form a_basis - axs[cc, 0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3) - axs[cc, 0].plot(x_coord, a_basis(x_coord)[:, i], "b") + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b") axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b') # plot the element form b_basis - axs[cc, 1].plot(y_coord, b_basis(y_coord), "grey", alpha=.3) - axs[cc, 1].plot(y_coord, b_basis(y_coord)[:, j], "b") + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3) + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b") axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b') # select & plot the corresponding product basis element @@ -322,7 +322,6 @@ in a linear maze and the LFP phase angle. ::: - N-Dimensional Basis ------------------- Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of @@ -346,7 +345,7 @@ c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) -eval_basis = prod_basis_3(samples, samples, samples) +eval_basis = prod_basis_3.compute_features(samples, samples, samples) print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} " f"basis elements.\nEvaluation output of shape {eval_basis.shape}") From 892c4120589eec1ba232920ad20a8bd926feedce Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 11:49:24 -0500 Subject: [PATCH 085/109] fix all tutorials --- docs/tutorials/plot_03_grid_cells.md | 2 +- docs/tutorials/plot_05_place_cells.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index 65bfbc43..b884bc68 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -175,7 +175,7 @@ Now we can "evaluate" the basis for each position of the animal ```{code-cell} ipython3 -position_basis = basis_2d(position["x"], position["y"]) +position_basis = basis_2d.compute_features(position["x"], position["y"]) ``` Now try to make sense of what it is diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index ae8e38ac..ce094f28 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -357,7 +357,7 @@ The object basis only tell us how each basis covers the feature space. For each ```{code-cell} ipython3 -X = basis(position, theta, speed) +X = basis.compute_features(position, theta, speed) ``` `X` is our design matrix. For each timestamps, it contains the information about the current position, @@ -455,7 +455,7 @@ predicted_rates = {} for m in models: print("1. Evaluating basis : ", m) - X = models[m](*features[m]) + X = models[m].compute_features(*features[m]) print("2. Fitting model : ", m) glm.fit( From 743bd8da50a3bea941f6a025c319e44b21fb40d6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 12:42:21 -0500 Subject: [PATCH 086/109] fixed links htmlproofer --- docs/background/plot_01_1D_basis_function.md | 2 +- docs/developers_notes/04-basis_module.md | 8 ++++---- docs/how_to_guide/plot_06_glm_pytree.md | 2 +- docs/tutorials/plot_03_grid_cells.md | 2 +- docs/tutorials/plot_05_place_cells.md | 6 +++--- src/nemos/basis/_basis_mixin.py | 7 ++++--- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index fa5c5879..e8857dc2 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -46,7 +46,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.basis.MSplineEval). +We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval). The hyperparameters required to initialize this class are: - The number of basis functions, which should be a positive integer. diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index f7823a86..45decfb1 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -26,7 +26,7 @@ Abstract Class Basis └─ Concrete Subclass OrthExponentialBasis ``` -The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`_evaluate`](nemos.basis._basis.Basis._evaluate) that is specific for each concrete class. See below for more details. +The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details. ## The Class `nemos.basis._basis.Basis` @@ -61,14 +61,14 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method. +3. Calls the `_evaluate` method. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. [`_evaluate`](nemos.basis._basis.Basis._evaluate): Evaluates a basis over some specified samples. +1. `_evaluate`: Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines @@ -77,7 +77,7 @@ The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the followi To write a usable (i.e., concrete, non-abstract) basis object, you - **Must** inherit the abstract superclass [`Basis`](nemos.basis._basis.Basis) -- **Must** define the [`_evaluate`](nemos.basis._basis.Basis._evaluate) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. +- **Must** define the `_evaluate` and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics. - **Should not** overwrite the [`compute_features`](nemos.basis._basis.Basis.compute_features) and [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis._basis.Basis). - **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis._spline_basis.SplineBasis)). diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 36910bc3..e5949f58 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -274,7 +274,7 @@ Okay, let's use unit number 7. Now let's set up our design matrix. First, let's fit the head direction by itself. Head direction is a circular variable (pi and -pi are adjacent to each other), so we need to use a basis that has this property as well. -[`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) is one such basis. +[`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval) is one such basis. Let's create our basis and then arrange our data properly. diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index b884bc68..f59c04c6 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -143,7 +143,7 @@ position = position.interpolate(counts) ``` We can define a two-dimensional basis for position by multiplying two one-dimensional bases, -see [here](../../background/plot_02_ND_basis_function) for more details. +see [here](composing_basis_function) for more details. ```{code-cell} ipython3 basis_2d = nmo.basis.RaisedCosineLinearEval( diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index ce094f28..57ca4298 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -335,9 +335,9 @@ print(count.shape) For each feature, we will use a different set of basis : - - position : [`MSplineEval`](nemos.basis.basis.MSplineEval) - - theta phase : [`CyclicBSplineEval`](nemos.basis.basis.CyclicBSplineEval) - - speed : [`MSplineEval`](nemos.basis.basis.MSplineEval) + - position : [`MSplineEval`](nemos.basis.MSplineEval) + - theta phase : [`CyclicBSplineEval`](nemos.basis.CyclicBSplineEval) + - speed : [`MSplineEval`](nemos.basis.MSplineEval) ```{code-cell} ipython3 diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 3ae84376..8579aa09 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -1,4 +1,5 @@ """Mixin classes for basis.""" +from __future__ import annotations import copy import inspect @@ -95,10 +96,10 @@ def _compute_features(self, *xi: ArrayLike): samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions except for the sample-axis are flattened, so that the method always returns a matrix. For example, if samples are of shape (num_samples, 2, 3), the output will be - (num_samples, num_basis_funcs * 2 * 3). + ``(num_samples, num_basis_funcs * 2 * 3)``. The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of - transform will be (num_samples, num_basis_funcs * N1 * N3 *...). + For example, if ``axis == 1`` your samples should be ``(N1, num_samples N3, ...)``, the output of + transform will be ``(num_samples, num_basis_funcs * N1 * N3 *...)``. Parameters ---------- From fb883ce47bc039fb449a18678ae42d18a0a17e54 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 12:42:45 -0500 Subject: [PATCH 087/109] linted --- src/nemos/basis/_basis_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 8579aa09..16332e6e 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -1,4 +1,5 @@ """Mixin classes for basis.""" + from __future__ import annotations import copy From d0af9c826d5b5b64212da20a5bc10770d1eeb9ab Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 14:29:16 -0500 Subject: [PATCH 088/109] updated jax link --- docs/tutorials/plot_01_current_injection.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/plot_01_current_injection.md b/docs/tutorials/plot_01_current_injection.md index 5d0df7f5..bffd5a35 100644 --- a/docs/tutorials/plot_01_current_injection.md +++ b/docs/tutorials/plot_01_current_injection.md @@ -446,7 +446,7 @@ following properties: :::{admonition} What is jax? :class: note -[jax](https://github.com/google/jax) is a Google-supported python library +[jax](https://github.com/jax-ml/jax) is a Google-supported python library for automatic differentiation. It has all sorts of neat features, but the most relevant of which for NeMoS is its GPU-compatibility and just-in-time compilation (both of which make code faster with little From c0664cf062af5ab6e946c5c8c5214f6a4900debf Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:28:25 -0500 Subject: [PATCH 089/109] Update docs/api_reference.rst Co-authored-by: William F. Broderick --- docs/api_reference.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index f41c4c02..7ad62877 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -95,7 +95,7 @@ These classes are the building blocks for the concrete basis classes. AdditiveBasis MultiplicativeBasis -**Basis As `scikit-learn` Tranformers:** +**Basis As ``scikit-learn`` Tranformers:** .. currentmodule:: nemos.basis._transformer_basis From 629f8c57d1a65ee41c1c8da1945198ed5f64f414 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:30:13 -0500 Subject: [PATCH 090/109] Update docs/background/plot_01_1D_basis_function.md Co-authored-by: William F. Broderick --- docs/background/plot_01_1D_basis_function.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index e8857dc2..67d809ed 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -106,7 +106,7 @@ if path.exists(): ## Feature Computation The bases in the `nemos.basis` module can be grouped into two categories: -1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `BSplineEval`. +1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`. 2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `BSplineConv`. From 1e59e590a590996f13a02b1172de00db4479c1c1 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:30:50 -0500 Subject: [PATCH 091/109] Update docs/developers_notes/04-basis_module.md Co-authored-by: William F. Broderick --- docs/developers_notes/04-basis_module.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index 45decfb1..7cf271e6 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -28,7 +28,7 @@ Abstract Class Basis The super-class [`Basis`](nemos.basis._basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `_evaluate` that is specific for each concrete class. See below for more details. -## The Class `nemos.basis._basis.Basis` +## The Abstract Super-class [`Basis`](nemos.basis._basis.Basis) (the-public-method-compute_features)= ### The Public Method `compute_features` From 5808b5027dd8d78cba0a039bfe4675f63a7c8a09 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:31:26 -0500 Subject: [PATCH 092/109] Update docs/background/plot_01_1D_basis_function.md Co-authored-by: William F. Broderick --- docs/background/plot_01_1D_basis_function.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 67d809ed..30d85d11 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -108,7 +108,7 @@ The bases in the `nemos.basis` module can be grouped into two categories: 1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`. -2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `BSplineConv`. +2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`. Let's see how this two modalities operate. From a683afa3271e6f2bbe614bb29b8975776eddd131 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:31:48 -0500 Subject: [PATCH 093/109] Update docs/background/plot_03_1D_convolution.md Co-authored-by: William F. Broderick --- docs/background/plot_03_1D_convolution.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md index 17dfdb1c..1967148d 100644 --- a/docs/background/plot_03_1D_convolution.md +++ b/docs/background/plot_03_1D_convolution.md @@ -188,7 +188,7 @@ if path.exists(): ## Convolve using [`Basis.compute_features`](nemos.basis._basis.Basis.compute_features) -Every basis in the `nemos.basis` module whose class name starts with "Conv" will perform a 1D convolution over the +Every basis in the `nemos.basis` module whose class name ends with "Conv" will perform a 1D convolution over the provided input when the `compute_features` method is called. The basis elements will be used as filters for the convolution. From c1773bcdc29dbe9d8bc7bb238ba1340f3ad2c81f Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 16:33:19 -0500 Subject: [PATCH 094/109] Update docs/developers_notes/04-basis_module.md Co-authored-by: William F. Broderick --- docs/developers_notes/04-basis_module.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index 7cf271e6..4803338e 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -42,7 +42,7 @@ It accepts one or more NumPy array or pynapple `Tsd` object as input, and perfor 1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case. 2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. -3. In `"eval"` mode, calls the `_evaluate` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). +3. In `"eval"` mode, calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). 4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. :::{admonition} Multiple epochs From 66cfa7148e4b4750695e0630a3536d6d2abaa748 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:16:35 -0500 Subject: [PATCH 095/109] Update docs/developers_notes/04-basis_module.md Co-authored-by: William F. Broderick --- docs/developers_notes/04-basis_module.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index 4803338e..8c663fe4 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -61,7 +61,7 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the `_evaluate` method. +3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on these samples. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods From 59abe28bc5264f97786abcb78ba20fa7c01de210 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:17:02 -0500 Subject: [PATCH 096/109] Update docs/developers_notes/04-basis_module.md Co-authored-by: William F. Broderick --- docs/developers_notes/04-basis_module.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index 8c663fe4..c162f4cb 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -68,7 +68,7 @@ This method performs the following steps: The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. `_evaluate`: Evaluates a basis over some specified samples. +1. [`_evaluate`](nemos.basis._basis.Basis._evaluate) : Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines From f376bd66ba80cce3cf72141b9ed6c1dfd3753c07 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:20:08 -0500 Subject: [PATCH 097/109] Update src/nemos/basis/_basis_mixin.py Co-authored-by: William F. Broderick --- src/nemos/basis/_basis_mixin.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 16332e6e..e7a58212 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -137,11 +137,6 @@ def _set_kernel(self) -> "ConvBasisMixin": The instance itself, modified to include the computed kernel if applicable. This allows for method chaining and integration into transformation pipelines. - Notes - ----- - Subclasses implementing this method should detail the specifics of how the kernel is - computed and how the input parameters are utilized. If the basis operates in 'eval' - mode exclusively, this method should simply return `self` without modification. """ self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size)) return self From 1aba9a54ea345edbddb66335d62634672ec12e80 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:20:43 -0500 Subject: [PATCH 098/109] Update src/nemos/basis/_basis_mixin.py Co-authored-by: William F. Broderick --- src/nemos/basis/_basis_mixin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index e7a58212..677bdc9b 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -143,9 +143,7 @@ def _set_kernel(self) -> "ConvBasisMixin": @property def window_size(self): - """Window size as number of samples. - - Duration of the convolutional kernel in number of samples. + """Duration of the convolutional kernel in number of samples. """ return self._window_size From cb7e1e779b985fa01315cedbe14e70162eb33b44 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:23:10 -0500 Subject: [PATCH 099/109] Update src/nemos/basis/_basis_mixin.py Co-authored-by: William F. Broderick --- src/nemos/basis/_basis_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 677bdc9b..d32cbc1d 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -152,7 +152,7 @@ def window_size(self, window_size): """Setter for the window size parameter.""" if window_size is None: raise ValueError( - "If the basis is in `conv` mode, you must provide a window_size!" + "You must provide a window_size!" ) elif not (isinstance(window_size, int) and window_size > 0): From 8427d6eea8b75920bdfda2fa5725a9ee3c4c9406 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 17:23:32 -0500 Subject: [PATCH 100/109] Update src/nemos/basis/_transformer_basis.py Co-authored-by: William F. Broderick --- src/nemos/basis/_transformer_basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index f6abdad5..fd570214 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -64,7 +64,7 @@ def __init__(self, basis: Basis): @staticmethod def _unpack_inputs(X: FeatureMatrix): - """Unpack impute without using transpose. + """Unpack inputs without using transpose. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` From 2b0599446ee9719c4d4cb87b7b0abf416510bb97 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 17:46:31 -0500 Subject: [PATCH 101/109] edits to docstrings --- .../plot_05_sklearn_pipeline_cv_demo.md | 8 +++- src/nemos/basis/_basis.py | 3 -- src/nemos/basis/_basis_mixin.py | 21 +++++----- src/nemos/basis/_raised_cosine_basis.py | 2 +- src/nemos/basis/basis.py | 42 +++++++++++++++++++ tests/test_basis.py | 2 +- 6 files changed, 62 insertions(+), 16 deletions(-) diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index 90fe1716..166a5cbb 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -152,12 +152,18 @@ sns.despine(ax=ax) ### Converting NeMoS `Basis` to a transformer In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): ```{code-cell} ipython3 bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5) + +# initalize using the constructor +trans_bas = nmo.basis.TransformerBasis(bas) + +# equivalent initialization via "to_transformer" trans_bas = bas.to_transformer() + ``` [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 698d4ce4..dec2e331 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -550,9 +550,6 @@ def _get_feature_slicing( Calculate and return the slicing for features based on the input structure. This method determines how to slice the features for different basis types. - If the instance is of ``AdditiveBasis`` type, the slicing is calculated recursively - for each component basis. Otherwise, it determines the slicing based on - the number of basis functions and ``split_by_input`` flag. Parameters ---------- diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 16332e6e..09060ac2 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -91,21 +91,23 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): def _compute_features(self, *xi: ArrayLike): """ - Apply the basis transformation to the input data. + Convolve basis functions with input time series. A bank of basis filters (created by calling fit) is convolved with the - samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions + input data. Inputs can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions except for the sample-axis are flattened, so that the method always returns a matrix. - For example, if samples are of shape (num_samples, 2, 3), the output will be + + For example, if inputs are of shape (num_samples, 2, 3), the output will be ``(num_samples, num_basis_funcs * 2 * 3)``. + The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your samples should be ``(N1, num_samples N3, ...)``, the output of - transform will be ``(num_samples, num_basis_funcs * N1 * N3 *...)``. + For example, if ``axis == 1`` your input should be of shape ``(N1, num_samples N3, ...)``, the output of + transform will be of shape ``(num_samples, num_basis_funcs * N1 * N3 *...)``. Parameters ---------- *xi: - The input samples over which to apply the basis transformation. The samples can be passed + The input data over which to apply the basis transformation. The samples can be passed as multiple arguments, each representing a different dimension for multivariate inputs. """ @@ -126,7 +128,7 @@ def _set_kernel(self) -> "ConvBasisMixin": Prepare or compute the convolutional kernel for the basis functions. This method is called to prepare the basis functions for convolution operations - in subclasses where the 'conv' mode is used. It typically involves computing a + in subclasses. It computes a kernel based on the basis functions that will be used for convolution with the input data. The specifics of kernel computation depend on the subclass implementation and the nature of the basis functions. @@ -134,14 +136,13 @@ def _set_kernel(self) -> "ConvBasisMixin": Returns ------- self : - The instance itself, modified to include the computed kernel if applicable. This + The instance itself, modified to include the computed kernel. This allows for method chaining and integration into transformation pipelines. Notes ----- Subclasses implementing this method should detail the specifics of how the kernel is - computed and how the input parameters are utilized. If the basis operates in 'eval' - mode exclusively, this method should simply return `self` without modification. + computed and how the input parameters are utilized. """ self.kernel_ = self._evaluate(np.linspace(0, 1, self.window_size)) return self diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index db442fb4..cef8d28d 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -123,7 +123,7 @@ def _evaluate( # call these _evaluate # basis1 = nmo.basis.RaisedCosineBasisLinear(5) # basis2 = nmo.basis.RaisedCosineBasisLog(5) # additive_basis = basis1 + basis2 - # additive_basis(*([x] * 2)) would modify both inputs + # additive_basis._evaluate(*([x] * 2)) would modify both inputs sample_pts, _ = min_max_rescale_samples( np.copy(sample_pts), getattr(self, "bounds", None) ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 730aaee9..dc7bd4f8 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -177,6 +177,12 @@ class BSplineConv(ConvBasisMixin, BSplineBasis): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. References ---------- @@ -408,6 +414,12 @@ class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. Examples -------- @@ -668,6 +680,12 @@ class MSplineConv(ConvBasisMixin, MSplineBasis): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. References ---------- @@ -912,6 +930,12 @@ class RaisedCosineLinearConv( label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. References ---------- @@ -1033,6 +1057,12 @@ class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. References ---------- @@ -1157,6 +1187,12 @@ class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. References ---------- @@ -1382,6 +1418,12 @@ class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): label : The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts. + conv_kwargs: + Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor`; + These arguments are used to change the default behavior of the convolution. + For example, changing the ``predictor_causality``, which by default is set to ``"causal"``. + Note that one cannot change the default value for the ``axis`` parameter. Basis assumes + that the convolution axis is ``axis=0``. Examples -------- diff --git a/tests/test_basis.py b/tests/test_basis.py index b2729465..a5d7bae7 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -132,7 +132,7 @@ def test_all_basis_are_tested() -> None: ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), ( "compute_features", - "Apply the basis transformation to the input data", + "Apply the basis transformation to the input data|Convolve basis functions with input time series", ), ( "split_by_feature", From 98c2a742a82227f2121b65891867f01545510938 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Mon, 9 Dec 2024 18:35:37 -0500 Subject: [PATCH 102/109] Update src/nemos/basis/_basis_mixin.py Co-authored-by: William F. Broderick --- src/nemos/basis/_basis_mixin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index d32cbc1d..97069760 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -35,9 +35,7 @@ def _compute_features(self, *xi: ArrayLike): Returns ------- : - A matrix with the transformed features. The basis evaluated at the samples, - or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array - or a pynapple Tsd. + A matrix with the transformed features. """ return self._evaluate(*xi) From 4d4c70f76bf81e4f1a22206736b09e6cc8f757bb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 18:45:34 -0500 Subject: [PATCH 103/109] fixed inheritance and removed TransformerMixin init call --- src/nemos/basis/_basis.py | 122 +++++++++++++++-------- src/nemos/basis/_basis_mixin.py | 11 +- src/nemos/basis/_decaying_exponential.py | 5 +- src/nemos/basis/_raised_cosine_basis.py | 5 +- src/nemos/basis/_spline_basis.py | 7 +- src/nemos/basis/basis.py | 4 +- tests/test_basis.py | 2 +- 7 files changed, 100 insertions(+), 56 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index dec2e331..2cd7d035 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -9,7 +9,7 @@ import jax import numpy as np from numpy.typing import ArrayLike, NDArray -from pynapple import Tsd, TsdFrame +from pynapple import Tsd, TsdFrame, TsdTensor from ..base_class import Base from ..type_casting import support_pynapple @@ -242,13 +242,13 @@ def add_constant(x): return X @check_transform_input - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Apply the basis transformation to the input data. This method is designed to be a high-level interface for transforming input data using the basis functions defined by the subclass. Depending on the basis' - mode ('eval' or 'conv'), it either evaluates the basis functions at the sample + mode ('Eval' or 'Conv'), it either evaluates the basis functions at the sample points or performs a convolution operation between the input data and the basis functions. @@ -256,7 +256,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: ---------- *xi : Input data arrays to be transformed. The shape and content requirements - depend on the subclass and mode of operation ('eval' or 'conv'). + depend on the subclass and mode of operation ('Eval' or 'Conv'). Returns ------- @@ -276,7 +276,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: return self._compute_features(*xi) @abc.abstractmethod - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """Convolve or evaluate the basis.""" pass @@ -286,7 +286,7 @@ def _set_kernel(self): pass @abc.abstractmethod - def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Abstract method to evaluate the basis functions at given points. @@ -579,37 +579,14 @@ def _get_feature_slicing( n_inputs = n_inputs or self._n_basis_input start_slice = start_slice or 0 - # If the instance is of AdditiveBasis type, handle slicing for the additive components - if isinstance(self, AdditiveBasis): - split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], - start_slice, - split_by_input=split_by_input, - ) - sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], - start_slice, - split_by_input=split_by_input, - ) - split_dict = self._merge_slicing_dicts(split_dict, sp2) - else: - # Handle the default case for other basis types - split_dict, start_slice = self._get_default_slicing( - split_by_input, start_slice - ) + # Handle the default case for non-additive basis types + # See overwritten method for recursion logic + split_dict, start_slice = self._get_default_slicing( + split_by_input=split_by_input, start_slice=start_slice + ) return split_dict, start_slice - def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: - """Merge two slicing dictionaries, handling key conflicts.""" - for key, val in dict2.items(): - if key in dict1: - new_key = self._generate_unique_key(dict1, key) - dict1[new_key] = val - else: - dict1[key] = val - return dict1 - @staticmethod def _generate_unique_key(existing_dict: dict, key: str) -> str: """Generate a unique key if there is a conflict.""" @@ -884,7 +861,7 @@ def _check_n_basis_min(self) -> None: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Evaluate the basis at the input samples. @@ -924,7 +901,7 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: return X @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: r""" Examples -------- @@ -941,7 +918,7 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: """ return super().compute_features(*xi) - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Compute features for added bases and concatenate. @@ -1159,6 +1136,70 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """ return super().evaluate_on_grid(*n_samples) + def _get_feature_slicing( + self, + n_inputs: Optional[tuple] = None, + start_slice: Optional[int] = None, + split_by_input: bool = True, + ) -> Tuple[dict, int]: + """ + Calculate and return the slicing for features based on the input structure. + + This method determines how to slice the features for different basis types. + + Parameters + ---------- + n_inputs : + The number of input basis for each component, by default it uses ``self._n_basis_input``. + start_slice : + The starting index for slicing, by default it starts from 0. + split_by_input : + Flag indicating whether to split the slicing by individual inputs or not. + If ``False``, a single slice is generated for all inputs. + + Returns + ------- + split_dict : + Dictionary with keys as labels and values as slices representing + the slicing for each input or additive component, if split_by_input equals to + True or False respectively. + start_slice : + The updated starting index after slicing. + + See Also + -------- + _get_default_slicing : Handles default slicing logic. + _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. + """ + # Set default values for n_inputs and start_slice if not provided + n_inputs = n_inputs or self._n_basis_input + start_slice = start_slice or 0 + + # If the instance is of AdditiveBasis type, handle slicing for the additive components + + split_dict, start_slice = self._basis1._get_feature_slicing( + n_inputs[: len(self._basis1._n_basis_input)], + start_slice, + split_by_input=split_by_input, + ) + sp2, start_slice = self._basis2._get_feature_slicing( + n_inputs[len(self._basis1._n_basis_input) :], + start_slice, + split_by_input=split_by_input, + ) + split_dict = self._merge_slicing_dicts(split_dict, sp2) + return split_dict, start_slice + + def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: + """Merge two slicing dictionaries, handling key conflicts.""" + for key, val in dict2.items(): + if key in dict1: + new_key = self._generate_unique_key(dict1, key) + dict1[new_key] = val + else: + dict1[key] = val + return dict1 + class MultiplicativeBasis(Basis): """ @@ -1205,7 +1246,6 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._label = "(" + basis1.label + " * " + basis2.label + ")" self._basis1 = basis1 self._basis2 = basis2 - BasisTransformerMixin.__init__(self) def _check_n_basis_min(self) -> None: pass @@ -1232,7 +1272,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis: @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: + def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Evaluate the basis at the input samples. @@ -1264,7 +1304,7 @@ def _evaluate(self, *xi: ArrayLike) -> FeatureMatrix: ) return X - def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Compute the features for the multiplied bases, and compute their outer product. @@ -1357,7 +1397,7 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike) -> FeatureMatrix: + def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Examples -------- diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 09060ac2..18b5eea4 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -7,7 +7,8 @@ from typing import Optional, Tuple, Union import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from ..convolve import create_convolutional_predictor from ._transformer_basis import TransformerBasis @@ -19,12 +20,12 @@ class EvalBasisMixin: def __init__(self, bounds: Optional[Tuple[float, float]] = None): self.bounds = bounds - def _compute_features(self, *xi: ArrayLike): + def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): """ - Apply the basis transformation to the input data. + Evaluate basis at sample points. The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a - basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. + basis element. xi[k] must be a one-dimensional array or a pynapple Tsd/TsdFrame/TsdTensor. Parameters ---------- @@ -89,7 +90,7 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs - def _compute_features(self, *xi: ArrayLike): + def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): """ Convolve basis functions with input time series. diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index a7a403fb..a474b819 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -8,7 +8,7 @@ import numpy as np import scipy.linalg -from numpy.typing import NDArray +from numpy.typing import ArrayLike, NDArray from ..type_casting import support_pynapple from ..typing import FeatureMatrix @@ -19,6 +19,7 @@ min_max_rescale_samples, ) +from pynapple import Tsd, TsdFrame, TsdTensor class OrthExponentialBasis(Basis, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -134,7 +135,7 @@ def _check_sample_size(self, *sample_pts: NDArray) -> None: @check_one_dimensional def _evaluate( self, - sample_pts: NDArray, + sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor, ) -> FeatureMatrix: """Generate basis functions with given spacing. diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index cef8d28d..7070e653 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -16,6 +16,7 @@ min_max_rescale_samples, ) +from pynapple import Tsd, TsdFrame, TsdTensor class RaisedCosineBasisLinear(Basis, abc.ABC): """Represent linearly-spaced raised cosine basis functions. @@ -101,7 +102,7 @@ def _check_width(width: float) -> None: @check_one_dimensional def _evaluate( # call these _evaluate self, - sample_pts: ArrayLike, + sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor, ) -> FeatureMatrix: """Generate basis functions with given samples. @@ -330,7 +331,7 @@ def _compute_peaks(self) -> NDArray: @check_one_dimensional def _evaluate( self, - sample_pts: ArrayLike, + sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor, ) -> FeatureMatrix: """Generate log-spaced raised cosine basis with given samples. diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index f8e93bbe..b5fde73c 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -18,6 +18,7 @@ min_max_rescale_samples, ) +from pynapple import Tsd, TsdFrame, TsdTensor class SplineBasis(Basis, abc.ABC): """ @@ -217,7 +218,7 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix: + def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Evaluate the M-spline basis functions at given sample points. @@ -334,7 +335,7 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike) -> FeatureMatrix: + def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: """ Evaluate the B-spline basis functions with given sample points. @@ -445,7 +446,7 @@ def __init__( @check_one_dimensional def _evaluate( self, - sample_pts: ArrayLike, + sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor, ) -> FeatureMatrix: """Evaluate the Cyclic B-spline basis functions with given sample points. diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index dc7bd4f8..17b8f80e 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -795,7 +795,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: class RaisedCosineLinearEval( - EvalBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin + EvalBasisMixin, RaisedCosineBasisLinear ): """ Represent linearly-spaced raised cosine basis functions. @@ -911,7 +911,7 @@ def split_by_feature( class RaisedCosineLinearConv( - ConvBasisMixin, RaisedCosineBasisLinear, BasisTransformerMixin + ConvBasisMixin, RaisedCosineBasisLinear ): """ Represent linearly-spaced raised cosine basis functions. diff --git a/tests/test_basis.py b/tests/test_basis.py index a5d7bae7..3e67aad9 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -132,7 +132,7 @@ def test_all_basis_are_tested() -> None: ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), ( "compute_features", - "Apply the basis transformation to the input data|Convolve basis functions with input time series", + "Apply the basis transformation to the input data|Convolve basis functions with input time series|Evaluate basis at sample points", ), ( "split_by_feature", From ec3c4c23c758f058925fe17c277dc91e52e8f13d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 18:46:10 -0500 Subject: [PATCH 104/109] linted --- src/nemos/basis/_basis.py | 24 ++++++++++++++++++------ src/nemos/basis/_decaying_exponential.py | 2 +- src/nemos/basis/_raised_cosine_basis.py | 2 +- src/nemos/basis/_spline_basis.py | 10 +++++++--- src/nemos/basis/basis.py | 10 +++------- 5 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index 2cd7d035..f3f76d7b 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -242,7 +242,9 @@ def add_constant(x): return X @check_transform_input - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Apply the basis transformation to the input data. @@ -276,7 +278,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Feat return self._compute_features(*xi) @abc.abstractmethod - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """Convolve or evaluate the basis.""" pass @@ -901,7 +905,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri return X @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: r""" Examples -------- @@ -918,7 +924,9 @@ def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> Featu """ return super().compute_features(*xi) - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Compute features for added bases and concatenate. @@ -1304,7 +1312,9 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri ) return X - def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _compute_features( + self, *xi: NDArray | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Compute the features for the multiplied bases, and compute their outer product. @@ -1397,7 +1407,9 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: return super().evaluate_on_grid(*n_samples) @add_docstring("compute_features", Basis) - def compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def compute_features( + self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Examples -------- diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index a474b819..7762487b 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -9,6 +9,7 @@ import numpy as np import scipy.linalg from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from ..type_casting import support_pynapple from ..typing import FeatureMatrix @@ -19,7 +20,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class OrthExponentialBasis(Basis, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 7070e653..4d14a1a2 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from ..type_casting import support_pynapple from ..typing import FeatureMatrix @@ -16,7 +17,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class RaisedCosineBasisLinear(Basis, abc.ABC): """Represent linearly-spaced raised cosine basis functions. diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index b5fde73c..dda67ab1 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -7,6 +7,7 @@ import numpy as np from numpy.typing import ArrayLike, NDArray +from pynapple import Tsd, TsdFrame, TsdTensor from scipy.interpolate import splev from ..type_casting import support_pynapple @@ -18,7 +19,6 @@ min_max_rescale_samples, ) -from pynapple import Tsd, TsdFrame, TsdTensor class SplineBasis(Basis, abc.ABC): """ @@ -218,7 +218,9 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _evaluate( + self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Evaluate the M-spline basis functions at given sample points. @@ -335,7 +337,9 @@ def __init__( @support_pynapple(conv_type="numpy") @check_transform_input @check_one_dimensional - def _evaluate(self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix: + def _evaluate( + self, sample_pts: ArrayLike | Tsd | TsdFrame | TsdTensor + ) -> FeatureMatrix: """ Evaluate the B-spline basis functions with given sample points. diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 17b8f80e..9caea358 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -9,7 +9,7 @@ from ..typing import FeatureMatrix from ._basis import add_docstring -from ._basis_mixin import BasisTransformerMixin, ConvBasisMixin, EvalBasisMixin +from ._basis_mixin import ConvBasisMixin, EvalBasisMixin from ._decaying_exponential import OrthExponentialBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis @@ -794,9 +794,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: return super().evaluate_on_grid(n_samples) -class RaisedCosineLinearEval( - EvalBasisMixin, RaisedCosineBasisLinear -): +class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear): """ Represent linearly-spaced raised cosine basis functions. @@ -910,9 +908,7 @@ def split_by_feature( return super().split_by_feature(x, axis=axis) -class RaisedCosineLinearConv( - ConvBasisMixin, RaisedCosineBasisLinear -): +class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear): """ Represent linearly-spaced raised cosine basis functions. From 442fbfbf3a12ba1746ec89c6792212760e764281 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 18:48:18 -0500 Subject: [PATCH 105/109] linted --- src/nemos/basis/_basis_mixin.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 0fe93a87..00bfbcfc 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -36,7 +36,7 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): Returns ------- : - A matrix with the transformed features. + A matrix with the transformed features. """ return self._evaluate(*xi) @@ -149,17 +149,14 @@ def _set_kernel(self) -> "ConvBasisMixin": @property def window_size(self): - """Duration of the convolutional kernel in number of samples. - """ + """Duration of the convolutional kernel in number of samples.""" return self._window_size @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" if window_size is None: - raise ValueError( - "You must provide a window_size!" - ) + raise ValueError("You must provide a window_size!") elif not (isinstance(window_size, int) and window_size > 0): raise ValueError( From a7c463bd87579a27a8d84caca763e3a8393ac45d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 18:50:32 -0500 Subject: [PATCH 106/109] fix tests --- tests/test_basis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 3e67aad9..616ce434 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -979,14 +979,14 @@ def test_fit_kernel_shape(self, cls): ( "conv", -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), + pytest.raises(ValueError, match="You must provide a window_siz"), ), ( "conv", None, pytest.raises( ValueError, - match="If the basis is in `conv` mode, you must provide a ", + match="You must provide a window_siz", ), ), ( @@ -1229,7 +1229,7 @@ def test_set_window_size(self, mode, expectation, cls): bas = cls["conv"]( n_basis_funcs=10, window_size=10, **extra_decay_rates(cls["conv"], 10) ) - with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + with pytest.raises(ValueError, match="You must provide a window_siz"): bas.set_params(window_size=None) if mode == "eval": From 905b745e0a5e29897a6db6902f61ec3dfb2d9023 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Dec 2024 18:52:27 -0500 Subject: [PATCH 107/109] fix tests --- tests/test_basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_basis.py b/tests/test_basis.py index 616ce434..5b1d82a2 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -979,14 +979,14 @@ def test_fit_kernel_shape(self, cls): ( "conv", -1, - pytest.raises(ValueError, match="You must provide a window_siz"), + pytest.raises(ValueError, match="`window_size` must be a positive integer"), ), ( "conv", None, pytest.raises( ValueError, - match="You must provide a window_siz", + match="You must provide a window_size", ), ), ( From 6a6da119d4872b975913547b43bede27cf8c54f9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 10 Dec 2024 09:17:03 -0500 Subject: [PATCH 108/109] fixed links --- docs/developers_notes/04-basis_module.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md index c162f4cb..30463671 100644 --- a/docs/developers_notes/04-basis_module.md +++ b/docs/developers_notes/04-basis_module.md @@ -42,7 +42,7 @@ It accepts one or more NumPy array or pynapple `Tsd` object as input, and perfor 1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case. 2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. -3. In `"eval"` mode, calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). +3. In `"eval"` mode, calls the `_evaluate` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis._basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor). 4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples. :::{admonition} Multiple epochs @@ -61,14 +61,14 @@ This method performs the following steps: 1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case. 2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis. -3. Calls the [`_evaluate`](nemos.basis._basis.Basis._evaluate) method on these samples. +3. Calls the `_evaluate` method on these samples. 4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid. ### Abstract Methods The [`nemos.basis._basis.Basis`](nemos.basis._basis.Basis) class has the following abstract methods, which every concrete subclass must implement: -1. [`_evaluate`](nemos.basis._basis.Basis._evaluate) : Evaluates a basis over some specified samples. +1. `_evaluate` : Evaluates a basis over some specified samples. 2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis. ## Contributors Guidelines From 00a24379f50a080f10bf5cff658e95276876a0d1 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Tue, 10 Dec 2024 15:50:43 -0500 Subject: [PATCH 109/109] updates some docstrings --- src/nemos/basis/_basis.py | 2 ++ src/nemos/basis/_basis_mixin.py | 27 ++++++++++++++------------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index f3f76d7b..07403e06 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -683,6 +683,7 @@ def split_by_feature( ------- dict A dictionary where: + - **Key**: Label of the basis. - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)`` """ @@ -1039,6 +1040,7 @@ def split_by_feature( ------- dict A dictionary where: + - **Keys**: Labels of the additive basis components. - **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape: diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 00bfbcfc..714a9141 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -21,11 +21,17 @@ def __init__(self, bounds: Optional[Tuple[float, float]] = None): self.bounds = bounds def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): - """ - Evaluate basis at sample points. + """Evaluate basis at sample points. + + The basis is evaluated at the locations specified in the inputs. For example, + ``compute_features(np.array([0, .5]))`` would return the array: + + .. code-block:: text - The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a - basis element. xi[k] must be a one-dimensional array or a pynapple Tsd/TsdFrame/TsdTensor. + b_1(0) ... b_n(0) + b_1(.5) ... b_n(.5) + + where ``b_i`` is the i-th basis. Parameters ---------- @@ -89,20 +95,15 @@ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): - """ - Convolve basis functions with input time series. + """Convolve basis functions with input time series. - A bank of basis filters (created by calling fit) is convolved with the - input data. Inputs can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions - except for the sample-axis are flattened, so that the method always returns a matrix. + A bank of basis filters is convolved with the input data. All the dimensions + except for the sample-axis are flattened, so that the method always returns a + matrix. For example, if inputs are of shape (num_samples, 2, 3), the output will be ``(num_samples, num_basis_funcs * 2 * 3)``. - The time-axis can be specified at basis initialization by setting the keyword argument ``axis``. - For example, if ``axis == 1`` your input should be of shape ``(N1, num_samples N3, ...)``, the output of - transform will be of shape ``(num_samples, num_basis_funcs * N1 * N3 *...)``. - Parameters ---------- *xi: