diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 37e28a3a..5eca7438 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,7 +15,7 @@ build: - gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9 post_build: # Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking. - - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" + - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003,https://www.nature.com/articles/s41467-017-01908-3,https://doi.org/10.1038/s41467-017-01908-3" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/" # The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset # therefore for this file I am not checking the figures. - htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003" diff --git a/docs/assets/nemos_sklearn.svg b/docs/assets/nemos_sklearn.svg new file mode 100644 index 00000000..8ea0a3e3 --- /dev/null +++ b/docs/assets/nemos_sklearn.svg @@ -0,0 +1,119 @@ + + + +scikit diff --git a/docs/assets/pipeline.svg b/docs/assets/pipeline.svg index 8c67c7ab..a38b6480 100644 --- a/docs/assets/pipeline.svg +++ b/docs/assets/pipeline.svg @@ -24,12 +24,12 @@ inkscape:deskcolor="#d1d1d1" inkscape:document-units="mm" inkscape:zoom="1.4142136" - inkscape:cx="289.20667" - inkscape:cy="-54.800776" - inkscape:window-width="2488" - inkscape:window-height="1262" + inkscape:cx="206.82873" + inkscape:cy="8.4852811" + inkscape:window-width="1800" + inkscape:window-height="1035" inkscape:window-x="0" - inkscape:window-y="25" + inkscape:window-y="44" inkscape:window-maximized="0" inkscape:current-layer="layer1" /> + y="49.623287" /> Pipeline + x="26.058722" + y="84.056198">Pipeline -Pipelining and cross-validation. +NeMoS vs sklearn. ```{toctree} :maxdepth: 2 -plot_05_sklearn_pipeline_cv_demo.md +plot_05_transformer_basis.md ``` ::: :::{grid-item-card}
-PyTrees. +PyTrees.
```{toctree} :maxdepth: 2 -plot_06_glm_pytree.md +plot_06_sklearn_pipeline_cv_demo.md ``` + +::: + +:::{grid-item-card} + +
+PyTrees. +
+ +```{toctree} +:maxdepth: 2 + +plot_07_glm_pytree.md +``` + ::: :::: diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md new file mode 100644 index 00000000..5cdd10ee --- /dev/null +++ b/docs/how_to_guide/plot_05_transformer_basis.md @@ -0,0 +1,184 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Using bases as scikit-learn transformers + +(tansformer-vs-nemos-basis)= +## scikit-learn Transformers and NeMoS Basis + +`scikit-learn` is a powerful machine learning library that provides advanced tools for creating data analysis pipelines, from input transformations to model fitting and cross-validation. + +All of `scikit-learn`'s machinery relies on strict assumptions about input structure. In particular, all `scikit-learn` +objects require inputs to be arrays of at most two dimensions, where the first dimension represents the time (or samples) +axis, and the second dimension represents features. +While this may feel rigid, it enables transformations to be seamlessly chained together, greatly simplifying the +process of building stable, complex pipelines. + +They can accept arrays or `pynapple` time series data, which can take any shape as long as the time (or sample) axis is the first of each array. +Furthermore, `NeMoS` design favours object composability: one can combine bases into [`CompositeBasis`](composing_basis_function) objects to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.). + +Both approaches to data transformation are valuable and each has its own advantages. Wouldn't it be great if one could combine the two? Well, this is what NeMoS `TransformerBasis` is for! + + +## From Basis to TransformerBasis + +:::{admonition} Composite Basis +:class: note + +To learn more on composite basis, take a look at [this note](composing_basis_function). +::: + +With NeMoS, you can easily create a basis which accepts two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`. + +```{code-cell} ipython3 +import numpy as np +import nemos as nmo + +# create the arrays +n_samples, n_neurons = 100, 5 +counts = np.random.poisson(size=(100, 5)) +speed = np.random.normal(size=(100)) + +# create a composite basis +counts_basis = nmo.basis.RaisedCosineLogConv(5, window_size=10) +speed_basis = nmo.basis.BSplineEval(5) +composite_basis = counts_basis + speed_basis + +# compute the features +X = composite_basis.compute_features(counts, speed) + +``` + +### Converting NeMoS `Basis` to a transformer + +Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline. +In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a `scikit-learn`-compliant transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. + +Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either by using the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): + + +```{code-cell} ipython3 +bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5) + +# initalize using the constructor +trans_bas = nmo.basis.TransformerBasis(bas) + +# equivalent initialization via "to_transformer" +trans_bas = bas.to_transformer() + +``` + +[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: + + +```{code-cell} ipython3 +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) +``` + +We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), nor does changing the original [`Basis`](nemos.basis._basis.Basis) modify the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: + + +```{code-cell} ipython3 +trans_bas.n_basis_funcs = 10 +bas.n_basis_funcs = 100 + +print(bas.n_basis_funcs, trans_bas.n_basis_funcs) +``` + +As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore they all accept a single 2D array. + +## Setting up the TransformerBasis + +At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right? + +```{code-cell} ipython3 + +# reinstantiate the basis transformer for illustration porpuses +composite_basis = counts_basis + speed_basis +trans_bas = (composite_basis).to_transformer() +# concatenate the inputs +inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1) +print(inp.shape) + +try: + trans_bas.fit_transform(inp) +except RuntimeError as e: + print(repr(e)) + +``` + +...Unfortunately, not yet. The problem is that the basis doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`. + +You can provide this information by calling the `set_input_shape` method of the basis. + +This can be called before or after the transformer basis is defined. The method extracts and stores the number of columns for each input. There are multiple ways to call this method: + +- It directly accepts the input: `composite_basis.set_input_shape(counts, speed)`. +- If the input is 1D or 2D, it also accepts the number of columns: `composite_basis.set_input_shape(5, 1)`. +- A tuple containing the shapes of all except the first: `composite_basis.set_input_shape((5,), (1,))`. +- A mix of the above methods: `composite_basis.set_input_shape(counts, 1)`. + +:::{note} + +Note that what `set_input_shapes` requires are the dimensions of the input stimuli, with the exception of the sample +axis. For example, if the input is a 4D tensor, one needs to provide the last 3 dimensions: + +```{code} ipython3 +# generate a 4D input +x = np.random.randn(10, 3, 2, 1) + +# define and setup the basis +basis = nmo.basis.BSplineEval(5).set_input_shape((3, 2, 1)) + +X = basis.to_transformer().transform( + x.reshape(10, -1) # reshape to 2D +) +``` +::: + +You can also invert the order of operations and call `to_transform` first and then set the input shapes. +```{code-cell} ipython3 + +trans_bas = composite_basis.to_transformer() +trans_bas.set_input_shape(5, 1) +out = trans_bas.fit_transform(inp) +``` + +:::{note} + +If you define a basis and call `compute_features` on your inputs, internally, it will store its shapes, +and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`. +::: + +:::{warning} + +If for some reason you need to provide an input of different shape to an already set-up transformer, you must reset the +`TransformerBasis` with `set_input_shape`. +::: + +```{code-cell} ipython3 + +# define inputs with different shapes and concatenate +x, y = np.random.poisson(size=(10, 3)), np.random.randn(10, 2, 3) +inp2 = np.concatenate([x, y.reshape(10, 6)], axis=1) + +trans_bas = composite_basis.to_transformer() +trans_bas.set_input_shape(3, (2, 3)) +out2 = trans_bas.fit_transform(inp2) +``` + + +### Learn more + +If you want to learn more about how to select basis' hyperparameters with `sklearn` pipelining and cross-validation, check out [this how-to guide](sklearn-how-to). + diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md similarity index 90% rename from docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md rename to docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md index 9f5a9652..4073b928 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md @@ -71,7 +71,8 @@ To set up a scikit-learn [`Pipeline`](https://scikit-learn.org/1.5/modules/gener Each transformation step takes a 2D array `X` of shape `(num_samples, num_original_features)` as input and outputs another 2D array of shape `(num_samples, num_transformed_features)`. The final step takes a pair `(X, y)`, where `X` is as before, and `y` is a 1D array of shape `(n_samples,)` containing the observations to be modeled. You can define a pipeline as follows: -```python + +```{code} ipython3 from sklearn.pipeline import Pipeline # Assume transformer_i/predictor is a transformer/model object @@ -92,7 +93,7 @@ Here we used a placeholder `"label_i"` for demonstration; you should choose a mo ::: Calling `pipe.fit(X, y)` will perform the following computations: -```python +```{code} ipython3 # Chain of transformations X1 = transformer_1.fit_transform(X) X2 = transformer_2.fit_transform(X1) @@ -111,6 +112,7 @@ Pipelines not only streamline and simplify your code but also offer several othe In the following sections, we will showcase this approach with a concrete example: selecting the appropriate basis type and number of bases for a GLM regression in NeMoS. ## Combining basis transformations and GLM in a pipeline + Let's start by creating some toy data. @@ -150,9 +152,7 @@ sns.despine(ax=ax) ``` ### Converting NeMoS `Basis` to a transformer -In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class. - -Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer): +In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. ```{code-cell} ipython3 @@ -164,24 +164,15 @@ trans_bas = nmo.basis.TransformerBasis(bas) # equivalent initialization via "to_transformer" trans_bas = bas.to_transformer() +# setup the transformer +trans_bas.set_input_shape(1) ``` -[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes: - - -```{code-cell} ipython3 -print(bas.n_basis_funcs, trans_bas.n_basis_funcs) -``` - -We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created: +:::{admonition} Learn More about `TransformerBasis` +:class: note - -```{code-cell} ipython3 -trans_bas.n_basis_funcs = 10 -bas.n_basis_funcs = 100 - -print(bas.n_basis_funcs, trans_bas.n_basis_funcs) -``` +To learn more about `sklearn` transformers and `TransforerBasis`, check out [this note](tansformer-vs-nemos-basis). +::: ### Creating and fitting a pipeline We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data. @@ -194,7 +185,7 @@ pipeline = Pipeline( [ ( "transformerbasis", - nmo.basis.RaisedCosineLinearEval(6).to_transformer(), + nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1).to_transformer(), ), ( "glm", @@ -311,7 +302,7 @@ gridsearch.fit(X, y) To appreciate how much boiler-plate code we are saving by calling scikit-learn cross-validation, below we can see how this cross-validation will look like in a manual loop. -```python +```{code} ipython from itertools import product from copy import deepcopy @@ -439,7 +430,7 @@ if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): - fig.savefig(path / "plot_05_sklearn_pipeline_cv_demo.svg") + fig.savefig(path / "plot_06_sklearn_pipeline_cv_demo.svg") ``` 🚀🚀🚀 **Success!** 🚀🚀🚀 @@ -457,12 +448,12 @@ Here we include `transformerbasis___basis` in the parameter grid to try differen param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis___basis=( - nmo.basis.RaisedCosineLinearEval(5), - nmo.basis.RaisedCosineLinearEval(10), - nmo.basis.RaisedCosineLogEval(5), - nmo.basis.RaisedCosineLogEval(10), - nmo.basis.MSplineEval(5), - nmo.basis.MSplineEval(10), + nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(10).set_input_shape(1), + nmo.basis.MSplineEval(5).set_input_shape(1), + nmo.basis.MSplineEval(10).set_input_shape(1), ), ) ``` @@ -538,17 +529,21 @@ The plot confirms that the firing rate distribution is accurately captured by ou :::{warning} Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error: -```python + + + +```{code} ipython + param_grid = dict( glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100), transformerbasis___basis=( - nmo.basis.RaisedCosineLinearEval(5), - nmo.basis.RaisedCosineLinearEval(10), - nmo.basis.RaisedCosineLogEval(5), - nmo.basis.RaisedCosineLogEval(10), - nmo.basis.MSplineEval(5), - nmo.basis.MSplineEval(10), + nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(5).set_input_shape(1), + nmo.basis.RaisedCosineLogEval(10).set_input_shape(1), + nmo.basis.MSplineEval(5).set_input_shape(1), + nmo.basis.MSplineEval(10).set_input_shape(1), ), ) ``` diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_07_glm_pytree.md similarity index 99% rename from docs/how_to_guide/plot_06_glm_pytree.md rename to docs/how_to_guide/plot_07_glm_pytree.md index d6608c45..f980e75c 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_07_glm_pytree.md @@ -265,7 +265,7 @@ if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): - fig.savefig(path / "plot_06_glm_pytree.svg") + fig.savefig(path / "plot_07_glm_pytree.svg") ``` diff --git a/docs/index.md b/docs/index.md index 4da5ef06..0b61208c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ For Developers ``` -## __Neural ModelS__ +# __Neural ModelS__ NeMoS (Neural ModelS) is a statistical modeling framework optimized for systems neuroscience and powered by [JAX](https://jax.readthedocs.io/en/latest/). diff --git a/docs/quickstart.md b/docs/quickstart.md index bdf3ffd4..f071839f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -1,7 +1,18 @@ --- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 hide: - navigation --- + # Quickstart ## **Overview** @@ -29,58 +40,56 @@ NeMoS provides two implementations of the GLM: one for fitting a single neuron, You can define a single neuron GLM by instantiating an `GLM` object. -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # Instantiate the single model ->>> model = nmo.glm.GLM() +# Instantiate the single model +model = nmo.glm.GLM() ``` The coefficients can be learned by invoking the `fit` method of `GLM`. The method requires a design matrix of shape `(num_samples, num_features)`, and the output neural activity of shape `(num_samples, )`. -```python +```{code-cell} ipython3 ->>> import numpy as np ->>> num_samples, num_features = 100, 3 +import numpy as np +num_samples, num_features = 100, 3 ->>> # Generate a design matrix ->>> X = np.random.normal(size=(num_samples, num_features)) ->>> # generate some counts ->>> spike_counts = np.random.poisson(size=num_samples) +# Generate a design matrix +X = np.random.normal(size=(num_samples, num_features)) +# generate some counts +spike_counts = np.random.poisson(size=num_samples) ->>> # define fit the model ->>> model = model.fit(X, spike_counts) +# define fit the model +model = model.fit(X, spike_counts) ``` Once the model is fit, you can retrieve the model parameters as shown below. -```python ->>> # model coefficients shape is (num_features, ) ->>> print(f"Model coefficients shape: {model.coef_.shape}") -Model coefficients shape: (3,) +```{code-cell} ipython3 +# model coefficients shape is (num_features, ) +print(f"Model coefficients shape: {model.coef_.shape}") ->>> # model intercept, shape (1,) since there is only one neuron. ->>> print(f"Model intercept shape: {model.intercept_.shape}") -Model intercept shape: (1,) +# model intercept, shape (1,) since there is only one neuron. +print(f"Model intercept shape: {model.intercept_.shape}") ``` Additionally, you can predict the firing rate and call the compute the model log-likelihood by calling the `predict` and the `score` method respectively. -```python +```{code-cell} ipython3 + +# predict the rate +predicted_rate = model.predict(X) +# firing rate has shape: (num_samples,) +predicted_rate.shape ->>> # predict the rate ->>> predicted_rate = model.predict(X) ->>> # firing rate has shape: (num_samples,) ->>> predicted_rate.shape -(100,) ->>> # compute the log-likelihood of the model ->>> log_likelihood = model.score(X, spike_counts) +# compute the log-likelihood of the model +log_likelihood = model.score(X, spike_counts) ``` @@ -88,49 +97,47 @@ Additionally, you can predict the firing rate and call the compute the model log You can set up a population GLM by instantiating a `PopulationGLM`. The API for the `PopulationGLM` is the same as for the single-neuron `GLM`; the only difference you'll notice is that some of the methods' inputs and outputs have an additional dimension for the different neurons. -```python +```{code-cell} ->>> import nemos as nmo ->>> population_model = nmo.glm.PopulationGLM() +import nemos as nmo +population_model = nmo.glm.PopulationGLM() ``` As for the single neuron GLM, you can learn the model parameters by invoking the `fit` method: the input of `fit` are the design matrix (with shape `(num_samples, num_features)` ), and the population activity (with shape `(num_samples, num_neurons)`). Once the model is fit, you can use `predict` and `score` to predict the firing rate and the log-likelihood. -```python +```{code-cell} ->>> import numpy as np ->>> num_samples, num_features, num_neurons = 100, 3, 5 +import numpy as np +num_samples, num_features, num_neurons = 100, 3, 5 ->>> # simulate a design matrix ->>> X = np.random.normal(size=(num_samples, num_features)) ->>> # simulate some counts ->>> spike_counts = np.random.poisson(size=(num_samples, num_neurons)) +# simulate a design matrix +X = np.random.normal(size=(num_samples, num_features)) +# simulate some counts +spike_counts = np.random.poisson(size=(num_samples, num_neurons)) ->>> # fit the model ->>> population_model = population_model.fit(X, spike_counts) +# fit the model +population_model = population_model.fit(X, spike_counts) ->>> # predict the rate of each neuron in the population ->>> predicted_rate = population_model.predict(X) ->>> predicted_rate.shape # expected shape: (num_samples, num_neurons) -(100, 5) +# predict the rate of each neuron in the population +predicted_rate = population_model.predict(X) +predicted_rate.shape # expected shape: (num_samples, num_neurons) ->>> # compute the log-likelihood of the model ->>> log_likelihood = population_model.score(X, spike_counts) + +# compute the log-likelihood of the model +log_likelihood = population_model.score(X, spike_counts) ``` The learned coefficient and intercept will have shape `(num_features, num_neurons)` and `(num_neurons, )` respectively. -```python ->>> # model coefficients shape is (num_features, num_neurons) ->>> print(f"Model coefficients shape: {population_model.coef_.shape}") -Model coefficients shape: (3, 5) +```{code-cell} +# model coefficients shape is (num_features, num_neurons) +print(f"Model coefficients shape: {population_model.coef_.shape}") ->>> # model intercept, (num_neurons,) ->>> print(f"Model intercept shape: {population_model.intercept_.shape}") -Model intercept shape: (5,) +# model intercept, (num_neurons,) +print(f"Model intercept shape: {population_model.intercept_.shape}") ``` @@ -160,12 +167,12 @@ The `basis` module includes objects that perform two types of transformations on Non-linear mapping is the default mode of operation of any `basis` object. To instantiate a basis for non-linear mapping, you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Reference](nemos_basis) for detailed information). -```python +```{code-cell} ->>> import nemos as nmo +import nemos as nmo ->>> n_basis_funcs = 10 ->>> basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs) +n_basis_funcs = 10 +basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs) ``` @@ -173,17 +180,16 @@ Once the basis is instantiated, you can apply it to your input data using the `c This method takes an input array of shape `(n_samples, )` and transforms it into a two-dimensional array of shape `(n_samples, n_basis_funcs)`, where each column represents a feature generated by the non-linear mapping. -```python +```{code-cell} ->>> import numpy as np +import numpy as np ->>> # generate an input ->>> x = np.arange(100) +# generate an input +x = np.arange(100) ->>> # evaluate the basis ->>> X = basis.compute_features(x) ->>> X.shape -(100, 10) +# evaluate the basis +X = basis.compute_features(x) +X.shape ``` @@ -199,13 +205,13 @@ If you want to convolve a bank of basis functions with an input you must set the `"conv"` and you must provide an integer `window_size` parameter, which defines the length of the filter bank in number of sample points. -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> n_basis_funcs = 10 ->>> # define a filter bank of 10 basis function, 200 samples long. ->>> basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200) +n_basis_funcs = 10 +# define a filter bank of 10 basis function, 200 samples long. +basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200) ``` @@ -219,23 +225,21 @@ Once the basis is initialized, you can call `compute_features` on an input of sh The `window_size` must be shorter than the number of samples in the signal(s) being convolved. ::: -```python +```{code-cell} ipython3 ->>> import numpy as np +import numpy as np ->>> x = np.ones(500) +x = np.ones(500) ->>> # convolve a single signal ->>> X = basis.compute_features(x) ->>> X.shape -(500, 10) +# convolve a single signal +X = basis.compute_features(x) +X.shape ->>> x_multi = np.ones((500, 3)) +x_multi = np.ones((500, 3)) ->>> # convolve a multiple signals ->>> X_multi = basis.compute_features(x_multi) ->>> X_multi.shape -(500, 30) +# convolve a multiple signals +X_multi = basis.set_input_shape(3).compute_features(x_multi) +X_multi.shape ``` @@ -249,12 +253,12 @@ By default, NeMoS' GLM uses [Poisson observations](nemos.observation_models.Pois To change the default observation model, set the `observation_model` argument during initialization: -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # set up a Gamma GLM for modeling continuous non-negative data ->>> glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) +# set up a Gamma GLM for modeling continuous non-negative data +glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) ``` @@ -270,12 +274,12 @@ NeMoS supports various regularization schemes, including [Ridge](nemos.regulariz You can specify the regularization scheme and its strength when initializing the GLM model: -```python +```{code-cell} ipython3 ->>> import nemos as nmo +import nemos as nmo ->>> # Instantiate a GLM with Ridge (L2) regularization ->>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) +# Instantiate a GLM with Ridge (L2) regularization +glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) ``` @@ -301,25 +305,26 @@ also be a `pynapple` time series. A canonical example of this behavior is the `predict` method of `GLM`. -```ipython +```{code-cell} ipython3 + +import numpy as np +import pynapple as nap ->>> import numpy as np ->>> import pynapple as nap +# suppress jax to numpy conversion warning +nap.nap_config.suppress_conversion_warnings = True ->>> # create a TsdFrame with the features and a Tsd with the counts ->>> X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2))) ->>> y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, ))) +# create a TsdFrame with the features and a Tsd with the counts +X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2))) +y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, ))) ->>> print(type(X)) # shape (num samples, num features) - +print(type(X)) # shape (num samples, num features) ->>> model = model.fit(X, y) # the following works +model = model.fit(X, y) # the following works ->>> firing_rate = model.predict(X) # predict the firing rate of the neuron +firing_rate = model.predict(X) # predict the firing rate of the neuron ->>> # this will still be a pynapple time series ->>> print(type(firing_rate)) # shape (num_samples, ) - +# this will still be a pynapple time series +print(type(firing_rate)) # shape (num_samples, ) ``` @@ -331,29 +336,29 @@ Let's see how you can greatly streamline your analysis pipeline by integrating ` You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1). ::: -```ipython +```{code-cell} ipython3 ->>> import nemos as nmo ->>> import pynapple as nap +import nemos as nmo +import pynapple as nap ->>> path = nmo.fetch.fetch_data("A2929-200711.nwb") ->>> data = nap.load_file(path) +path = nmo.fetch.fetch_data("A2929-200711.nwb") +data = nap.load_file(path) ->>> # load spikes and head direction ->>> spikes = data["units"] ->>> head_dir = data["ry"] +# load spikes and head direction +spikes = data["units"] +head_dir = data["ry"] ->>> # restrict and bin ->>> counts = spikes[6].count(0.01, ep=head_dir.time_support) +# restrict and bin +counts = spikes[6].count(0.01, ep=head_dir.time_support) ->>> # down-sample head direction ->>> upsampled_head_dir = head_dir.bin_average(0.01) +# down-sample head direction +upsampled_head_dir = head_dir.bin_average(0.01) ->>> # create your features ->>> X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir) +# create your features +X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir) ->>> # add a neuron axis and fit model ->>> model = nmo.glm.GLM().fit(X, counts) +# add a neuron axis and fit model +model = nmo.glm.GLM().fit(X, counts) ``` @@ -361,35 +366,31 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa Finally, let's compare the tuning curves -```ipython +```{code-cell} ipython3 ->>> import numpy as np ->>> import matplotlib.pyplot as plt +import numpy as np +import matplotlib.pyplot as plt ->>> # tuning curves ->>> raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] +# tuning curves +raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] ->>> # model based tuning curve ->>> model_tuning = nap.compute_1d_tuning_curves_continuous( -... model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate -... head_dir, -... nb_bins=100 -... )[0] +# model based tuning curve +model_tuning = nap.compute_1d_tuning_curves_continuous( + model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate + head_dir, + nb_bins=100 + )[0] ->>> # plot results ->>> sub = plt.subplot(111, projection="polar") ->>> plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw") ->>> plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm") ->>> legend = plt.yticks([]) ->>> xlab = plt.xlabel("heading angle") +# plot results +sub = plt.subplot(111, projection="polar") +plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw") +plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm") +legend = plt.yticks([]) +xlab = plt.xlabel("heading angle") ``` - - - - ## **Compatibility with `scikit-learn`** @@ -400,34 +401,34 @@ For example, if we would like to tune the critical hyper-parameter `regularizer_ [^1]: For a detailed explanation and practical examples, refer to the [cross-validation page](https://scikit-learn.org/stable/modules/cross_validation.html) in the `scikit-learn` documentation. -```ipython +```{code-cell} ipython3 ->>> # set up the model ->>> import nemos as nmo ->>> import numpy as np +# set up the model +import nemos as nmo +import numpy as np ->>> # generate data ->>> X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100) +# generate data +X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100) ->>> # model definition ->>> model = nmo.glm.GLM(regularizer="Ridge") +# model definition +model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) ``` Fit a 5-fold cross-validation scheme for comparing two different regularizer strengths: -```ipython +```{code-cell} ipython3 ->>> from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import GridSearchCV ->>> # define the parameter grid ->>> param_grid = dict(regularizer_strength=(0.01, 0.001)) +# define the parameter grid +param_grid = dict(regularizer_strength=(0.01, 0.001)) ->>> # define the 5-fold cross-validation grid search from sklearn ->>> cls = GridSearchCV(model, param_grid=param_grid, cv=5) +# define the 5-fold cross-validation grid search from sklearn +cls = GridSearchCV(model, param_grid=param_grid, cv=5) ->>> # run the 5-fold cross-validation grid search ->>> cls = cls.fit(X, counts) +# run the 5-fold cross-validation grid search +cls = cls.fit(X, counts) ``` @@ -440,11 +441,10 @@ For more information and a practical example on how to construct a parameter gri Finally, we can print the regularizer strength with the best cross-validated performance: -```ipython +```{code-cell} ipython3 ->>> # print best regularizer strength ->>> print(cls.best_params_) -{'regularizer_strength': 0.01} +# print best regularizer strength +print(cls.best_params_) ``` diff --git a/pyproject.toml b/pyproject.toml index 4c6134ef..f57d3238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,15 @@ profile = "black" # Configure pytest [tool.pytest.ini_options] testpaths = ["tests"] # Specify the directory where test files are located +filterwarnings = [ + # note the use of single quote below to denote "raw" strings in TOML + # this is raised whenever one imports the plotting utils + 'ignore:plotting functions contained within:UserWarning', + # numerical inversion test reaches tolerance... + 'ignore:Tolerance of -?\d\.\d+e-\d\d reached:RuntimeWarning', + # mpl must be non-interctive for testing otherwise doctests will freeze + 'ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning', +] [tool.coverage.run] omit = [ diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py index df37dc0a..6e9cbbca 100644 --- a/src/nemos/basis/_basis.py +++ b/src/nemos/basis/_basis.py @@ -16,7 +16,7 @@ from ..typing import FeatureMatrix from ..utils import row_wise_kron from ..validation import check_fraction_valid_samples -from ._basis_mixin import BasisTransformerMixin +from ._basis_mixin import BasisTransformerMixin, CompositeBasisMixin def add_docstring(method_name, cls): @@ -53,7 +53,7 @@ def check_one_dimensional(func: Callable) -> Callable: """Check if the input is one-dimensional.""" @wraps(func) - def wrapper(self: Basis, *xi: ArrayLike, **kwargs): + def wrapper(self: Basis, *xi: NDArray, **kwargs): if any(x.ndim != 1 for x in xi): raise ValueError("Input sample must be one dimensional!") return func(self, *xi, **kwargs) @@ -111,8 +111,6 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): Parameters ---------- - n_basis_funcs : - The number of basis functions. mode : The mode of operation. 'eval' for evaluation at sample points, 'conv' for convolutional operation. @@ -135,28 +133,25 @@ class Basis(Base, abc.ABC, BasisTransformerMixin): def __init__( self, - n_basis_funcs: int, - mode: Literal["eval", "conv"] = "eval", + mode: Literal["eval", "conv", "composite"] = "eval", label: Optional[str] = None, ) -> None: - self.n_basis_funcs = n_basis_funcs - self._n_input_dimensionality = 0 + self._n_input_dimensionality = getattr(self, "_n_input_dimensionality", 0) self._mode = mode - self._n_basis_input = None - - # these parameters are going to be set at the first call of `compute_features` - # since we cannot know a-priori how many features may be convolved - self._n_output_features = None - self._input_shape = None - if label is None: self._label = self.__class__.__name__ else: self._label = str(label) - self.kernel_ = None + # specified only after inputs/input shapes are provided + self._n_basis_input_ = None + self._input_shape_ = None + + # initialize parent to None. This should not end in "_" because it is + # a permanent property of a basis, defined at composite basis init + self._parent = None @property def n_output_features(self) -> int | None: @@ -167,9 +162,12 @@ def n_output_features(self) -> int | None: ----- The number of output features can be determined only when the number of inputs provided to the basis is known. Therefore, before the first call to ``compute_features``, - this property will return ``None``. After that call, ``n_output_features`` will be available. + this property will return ``None``. After that call, or after setting the input shape with + ``set_input_shape``, ``n_output_features`` will be available. """ - return self._n_output_features + if self._n_basis_input_ is not None: + return self.n_basis_funcs * self._n_basis_input_[0] + return None @property def label(self) -> str: @@ -177,12 +175,12 @@ def label(self) -> str: return self._label @property - def n_basis_input(self) -> tuple | None: + def n_basis_input_(self) -> tuple | None: """Number of expected inputs. The number of inputs ``compute_feature`` expects. """ - return self._n_basis_input + return self._n_basis_input_ @property def n_basis_funcs(self): @@ -204,43 +202,6 @@ def mode(self): """Mode of operation, either ``"conv"`` or ``"eval"``.""" return self._mode - @staticmethod - def _apply_identifiability_constraints(X: NDArray): - """Apply identifiability constraints to a design matrix `X`. - - Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness - of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly - crucial for models using bases like BSplines and CyclicBspline, which, due to their - construction, sum to 1 and can cause rank deficiency when combined with an intercept. - - For GLMs, this rank deficiency means that different sets of coefficients might yield - identical predicted rates and log-likelihood, complicating parameter learning, especially - in the absence of regularization. - - Parameters - ---------- - X: - The design matrix before applying the identifiability constraints. - - Returns - ------- - : - The adjusted design matrix with redundant columns dropped and columns mean-centered. - """ - - def add_constant(x): - return np.hstack((np.ones((x.shape[0], 1)), x)) - - rank = np.linalg.matrix_rank(add_constant(X)) - # mean center - X = X - np.nanmean(X, axis=0) - while rank < X.shape[1] + 1: - # drop a column - X = X[:, :-1] - # recompute rank - rank = np.linalg.matrix_rank(add_constant(X)) - return X - @check_transform_input def compute_features( self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor @@ -273,20 +234,65 @@ def compute_features( Subclasses should implement how to handle the transformation specific to their basis function types and operation modes. """ - self._set_num_output_features(*xi) - self.set_kernel() + if self._n_basis_input_ is None: + self.set_input_shape(*xi) + self._check_input_shape_consistency(*xi) + self._set_input_independent_states() return self._compute_features(*xi) @abc.abstractmethod def _compute_features( self, *xi: NDArray | Tsd | TsdFrame | TsdTensor ) -> FeatureMatrix: - """Convolve or evaluate the basis.""" + """Convolve or evaluate the basis. + + This method is intended to be equivalent to the sklearn transformer ``transform`` method. + As the latter, it computes the transformation assuming that all the states are already + pre-computed by ``_fit_basis``, a method corresponding to ``fit``. + + The method differs from transformer's ``transform`` for the structure of the input that it accepts. + In particular, ``_compute_features`` accepts a number of different time series, one per 1D basis component, + while ``transform`` requires all inputs to be concatenated in a single array. + """ + pass + + @abc.abstractmethod + def setup_basis(self, *xi: ArrayLike) -> FeatureMatrix: + """Pre-compute all basis state variables. + + This method is intended to be equivalent to the sklearn transformer ``fit`` method. + As the latter, it computes all the state attributes, and store it with the convention + that the attribute name **must** end with "_", for example ``self.kernel_``, + ``self._input_shape_``. + + The method differs from transformer's ``fit`` for the structure of the input that it accepts. + In particular, ``_fit_basis`` accepts a number of different time series, one per 1D basis component, + while ``fit`` requires all inputs to be concatenated in a single array. + """ + pass + + @abc.abstractmethod + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + An example of such state is the kernel_ for Conv bases, which can be computed + without any input (it only depends on the basis type, the window size and the + number of basis elements). + """ pass @abc.abstractmethod - def set_kernel(self): - """Set kernel for conv basis and return self or just return self for eval.""" + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + """ pass @abc.abstractmethod @@ -379,13 +385,6 @@ def _check_transform_input( return xi - def _check_has_kernel(self) -> None: - """Check that the kernel is pre-computed.""" - if self.mode == "conv" and self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features` when mode =`conv`." - ) - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -464,20 +463,6 @@ def _check_samples_consistency(*xi: NDArray) -> None: "Sample size mismatch. Input elements have inconsistent sample sizes." ) - @abc.abstractmethod - def _check_n_basis_min(self) -> None: - """Check that the user required enough basis elements. - - Most of the basis work with at least 1 element, but some - such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined. - - Raises - ------ - ValueError - If an insufficient number of basis element is requested for the basis type - """ - pass - def __add__(self, other: Basis) -> AdditiveBasis: """ Add two Basis objects together. @@ -558,7 +543,7 @@ def _get_feature_slicing( Parameters ---------- n_inputs : - The number of input basis for each component, by default it uses ``self._n_basis_input``. + The number of input basis for each component, by default it uses ``self._n_basis_input_``. start_slice : The starting index for slicing, by default it starts from 0. split_by_input : @@ -579,10 +564,8 @@ def _get_feature_slicing( _get_default_slicing : Handles default slicing logic. _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ - # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + # Set default values for start_slice if not provided start_slice = start_slice or 0 - # Handle the default case for non-additive basis types # See overwritten method for recursion logic split_dict, start_slice = self._get_default_slicing( @@ -607,11 +590,9 @@ def _get_default_slicing( """Handle default slicing logic.""" if split_by_input: # should we remove this option? - if self._n_basis_input[0] == 1 or isinstance(self, MultiplicativeBasis): + if self._n_basis_input_[0] == 1 or isinstance(self, MultiplicativeBasis): split_dict = { - self.label: slice( - start_slice, start_slice + self._n_output_features - ) + self.label: slice(start_slice, start_slice + self.n_output_features) } else: split_dict = { @@ -620,14 +601,14 @@ def _get_default_slicing( start_slice + i * self.n_basis_funcs, start_slice + (i + 1) * self.n_basis_funcs, ) - for i in range(self._n_basis_input[0]) + for i in range(self._n_basis_input_[0]) } } else: split_dict = { - self.label: slice(start_slice, start_slice + self._n_output_features) + self.label: slice(start_slice, start_slice + self.n_output_features) } - start_slice += self._n_output_features + start_slice += self.n_output_features return split_dict, start_slice def split_by_feature( @@ -719,88 +700,18 @@ def is_leaf(val): # Apply the slicing using the custom leaf function out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf) - # reshape the arrays to spilt by n_basis_input + # reshape the arrays to spilt by n_basis_input_ reshaped_out = dict() for i, vals in enumerate(out.items()): key, val = vals shape = list(val.shape) reshaped_out[key] = val.reshape( - shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :] + shape[:axis] + [self._n_basis_input_[i], -1] + shape[axis + 1 :] ) return reshaped_out - def _check_input_shape_consistency(self, x: NDArray): - """Check input consistency across calls.""" - # remove sample axis - shape = x.shape[1:] - if self._input_shape is not None and self._input_shape != shape: - expected_shape_str = "(n_samples, " + f"{self._input_shape}"[1:] - expected_shape_str = expected_shape_str.replace(",)", ")") - raise ValueError( - f"Input shape mismatch detected.\n\n" - f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " - f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" - f" Expected: {expected_shape_str}\n" - f" But got: {x.shape}.\n\n" - "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " - "but all other dimensions must remain the same. If you need to process inputs with a " - "different shape, please create a new basis instance." - ) - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - """ - Pre-compute the number of inputs and output features. - - This function computes the number of inputs that are provided to the basis and uses - that number, and the n_basis_funcs to calculate the number of output features that - ``self.compute_features`` will return. These quantities and the input shape (excluding the sample axis) - are stored in ``self._n_basis_input`` and ``self._n_output_features``, and ``self._input_shape`` - respectively. - - Parameters - ---------- - xi: - The input arrays. - - Returns - ------- - : - The basis itself, for chaining. - - Raises - ------ - ValueError: - If the number of inputs do not match ``self._n_basis_input``, if ``self._n_basis_input`` was - not None. - - Notes - ----- - Once a ``compute_features`` is called, we enforce that for all subsequent calls of the method, - the input that the basis receives preserves the shape of all axes, except for the sample axis. - This condition guarantees the consistency of the feature axis, and therefore that - ``self.split_by_feature`` behaves appropriately. - - """ - # Check that the input shape matches expectation - # Note that this method is reimplemented in AdditiveBasis and MultiplicativeBasis - # so we can assume that len(xi) == 1 - xi = xi[0] - self._check_input_shape_consistency(xi) - - # remove sample axis (samples are allowed to vary) - shape = xi.shape[1:] - self._input_shape = shape - - # remove sample axis & get the total input number - n_inputs = (1,) if xi.ndim == 1 else (np.prod(shape),) - - self._n_basis_input = n_inputs - self._n_output_features = self.n_basis_funcs * self._n_basis_input[0] - return self - - -class AdditiveBasis(Basis): +class AdditiveBasis(CompositeBasisMixin, Basis): """ Class representing the addition of two Basis objects. @@ -811,11 +722,6 @@ class AdditiveBasis(Basis): basis2 : Second basis object to add. - Attributes - ---------- - n_basis_funcs : int - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -835,33 +741,57 @@ class AdditiveBasis(Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs + basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="composite") + self._label = "(" + basis1.label + " + " + basis2.label + ")" + self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " + " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features + self._basis2.n_output_features - ) - return self - def _check_n_basis_min(self) -> None: - pass + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. + + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. + """ + return self.basis1.n_basis_funcs + self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self.basis1, "n_output_features", None) + out2 = getattr(self.basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 + out2 + + @add_docstring("set_input_shape", CompositeBasisMixin) + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: + """ + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + + >>> # define an additive basis + >>> basis_1 = nmo.basis.BSplineEval(5) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) + >>> basis_3 = nmo.basis.RaisedCosineLinearEval(7) + >>> additive_basis = basis_1 + basis_2 + basis_3 + + Specify the input shape using all 3 allowed ways: integer, tuple, array + >>> _ = additive_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) + + Expected output features are: + (5 bases * 1 input) + (6 bases * 6 inputs) + (7 bases * 20 inputs) = 181 + >>> additive_basis.n_output_features + 181 + + """ + return super().set_input_shape(*xi) @support_pynapple(conv_type="numpy") @check_transform_input @@ -899,8 +829,8 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri """ X = np.hstack( ( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._evaluate(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._evaluate(*xi[self.basis1._n_input_dimensionality :]), ) ) return X @@ -948,35 +878,16 @@ def _compute_features( hstack_pynapple = support_pynapple(conv_type="numpy")(np.hstack) X = hstack_pynapple( ( - self._basis1._compute_features( - *xi[: self._basis1._n_input_dimensionality] + self.basis1._compute_features( + *xi[: self.basis1._n_input_dimensionality] ), - self._basis2._compute_features( - *xi[self._basis1._n_input_dimensionality :] + self.basis2._compute_features( + *xi[self.basis1._n_input_dimensionality :] ), ), ) return X - def set_kernel(self, *xi: ArrayLike) -> Basis: - """Call fit on the added basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. - - Returns - ------- - : - The AdditiveBasis ready to be evaluated. - """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self - def split_by_feature( self, x: NDArray, @@ -1182,18 +1093,18 @@ def _get_feature_slicing( _merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts. """ # Set default values for n_inputs and start_slice if not provided - n_inputs = n_inputs or self._n_basis_input + n_inputs = n_inputs or self._n_basis_input_ start_slice = start_slice or 0 # If the instance is of AdditiveBasis type, handle slicing for the additive components - split_dict, start_slice = self._basis1._get_feature_slicing( - n_inputs[: len(self._basis1._n_basis_input)], + split_dict, start_slice = self.basis1._get_feature_slicing( + n_inputs[: len(self.basis1._n_basis_input_)], start_slice, split_by_input=split_by_input, ) - sp2, start_slice = self._basis2._get_feature_slicing( - n_inputs[len(self._basis1._n_basis_input) :], + sp2, start_slice = self.basis2._get_feature_slicing( + n_inputs[len(self.basis1._n_basis_input_) :], start_slice, split_by_input=split_by_input, ) @@ -1211,7 +1122,7 @@ def _merge_slicing_dicts(self, dict1: dict, dict2: dict) -> dict: return dict1 -class MultiplicativeBasis(Basis): +class MultiplicativeBasis(CompositeBasisMixin, Basis): """ Class representing the multiplication (external product) of two Basis objects. @@ -1222,11 +1133,6 @@ class MultiplicativeBasis(Basis): basis2 : Second basis object to multiply. - Attributes - ---------- - n_basis_funcs : - Number of basis functions. - Examples -------- >>> # Generate sample data @@ -1246,38 +1152,30 @@ class MultiplicativeBasis(Basis): """ def __init__(self, basis1: Basis, basis2: Basis) -> None: - self.n_basis_funcs = basis1.n_basis_funcs * basis2.n_basis_funcs - super().__init__(self.n_basis_funcs, mode="eval") + CompositeBasisMixin.__init__(self, basis1, basis2) + Basis.__init__(self, mode="composite") + self._label = "(" + basis1.label + " * " + basis2.label + ")" self._n_input_dimensionality = ( basis1._n_input_dimensionality + basis2._n_input_dimensionality ) - self._n_basis_input = None - self._n_output_features = None - self._label = "(" + basis1.label + " * " + basis2.label + ")" - self._basis1 = basis1 - self._basis2 = basis2 - - def _check_n_basis_min(self) -> None: - pass - def set_kernel(self, *xi: NDArray) -> Basis: - """Call fit on the multiplied basis. - - If any of the added basis is in "conv" mode, it will prepare its kernels for the convolution. - - Parameters - ---------- - *xi: - The sample inputs. Unused, necessary to conform to ``scikit-learn`` API. + @property + def n_basis_funcs(self): + """Compute the n-basis function runtime. - Returns - ------- - : - The MultiplicativeBasis ready to be evaluated. + This plays well with cross-validation where the number of basis function of the + underlying bases can be changed. It must be read-only since the number of basis + is determined by the two basis elements and the type of composition. """ - self._basis1.set_kernel() - self._basis2.set_kernel() - return self + return self.basis1.n_basis_funcs * self.basis2.n_basis_funcs + + @property + def n_output_features(self): + out1 = getattr(self.basis1, "n_output_features", None) + out2 = getattr(self.basis2, "n_output_features", None) + if out1 is None or out2 is None: + return None + return out1 * out2 @support_pynapple(conv_type="numpy") @check_transform_input @@ -1307,8 +1205,8 @@ def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatri """ X = np.asarray( row_wise_kron( - self._basis1._evaluate(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._evaluate(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._evaluate(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._evaluate(*xi[self.basis1._n_input_dimensionality :]), transpose=False, ) ) @@ -1341,26 +1239,12 @@ def _compute_features( """ kron = support_pynapple(conv_type="numpy")(row_wise_kron) X = kron( - self._basis1._compute_features(*xi[: self._basis1._n_input_dimensionality]), - self._basis2._compute_features(*xi[self._basis1._n_input_dimensionality :]), + self.basis1._compute_features(*xi[: self.basis1._n_input_dimensionality]), + self.basis2._compute_features(*xi[self.basis1._n_input_dimensionality :]), transpose=False, ) return X - def _set_num_output_features(self, *xi: NDArray) -> Basis: - self._n_basis_input = ( - *self._basis1._set_num_output_features( - *xi[: self._basis1._n_input_dimensionality] - )._n_basis_input, - *self._basis2._set_num_output_features( - *xi[self._basis1._n_input_dimensionality :] - )._n_basis_input, - ) - self._n_output_features = ( - self._basis1.n_output_features * self._basis2.n_output_features - ) - return self - def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]: """Evaluate the basis set on a grid of equi-spaced sample points. @@ -1455,3 +1339,29 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + + @add_docstring("set_input_shape", CompositeBasisMixin) + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: + """ + Examples + -------- + >>> # Generate sample data + >>> import numpy as np + >>> import nemos as nmo + + >>> # define an additive basis + >>> basis_1 = nmo.basis.BSplineEval(5) + >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) + >>> basis_3 = nmo.basis.MSplineEval(7) + >>> multiplicative_basis = basis_1 * basis_2 * basis_3 + + Specify the input shape using all 3 allowed ways: integer, tuple, array + >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5))) + + Expected output features are: + (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200 + >>> multiplicative_basis.n_output_features + 25200 + + """ + return super().set_input_shape(*xi) diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index 9aac6faf..c65710bd 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -2,9 +2,12 @@ from __future__ import annotations +import abc import copy import inspect -from typing import Optional, Tuple, Union +from functools import wraps +from itertools import chain +from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union import numpy as np from numpy.typing import ArrayLike, NDArray @@ -13,6 +16,166 @@ from ..convolve import create_convolutional_predictor from ._transformer_basis import TransformerBasis +if TYPE_CHECKING: + from ._basis import Basis + + +def set_input_shape_state( + method, states: Tuple[str] = ("_n_basis_input_", "_input_shape_") +): + """ + Decorator to preserve input shape-related attributes during method execution. + + This decorator ensures that the attributes `_n_basis_input_` and `_input_shape_` + are copied from the original object (`self`) to the returned object (`klass`) + after the wrapped method executes. It is intended to be used with methods that + clone or create a new instance of the class, ensuring these critical attributes + are retained for functionality such as cross-validation. + + Parameters + ---------- + method : + The method to be wrapped. This method is expected to return an object + (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes. + attr_list + + Returns + ------- + : + The wrapped method that copies `_n_basis_input_` and `_input_shape_` from + the original object (`self`) to the new object (`klass`). + + Examples + -------- + Applying the decorator to a method: + + >>> from functools import wraps + >>> @set_input_shape_state + ... def __sklearn_clone__(self): + ... klass = self.__class__(**self.get_params()) + ... return klass + + The `_n_basis_input_` and `_input_shape_` attributes of `self` will be + copied to `klass` after the method executes. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + klass: Basis = method(self, *args, **kwargs) + for attr_name in states: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + + return wrapper + + +class AtomicBasisMixin: + """Mixin class for atomic bases (i.e. non-composite).""" + + def __init__(self, n_basis_funcs: int): + self._n_basis_funcs = n_basis_funcs + self._check_n_basis_min() + + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + """ + klass = self.__class__(**self.get_params()) + return klass + + def _iterate_over_components(self) -> Generator: + """Return a generator that iterates over all basis components. + + For atomic bases, the list is just [self]. + + Returns + ------- + A generator returning self, it will be chained in composite bases. + + """ + return (x for x in [self]) + + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Set the expected input shape for the basis object. + + This method configures the shape of the input data that the basis object expects. + ``xi`` can be specified as an integer, a tuple of integers, or derived + from an array. The method also calculates the total number of input + features and output features based on the number of basis functions. + + Parameters + ---------- + xi : + The input shape specification. + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + + Notes + ----- + All state attributes that depends on the input must be set in this method in order for + the API of basis to work correctly. In particular, this method is called by ``setup_basis``, + which is equivalent to ``fit`` for a transformer. If any input dependent state + is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break. + + """ + if isinstance(xi, tuple): + if not all(isinstance(i, int) for i in xi): + raise ValueError( + f"The tuple provided contains non integer values. Tuple: {xi}." + ) + shape = xi + elif isinstance(xi, int): + shape = () if xi == 1 else (xi,) + else: + shape = xi.shape[1:] + + n_inputs = (int(np.prod(shape)),) + + self._input_shape_ = shape + + self._n_basis_input_ = n_inputs + return self + + def _check_input_shape_consistency(self, x: NDArray): + """Check input consistency across calls.""" + # remove sample axis and squeeze + shape = x.shape[1:] + + initialized = self._input_shape_ is not None + is_shape_match = self._input_shape_ == shape + if initialized and not is_shape_match: + expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:] + expected_shape_str = expected_shape_str.replace(",)", ")") + raise ValueError( + f"Input shape mismatch detected.\n\n" + f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with " + f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n" + f" Expected: {expected_shape_str}\n" + f" But got: {x.shape}.\n\n" + "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, " + "but all other dimensions must remain the same. If you need to process inputs with a " + "different shape, please create a new basis instance, or set a new input shape by calling " + "`set_input_shape`." + ) + class EvalBasisMixin: """Mixin class for evaluational basis.""" @@ -48,9 +211,32 @@ def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor): out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi)) return np.reshape(out, (out.shape[0], -1)) - def set_kernel(self) -> "EvalBasisMixin": + def setup_basis(self, *xi: NDArray) -> Basis: """ - Prepare or compute the convolutional kernel for the basis functions. + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self) -> "EvalBasisMixin": + """ + Compute all the basis states that do not depend on the input. For EvalBasisMixin, this method might not perform any operation but simply return the instance itself, as no kernel preparation is necessary. @@ -92,6 +278,7 @@ class ConvBasisMixin: """Mixin class for convolutional basis.""" def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None): + self.kernel_ = None self.window_size = window_size self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs @@ -111,12 +298,16 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): The input data over which to apply the basis transformation. The samples can be passed as multiple arguments, each representing a different dimension for multivariate inputs. + Notes + ----- + This method is intended to be 1-to-1 mappable to sklearn ``transform`` method of transformer. This + means that for the method to be callable, all the state attributes have to be pre-computed in a + method that is mappable to ``fit``, which for us is ``_fit_basis``. It is fundamental that both + methods behaves like the corresponding transformer method, with the only difference being the input + structure: a single (X, y) pair for the transformer, a number of time series for the Basis. + """ - if self.kernel_ is None: - raise ValueError( - "You must call `_set_kernel` before `_compute_features`! " - "Convolution kernel is not set." - ) + self._check_has_kernel() # before calling the convolve, check that the input matches # the expectation. We can check xi[0] only, since convolution # is applied at the end of the recursion on the 1D basis, ensuring len(xi) == 1. @@ -124,6 +315,38 @@ def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor): # make sure to return a matrix return np.reshape(conv, newshape=(conv.shape[0], -1)) + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + self.set_kernel() + self.set_input_shape(*xi) + return self + + def _set_input_independent_states(self): + """ + Compute all the basis states that do not depend on the input. + + For Conv mixin the only attribute is the kernel. + """ + return self.set_kernel() + def set_kernel(self) -> "ConvBasisMixin": """ Prepare or compute the convolutional kernel for the basis functions. @@ -157,6 +380,11 @@ def window_size(self): @window_size.setter def window_size(self, window_size): """Setter for the window size parameter.""" + self._check_window_size(window_size) + + self._window_size = window_size + + def _check_window_size(self, window_size): if window_size is None: raise ValueError("You must provide a window_size!") @@ -165,8 +393,6 @@ def window_size(self, window_size): f"`window_size` must be a positive integer. {window_size} provided instead!" ) - self._window_size = window_size - @property def conv_kwargs(self): """The convolutional kwargs. @@ -224,6 +450,13 @@ def _check_convolution_kwargs(conv_kwargs: dict): f"Allowed convolution keyword arguments are: {convolve_configs}." ) + def _check_has_kernel(self) -> None: + """Check that the kernel is pre-computed.""" + if self.kernel_ is None: + raise RuntimeError( + "You must call `setup_basis` before `_compute_features` for Conv basis." + ) + class BasisTransformerMixin: """Mixin class for constructing a transformer.""" @@ -241,7 +474,7 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.model_selection import GridSearchCV >>> # load some data >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) - >>> basis = nmo.basis.RaisedCosineLinearEval(10).to_transformer() + >>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( @@ -255,4 +488,160 @@ def to_transformer(self) -> TransformerBasis: ... ) >>> gridsearch = gridsearch.fit(X, y) """ - return TransformerBasis(copy.deepcopy(self)) + return TransformerBasis(self) + + +class CompositeBasisMixin: + """Mixin class for composite basis. + + Add overwrites concrete methods or defines abstract methods for composite basis + (AdditiveBasis and MultiplicativeBasis). + """ + + def __init__(self, basis1: Basis, basis2: Basis): + # deep copy to avoid changes directly to the 1d basis to be reflected + # in the composite basis. + self.basis1 = copy.deepcopy(basis1) + self.basis2 = copy.deepcopy(basis2) + + # set parents + self.basis1._parent = self + self.basis2._parent = self + + shapes = ( + *(bas1._input_shape_ for bas1 in basis1._iterate_over_components()), + *(bas2._input_shape_ for bas2 in basis2._iterate_over_components()), + ) + # if all bases where set, then set input for composition. + set_bases = [s is not None for s in shapes] + + if all(set_bases): + # pass down the input shapes + self.set_input_shape(*shapes) + + @property + @abc.abstractmethod + def n_basis_funcs(self): + """Read only property for composite bases.""" + pass + + def setup_basis(self, *xi: NDArray) -> Basis: + """ + Set all basis states. + + This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and + it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape. + The difference between this method and the transformer ``fit`` is in the expected input structure, + where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here + each input is provided as a separate time series for each basis element. + + Parameters + ---------- + xi: + Input arrays. + + Returns + ------- + : + The basis with ready for evaluation. + """ + # setup both input independent + self._set_input_independent_states() + + # and input dependent states + self.set_input_shape(*xi) + + return self + + def _set_input_independent_states(self): + """ + Compute the input dependent states for traversing the composite basis. + + Returns + ------- + : + The basis with the states stored as attributes of each component. + """ + self.basis1._set_input_independent_states() + self.basis2._set_input_independent_states() + + def _check_input_shape_consistency(self, *xi: NDArray): + """Check the input shape consistency for all basis elements.""" + self.basis1._check_input_shape_consistency( + *xi[: self.basis1._n_input_dimensionality] + ) + self.basis2._check_input_shape_consistency( + *xi[self.basis1._n_input_dimensionality :] + ) + + def _iterate_over_components(self): + """Return a generator that iterates over all basis components. + + Reimplements the default behavior by iteratively calling _iterate_over_components of the + elements. + + Returns + ------- + A generator looping on each individual input. + """ + return chain( + self.basis1._iterate_over_components(), + self.basis2._iterate_over_components(), + ) + + @set_input_shape_state + def __sklearn_clone__(self) -> Basis: + """Clone the basis while preserving attributes related to input shapes. + + This method ensures that input shape attributes (e.g., `_n_basis_input_`, + `_input_shape_`) are preserved during cloning. Reinitializing the class + as in the regular sklearn clone would drop these attributes, rendering + cross-validation unusable. + The method also handles recursive cloning for composite basis structures. + """ + # clone recursively + basis1 = self.basis1.__sklearn_clone__() + basis2 = self.basis2.__sklearn_clone__() + klass = self.__class__(basis1, basis2) + + for attr_name in ["_n_basis_input_", "_input_shape_"]: + setattr(klass, attr_name, getattr(self, attr_name)) + return klass + + def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis: + """ + Set the expected input shape for the basis object. + + This method sets the input shape for each component basis in the basis. + One ``xi`` must be provided for each basis component, specified as an integer, + a tuple of integers, or an array. The method calculates and stores the total number of output features + based on the number of basis functions in each component and the provided input shapes. + + Parameters + ---------- + *xi : + The input shape specifications. For every k,``xi[k]`` can be: + - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input. + - A tuple: Represents the exact input shape excluding the first axis (sample axis). + All elements must be integers. + - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis). + + Raises + ------ + ValueError + If a tuple is provided and it contains non-integer elements. + + Returns + ------- + self : + Returns the instance itself to allow method chaining. + """ + self._n_basis_input_ = ( + *self.basis1.set_input_shape( + *xi[: self.basis1._n_input_dimensionality] + )._n_basis_input_, + *self.basis2.set_input_shape( + *xi[self.basis1._n_input_dimensionality :] + )._n_basis_input_, + ) + return self diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py index 65a71a3e..5f80df58 100644 --- a/src/nemos/basis/_decaying_exponential.py +++ b/src/nemos/basis/_decaying_exponential.py @@ -14,9 +14,10 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin -class OrthExponentialBasis(Basis, abc.ABC): +class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC): """Set of 1D basis decaying exponential functions numerically orthogonalized. Parameters @@ -40,8 +41,8 @@ def __init__( mode="eval", label: Optional[str] = "OrthExponentialBasis", ): + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( - n_basis_funcs, mode=mode, label=label, ) diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py index 07c3ae0a..dbf039eb 100644 --- a/src/nemos/basis/_raised_cosine_basis.py +++ b/src/nemos/basis/_raised_cosine_basis.py @@ -11,9 +11,10 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin -class RaisedCosineBasisLinear(Basis, abc.ABC): +class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC): """Represent linearly-spaced raised cosine basis functions. This implementation is based on the cosine bumps used by Pillow et al. [1]_ @@ -47,8 +48,8 @@ def __init__( width: float = 2.0, label: Optional[str] = "RaisedCosineBasisLinear", ) -> None: + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( - n_basis_funcs, mode=mode, label=label, ) diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py index d9969029..c8f42d90 100644 --- a/src/nemos/basis/_spline_basis.py +++ b/src/nemos/basis/_spline_basis.py @@ -13,9 +13,10 @@ from ..type_casting import support_pynapple from ..typing import FeatureMatrix from ._basis import Basis, check_transform_input, min_max_rescale_samples +from ._basis_mixin import AtomicBasisMixin -class SplineBasis(Basis, abc.ABC): +class SplineBasis(Basis, AtomicBasisMixin, abc.ABC): """ SplineBasis class inherits from the Basis class and represents spline basis functions. @@ -46,8 +47,8 @@ def __init__( mode: Literal["conv", "eval"] = "eval", ) -> None: self.order = order + AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs) super().__init__( - n_basis_funcs, label=label, mode=mode, ) diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py index 83f4f2e3..4420eb40 100644 --- a/src/nemos/basis/_transformer_basis.py +++ b/src/nemos/basis/_transformer_basis.py @@ -1,7 +1,9 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generator + +import numpy as np from ..typing import FeatureMatrix @@ -63,12 +65,27 @@ def __init__(self, basis: Basis): self._basis = copy.deepcopy(basis) @staticmethod - def _unpack_inputs(X: FeatureMatrix): - """Unpack inputs without using transpose. + def _check_initialized(basis): + if basis._n_basis_input_ is None: + raise RuntimeError( + "Cannot apply TransformerBasis: the provided basis has no defined input shape. " + "Please call `set_input_shape` before calling `fit`, `transform`, or " + "`fit_transform`." + ) + + @property + def basis(self): + return self._basis + + @basis.setter + def basis(self, basis): + self._basis = basis + + def _unpack_inputs(self, X: FeatureMatrix) -> Generator: + """Unpack inputs. Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``, - returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple`` - exception since ``pynapple`` assumes that the time axis is the first axis. + returning a list of Tsd objects. Parameters ---------- @@ -78,10 +95,18 @@ def _unpack_inputs(X: FeatureMatrix): Returns ------- : - A tuple of each individual input. + A list of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) + n_samples = X.shape[0] + out = ( + np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_)) + for i, (bas, n_input) in enumerate( + zip(self._iterate_over_components(), self._n_basis_input_) + ) + for cc in [sum(self._n_basis_input_[:i])] + ) + return out def fit(self, X: FeatureMatrix, y=None): """ @@ -110,11 +135,13 @@ def fit(self, X: FeatureMatrix, y=None): >>> X = np.random.normal(size=(100, 2)) >>> # Define and fit tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> transformer_fitted = transformer.fit(X) """ - self._basis.set_kernel() + self._check_initialized(self._basis) + self._check_input(X, y) + self._basis.setup_basis(*self._unpack_inputs(X)) return self def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: @@ -141,7 +168,7 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> # Example input >>> X = np.random.normal(size=(10000, 2)) - >>> basis = MSplineConv(10, window_size=200) + >>> basis = MSplineConv(10, window_size=200).set_input_shape(2) >>> transformer = TransformerBasis(basis) >>> # Before calling `fit` the convolution kernel is not set >>> transformer.kernel_ @@ -152,8 +179,9 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: (200, 10) >>> # Transform basis - >>> feature_transformed = transformer.transform(X[:, 0:1]) + >>> feature_transformed = transformer.transform(X) """ + self._check_initialized(self._basis) # transpose does not work with pynapple # can't use func(*X.T) to unwrap return self._basis._compute_features(*self._unpack_inputs(X)) @@ -187,13 +215,14 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix: >>> X = np.random.normal(size=(100, 1)) >>> # Define tranformation basis - >>> basis = MSplineEval(10) + >>> basis = MSplineEval(10).set_input_shape(1) >>> transformer = TransformerBasis(basis) >>> # Fit and transform basis >>> feature_transformed = transformer.fit_transform(X) """ - return self._basis.compute_features(*self._unpack_inputs(X)) + self.fit(X, y=y) + return self.transform(X) def __getstate__(self): """ @@ -283,7 +312,7 @@ def __sklearn_clone__(self) -> TransformerBasis: For more info: https://scikit-learn.org/stable/developers/develop.html#cloning """ - cloned_obj = TransformerBasis(copy.deepcopy(self._basis)) + cloned_obj = TransformerBasis(self._basis.__sklearn_clone__()) cloned_obj._basis.kernel_ = None return cloned_obj @@ -390,3 +419,41 @@ def __pow__(self, exponent: int) -> TransformerBasis: """ # errors are handled by Basis.__pow__ return TransformerBasis(self._basis**exponent) + + def _check_input(self, X: FeatureMatrix, y=None): + """Check that the input structure. + + TransformerBasis expects a 2-d array as an input. The number of columns should match the number of inputs + the basis expects. This number can be set before the TransformerBasis is initialized, by calling + ``Basis.set_input_shape``. + + Parameters + ---------- + X: + The input FeatureMatrix. + + Raises + ------ + ValueError: + If the input is not a 2-d array or if the number of columns does not match the expected number of inputs. + """ + ndim = getattr(X, "ndim", None) + if ndim is None: + raise ValueError("The input must be a 2-dimensional array.") + + elif ndim != 2: + raise ValueError( + f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead." + ) + + if X.shape[1] != sum(self.n_basis_input_): + raise ValueError( + f"Input mismatch: expected {sum(self.n_basis_input_)} inputs, but got {X.shape[1]} columns in X.\n" + "To modify the required number of inputs, call `set_input_shape` before using `fit` or `fit_transform`." + ) + + if y is not None and y.shape[0] != X.shape[0]: + raise ValueError( + "X and y must have the same number of samples. " + f"X has {X.shpae[0]} samples, while y has {y.shape[0]} samples." + ) diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py index 9caea358..8601a101 100644 --- a/src/nemos/basis/basis.py +++ b/src/nemos/basis/basis.py @@ -9,7 +9,7 @@ from ..typing import FeatureMatrix from ._basis import add_docstring -from ._basis_mixin import ConvBasisMixin, EvalBasisMixin +from ._basis_mixin import AtomicBasisMixin, ConvBasisMixin, EvalBasisMixin from ._decaying_exponential import OrthExponentialBasis from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis @@ -83,7 +83,7 @@ def __init__( bounds: Optional[Tuple[float, float]] = None, label: Optional[str] = "BSplineEval", ): - EvalBasisMixin.__init__(self, bounds=bounds) + BSplineBasis.__init__( self, n_basis_funcs, @@ -91,6 +91,7 @@ def __init__( order=order, label=label, ) + EvalBasisMixin.__init__(self, bounds=bounds) @add_docstring("split_by_feature", BSplineBasis) def split_by_feature( @@ -157,6 +158,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.BSplineEval(5) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class BSplineConv(ConvBasisMixin, BSplineBasis): """ @@ -283,6 +309,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.BSplineConv(5, 10) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis): """ @@ -396,6 +447,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.CyclicBSplineEval(5) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis): """ @@ -514,6 +590,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.CyclicBSplineConv(5, 10) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class MSplineEval(EvalBasisMixin, MSplineBasis): r""" @@ -651,6 +752,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.MSplineEval(5) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class MSplineConv(ConvBasisMixin, MSplineBasis): r""" @@ -793,6 +919,31 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: """ return super().evaluate_on_grid(n_samples) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.MSplineConv(5, 10) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear): """ @@ -907,6 +1058,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLinearEval(5) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear): """ @@ -1026,6 +1202,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLinearConv(5, 10) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. @@ -1156,6 +1357,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLogEval(5) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog): """Represent log-spaced raised cosine basis functions. @@ -1287,6 +1513,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.RaisedCosineLogConv(5, 10) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1399,6 +1650,31 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6)) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis): """Set of 1D basis decaying exponential functions numerically orthogonalized. @@ -1451,6 +1727,9 @@ def __init__( decay_rates=decay_rates, label=label, ) + # re-check window size because n_basis_funcs is not set yet when the + # property setter runs the first check. + self._check_window_size(self.window_size) @add_docstring("evaluate_on_grid", OrthExponentialBasis) def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: @@ -1514,3 +1793,40 @@ def split_by_feature( """ return super().split_by_feature(x, axis=axis) + + @add_docstring("set_input_shape", AtomicBasisMixin) + def set_input_shape(self, xi: int | tuple[int, ...] | NDArray): + """ + Examples + -------- + >>> import nemos as nmo + >>> import numpy as np + >>> basis = nmo.basis.OrthExponentialConv(5, window_size=10, decay_rates=np.arange(1, 6)) + >>> # Configure with an integer input: + >>> _ = basis.set_input_shape(3) + >>> basis.n_output_features + 15 + >>> # Configure with a tuple: + >>> _ = basis.set_input_shape((4, 5)) + >>> basis.n_output_features + 100 + >>> # Configure with an array: + >>> x = np.ones((10, 4, 5)) + >>> _ = basis.set_input_shape(x) + >>> basis.n_output_features + 100 + + """ + return AtomicBasisMixin.set_input_shape(self, xi) + + def _check_window_size(self, window_size: int): + """OrthExponentialBasis specific window size check.""" + super()._check_window_size(window_size) + # if n_basis_funcs is not yet initialized, skip check + n_basis = getattr(self, "n_basis_funcs", None) + if n_basis and window_size < n_basis: + raise ValueError( + "OrthExponentialConv basis requires at least a window_size larger then the number " + f"of basis functions. window_size is {window_size}, n_basis_funcs while" + f"is {self.n_basis_funcs}." + ) diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py index b949b489..d0f7709e 100644 --- a/src/nemos/identifiability_constraints.py +++ b/src/nemos/identifiability_constraints.py @@ -218,6 +218,8 @@ def apply_identifiability_constraints( >>> from nemos.identifiability_constraints import apply_identifiability_constraints >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM + >>> import jax + >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) @@ -280,9 +282,11 @@ def apply_identifiability_constraints_by_basis_component( Examples -------- >>> import numpy as np + >>> import jax >>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM + >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) diff --git a/tests/conftest.py b/tests/conftest.py index eb88ed10..3daba960 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ and loading predefined parameters for testing various functionalities of the NeMoS library. """ +import abc + import jax import jax.numpy as jnp import numpy as np @@ -16,11 +18,112 @@ import pytest import nemos as nmo +import nemos._inspect_utils as inspect_utils +import nemos.basis.basis as basis +from nemos.basis import AdditiveBasis, MultiplicativeBasis +from nemos.basis._basis import Basis # shut-off conversion warnings nap.nap_config.suppress_conversion_warnings = True +@pytest.fixture() +def basis_class_specific_params(): + """Returns all the params for each class.""" + all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") + return {cls.__name__: cls._get_param_names() for cls in all_cls} + + +class BasisFuncsTesting(abc.ABC): + """ + An abstract base class that sets the foundation for individual basis function testing. + This class requires an implementation of a 'cls' method, which is utilized by the meta-test + that verifies if all basis functions are properly tested. + """ + + @abc.abstractmethod + def cls(self): + pass + + +class CombinedBasis(BasisFuncsTesting): + """ + This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. + + Properties: + - cls: Class (default = None) + """ + + cls = None + + @staticmethod + def instantiate_basis( + n_basis, basis_class, class_specific_params, window_size=10, **kwargs + ): + """Instantiate and return two basis of the type specified.""" + + # Set non-optional args + default_kwargs = { + "n_basis_funcs": n_basis, + "window_size": window_size, + "decay_rates": np.arange(1, 1 + n_basis), + } + repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) + if repeated_keys: + raise ValueError( + "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" + ) + + # Merge with provided extra kwargs + kwargs = {**default_kwargs, **kwargs} + + if basis_class == AdditiveBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 + b2 + elif basis_class == MultiplicativeBasis: + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params + ) + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params + ) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) + basis_obj = b1 * b2 + else: + basis_obj = basis_class( + **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) + ) + return basis_obj + + +# automatic define user accessible basis and check the methods +def list_all_basis_classes(filter_basis="all") -> list[type]: + """ + Return all the classes in nemos.basis which are a subclass of Basis, + which should be all concrete classes except TransformerBasis. + """ + all_basis = [ + class_obj + for _, class_obj in inspect_utils.get_non_abstract_classes(basis) + if issubclass(class_obj, Basis) + ] + [ + bas + for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) + if bas != basis.TransformerBasis + ] + if filter_basis != "all": + all_basis = [a for a in all_basis if filter_basis in a.__name__] + return all_basis + + # Sample subclass to test instantiation and methods class MockRegressor(nmo.base_regressor.BaseRegressor): """ diff --git a/tests/test_basis.py b/tests/test_basis.py index d794db53..8daeb683 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,4 +1,3 @@ -import abc import inspect import itertools import pickle @@ -11,9 +10,8 @@ import numpy as np import pynapple as nap import pytest -from sklearn.base import clone as sk_clone +from conftest import BasisFuncsTesting, CombinedBasis, list_all_basis_classes -import nemos as nmo import nemos._inspect_utils as inspect_utils import nemos.basis.basis as basis import nemos.convolve as convolve @@ -34,33 +32,6 @@ def extra_decay_rates(cls, n_basis): return {} -# automatic define user accessible basis and check the methods -def list_all_basis_classes(filter_basis="all") -> list[type]: - """ - Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. - """ - all_basis = [ - class_obj - for _, class_obj in inspect_utils.get_non_abstract_classes(basis) - if issubclass(class_obj, Basis) - ] + [ - bas - for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) - if bas != basis.TransformerBasis - ] - if filter_basis != "all": - all_basis = [a for a in all_basis if filter_basis in a.__name__] - return all_basis - - -@pytest.fixture() -def class_specific_params(): - """Returns all the params for each class.""" - all_cls = list_all_basis_classes("Conv") + list_all_basis_classes("Eval") - return {cls.__name__: cls._get_param_names() for cls in all_cls} - - def test_all_basis_are_tested() -> None: """Meta-test. @@ -130,14 +101,18 @@ def test_all_basis_are_tested() -> None: "split_by_feature", "Decompose an array along a specified axis into sub-arrays", ), + ( + "set_input_shape", + "Set the expected input shape for the basis object", + ), ], ) def test_example_docstrings_add( - basis_cls, method_name, descr_match, class_specific_params + basis_cls, method_name, descr_match, basis_class_specific_params ): basis_instance = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 + 5, basis_cls, basis_class_specific_params, window_size=10 ) method = getattr(basis_instance, method_name) doc = method.__doc__ @@ -178,6 +153,15 @@ def method(self): pass assert CustomSubClass().method.__doc__ == "My extra text.\nMy custom method." + with pytest.raises(AttributeError, match="CustomClass has no attribute"): + + class CustomSubClass2(CustomClass): + @custom_add_docstring("unknown", cls=CustomClass) + def method(self): + """My custom method.""" + pass + + CustomSubClass2() @pytest.mark.parametrize( @@ -267,33 +251,36 @@ def test_expected_output_compute_features(basis_instance, super_class): ), OrthExponentialBasis, ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + * basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), + ( + basis.OrthExponentialConv( + 10, decay_rates=np.arange(1, 11), window_size=12, label="a" + ) + + basis.RaisedCosineLogConv(10, window_size=11, label="b"), + OrthExponentialBasis, + ), ], ) def test_expected_output_split_by_feature(basis_instance, super_class): - x = super_class.compute_features(basis_instance, np.linspace(0, 1, 100)) + inp = [np.linspace(0, 1, 100)] * basis_instance._n_input_dimensionality + x = super_class.compute_features(basis_instance, *inp) xdict = super_class.split_by_feature(basis_instance, x) xxdict = basis_instance.split_by_feature(x) assert xdict.keys() == xxdict.keys() - xx = xxdict["label"] - x = xdict["label"] - nans = np.isnan(x.sum(axis=(1, 2))) - assert np.all(np.isnan(xx[nans])) - np.testing.assert_array_equal(xx[~nans], x[~nans]) + for k in xdict.keys(): + xx = xxdict[k] + x = xdict[k] + nans = np.isnan(x.sum(axis=(1, 2))) + assert np.all(np.isnan(xx[nans])) + np.testing.assert_array_equal(xx[~nans], x[~nans]) -class BasisFuncsTesting(abc.ABC): - """ - An abstract base class that sets the foundation for individual basis function testing. - This class requires an implementation of a 'cls' method, which is utilized by the meta-test - that verifies if all basis functions are properly tested. - """ - - @abc.abstractmethod - def cls(self): - pass - - -# Auto-generated file with stripped classes and shared methods @pytest.mark.parametrize( "cls", [ @@ -339,12 +326,44 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): with expectation: bas._evaluate(samples) + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("vmin, vmax", [(0, 1), (-1, 1)]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_eval(self, cls, n_basis, vmin, vmax, inp_num): + bas = cls["eval"]( + n_basis, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + + @pytest.mark.parametrize("n_basis", [5, 6]) + @pytest.mark.parametrize("ws", [10, 20]) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone_conv(self, cls, n_basis, ws, inp_num): + bas = cls["conv"]( + n_basis, window_size=ws, **extra_decay_rates(cls["eval"], n_basis) + ) + bas.set_input_shape(inp_num) + bas2 = bas.__sklearn_clone__() + assert id(bas) != id(bas2) + assert np.all( + bas.__dict__.pop("decay_rates", True) + == bas2.__dict__.pop("decay_rates", True) + ) + assert bas.__dict__ == bas2.__dict__ + @pytest.mark.parametrize( "attribute, value", [ ("label", None), ("label", "label"), - ("n_basis_input", 1), + ("n_basis_input_", 1), ("n_output_features", 5), ], ) @@ -438,10 +457,10 @@ def test_set_num_basis_input(self, n_input, cls): bas = cls["conv"]( n_basis_funcs=5, window_size=10, **extra_decay_rates(cls["conv"], 5) ) - assert bas.n_basis_input is None + assert bas.n_basis_input_ is None bas.compute_features(np.random.randn(20, n_input)) - assert bas.n_basis_input == (n_input,) - assert bas._n_basis_input == (n_input,) + assert bas.n_basis_input_ == (n_input,) + assert bas._n_basis_input_ == (n_input,) @pytest.mark.parametrize( "bounds, samples, nan_idx, mn, mx", @@ -516,7 +535,7 @@ def test_vmin_vmax_init(self, bounds, expectation, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_basis_number(self, n_basis, mode, kwargs, cls): @@ -548,7 +567,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) @pytest.mark.parametrize("n_basis", [6]) def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls): @@ -568,7 +587,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls ) @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): bas = cls[mode]( @@ -582,7 +601,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): @pytest.mark.parametrize("n_basis", [6]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan_location(self, mode, kwargs, n_basis, cls): bas = cls[mode]( @@ -615,7 +634,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): bas._evaluate(samples) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -625,7 +644,7 @@ def test_call_nan(self, mode, kwargs, cls): @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_non_empty(self, n_basis, mode, kwargs, cls): bas = cls[mode]( @@ -636,7 +655,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -653,7 +672,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -718,7 +737,7 @@ def test_compute_features_conv_input( order, width, cls, - class_specific_params, + basis_class_specific_params, ): x = np.ones(input_shape) @@ -733,7 +752,9 @@ def test_compute_features_conv_input( ) # figure out which kwargs needs to be removed - kwargs = inspect_utils.trim_kwargs(cls["conv"], kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + cls["conv"], kwargs, basis_class_specific_params + ) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -909,7 +930,7 @@ def test_convolution_is_performed(self, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -928,7 +949,7 @@ def test_evaluate_on_grid_basis_size(self, sample_size, mode, kwargs, cls): @pytest.mark.parametrize("n_input", [0, 1, 2]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): basis_obj = cls[mode]( @@ -953,7 +974,7 @@ def test_evaluate_on_grid_input_number(self, n_input, mode, kwargs, cls): @pytest.mark.parametrize("sample_size", [-1, 0, 1, 10, 11, 100]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_evaluate_on_grid_meshgrid_size(self, sample_size, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -987,7 +1008,7 @@ def test_fit_kernel_shape(self, cls): @pytest.mark.parametrize( "mode, ws, expectation", [ - ("conv", 2, does_not_raise()), + ("conv", 5, does_not_raise()), ( "conv", -1, @@ -1032,9 +1053,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): n_basis_funcs=5, window_size=ws, **extra_decay_rates(cls[mode], 5) ) - @pytest.mark.parametrize("samples", [[], [0], [0, 0]]) + @pytest.mark.parametrize("samples", [[], [0] * 10, [0] * 11]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_non_empty_samples(self, samples, mode, kwargs, cls): if "OrthExp" in cls["eval"].__name__: @@ -1079,7 +1100,7 @@ def test_number_of_required_inputs_compute_features( basis_obj.compute_features(*inputs) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 8})] ) def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) @@ -1177,7 +1198,7 @@ def test_set_params( decay_rates, conv_kwargs, cls, - class_specific_params, + basis_class_specific_params, ): """Test the read-only and read/write property of the parameters.""" pars = dict( @@ -1194,7 +1215,7 @@ def test_set_params( pars = { key: value for key, value in pars.items() - if key in class_specific_params[cls[mode].__name__] + if key in basis_class_specific_params[cls[mode].__name__] } keys = list(pars.keys()) @@ -1238,15 +1259,16 @@ def test_set_window_size(self, mode, expectation, cls): def test_transform_fails(self, cls): bas = cls["conv"]( - n_basis_funcs=5, window_size=3, **extra_decay_rates(cls["conv"], 5) + n_basis_funcs=5, window_size=5, **extra_decay_rates(cls["conv"], 5) ) with pytest.raises( - ValueError, match="You must call `_set_kernel` before `_compute_features`" + RuntimeError, match="You must call `setup_basis` before `_compute_features`" ): bas._compute_features(np.linspace(0, 1, 10)) def test_transformer_get_params(self, cls): bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas_transformer = bas.to_transformer() params_transf = bas_transformer.get_params() params_transf.pop("_basis") @@ -1256,6 +1278,95 @@ def test_transformer_get_params(self, cls): assert params_transf == params_basis assert np.all(rates_1 == rates_2) + @pytest.mark.parametrize( + "x, inp_shape, expectation", + [ + (np.ones((10,)), 1, does_not_raise()), + ( + np.ones((10, 1)), + 1, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + (np.ones((10, 2)), 2, does_not_raise()), + ( + np.ones((10, 1)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 2, 1)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1, 2)), + 2, + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + (np.ones((10, 1)), (1,), does_not_raise()), + (np.ones((10,)), tuple(), does_not_raise()), + (np.ones((10,)), np.zeros((12,)), does_not_raise()), + ( + np.ones((10,)), + (1,), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1)), + (), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10, 1)), + np.zeros((12,)), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ( + np.ones((10)), + np.zeros((12, 1)), + pytest.raises(ValueError, match="Input shape mismatch detected"), + ), + ], + ) + def test_input_shape_validity(self, x, inp_shape, expectation, cls): + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + bas.set_input_shape(inp_shape) + with expectation: + bas.compute_features(x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + ((1, 1), does_not_raise()), + ( + (1, 1.0), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + (np.ones((1,)), does_not_raise()), + (np.ones((1, 1)), does_not_raise()), + ], + ) + def test_set_input_value_types(self, inp_shape, expectation, cls): + bas = cls["eval"](n_basis_funcs=5, **extra_decay_rates(cls["eval"], 5)) + with expectation: + bas.set_input_shape(inp_shape) + + @pytest.mark.parametrize( + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})] + ) + def test_iterate_over_component(self, mode, kwargs, cls): + basis_obj = cls[mode]( + n_basis_funcs=5, + **kwargs, + **extra_decay_rates(cls[mode], 5), + ) + + out = tuple(basis_obj._iterate_over_components()) + assert len(out) == 1 + assert id(out[0]) == id(basis_obj) + class TestRaisedCosineLogBasis(BasisFuncsTesting): cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @@ -1276,7 +1387,7 @@ def test_decay_to_zero_basis_number_match(self, width): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1387,7 +1498,7 @@ def test_time_scaling_values(self, time_scaling, expectation, mode, kwargs): ], ) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_width_values(self, width, expectation, mode, kwargs): with expectation: @@ -1399,7 +1510,7 @@ class TestRaisedCosineLinearBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, mode, kwargs @@ -1474,7 +1585,7 @@ class TestMSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1494,6 +1605,15 @@ def test_minimum_number_of_basis_required_is_matched( n_basis_funcs=n_basis_funcs, order=order, **kwargs ) basis_obj.compute_features(np.linspace(0, 1, 10)) + + # test the setter valuerror + if (order > 1) & (n_basis_funcs > 1): + basis_obj = self.cls[mode](n_basis_funcs=20, order=order, **kwargs) + with pytest.raises( + ValueError, + match=rf"{self.cls[mode].__name__} `order` parameter cannot be larger than", + ): + basis_obj.n_basis_funcs = n_basis_funcs else: basis_obj = self.cls[mode]( n_basis_funcs=n_basis_funcs, order=order, **kwargs @@ -1575,6 +1695,60 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + with expectation: + self.cls["conv"](n_basis, decay_rates=decay_rates, window_size=window_size) + + def test_check_window_size_after_init(self): + decay_rates = np.asarray(np.arange(1, 5 + 1), dtype=float) + expectation = pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ) + bas = self.cls["conv"](5, decay_rates=decay_rates, window_size=10) + with expectation: + bas.window_size = 4 + + @pytest.mark.parametrize( + "window_size, n_basis, expectation", + [ + ( + 4, + 5, + pytest.raises( + ValueError, + match="OrthExponentialConv basis requires at least a window_size", + ), + ), + (5, 5, does_not_raise()), + ], + ) + def test_window_size_at_init(self, window_size, n_basis, expectation): + decay_rates = np.asarray(np.arange(1, n_basis + 1), dtype=float) + obj = self.cls["conv"]( + n_basis, decay_rates=decay_rates, window_size=n_basis + 1 + ) + with expectation: + obj.window_size = window_size + + with expectation: + obj.set_params(window_size=window_size) + @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] ) @@ -1650,7 +1824,7 @@ class TestBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1738,7 +1912,7 @@ class TestCyclicBSplineBasis(BasisFuncsTesting): @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @pytest.mark.parametrize( - "mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})] + "mode, kwargs", [("eval", {}), ("conv", {"window_size": 5})] ) def test_minimum_number_of_basis_required_is_matched( self, n_basis_funcs, order, mode, kwargs @@ -1839,72 +2013,44 @@ def test_samples_range_matches_compute_features_requirements( basis_obj.compute_features(np.linspace(*sample_range, 100)) -class CombinedBasis(BasisFuncsTesting): - """ - This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. - - Properties: - - cls: Class (default = None) - """ - - cls = None +class TestAdditiveBasis(CombinedBasis): + cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} - @staticmethod - def instantiate_basis( - n_basis, basis_class, class_specific_params, window_size=10, **kwargs + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_iterate_over_component( + self, basis_a, basis_b, basis_class_specific_params ): - """Instantiate and return two basis of the type specified.""" - - # Set non-optional args - default_kwargs = { - "n_basis_funcs": n_basis, - "window_size": window_size, - "decay_rates": np.arange(1, 1 + n_basis), - } - repeated_keys = set(default_kwargs.keys()).intersection(kwargs.keys()) - if repeated_keys: - raise ValueError( - "Cannot set `n_basis_funcs, window_size, decay_rates` with kwargs" - ) + basis_a_obj = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b_obj = self.instantiate_basis( + 6, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a_obj + basis_b_obj + out = tuple(add._iterate_over_components()) + assert len(out) == add._n_input_dimensionality - # Merge with provided extra kwargs - kwargs = {**default_kwargs, **kwargs} + def get_ids(bas): - if basis_class == AdditiveBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 + b2 - elif basis_class == MultiplicativeBasis: - kwargs_mspline = inspect_utils.trim_kwargs( - basis.MSplineEval, kwargs, class_specific_params - ) - kwargs_raised_cosine = inspect_utils.trim_kwargs( - basis.RaisedCosineLinearConv, kwargs, class_specific_params - ) - b1 = basis.MSplineEval(**kwargs_mspline) - b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) - basis_obj = b1 * b2 - else: - basis_obj = basis_class( - **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) - ) - return basis_obj + if hasattr(bas, "basis1"): + ids = get_ids(bas.basis1) + ids += get_ids(bas.basis2) + else: + ids = [id(bas)] + return ids + id_list = get_ids(add) -class TestAdditiveBasis(CombinedBasis): - cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} + assert tuple(id(o) for o in out) == tuple(id_list) @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) - def test_non_empty_samples(self, base_cls, samples, class_specific_params): + def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} - kwargs = inspect_utils.trim_kwargs(base_cls, kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs( + base_cls, kwargs, basis_class_specific_params + ) basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -1931,6 +2077,64 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [6]) + @pytest.mark.parametrize("n_basis_b", [5]) + @pytest.mark.parametrize("vmin, vmax", [(-1, 1)]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("inp_num", [1, 2]) + def test_sklearn_clone( + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + vmin, + vmax, + inp_num, + basis_class_specific_params, + ): + """Recursively check cloning.""" + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a_obj = basis_a_obj.set_input_shape( + *([inp_num] * basis_a_obj._n_input_dimensionality) + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, basis_class_specific_params, window_size=15 + ) + basis_b_obj = basis_b_obj.set_input_shape( + *([inp_num] * basis_b_obj._n_input_dimensionality) + ) + add = basis_a_obj + basis_b_obj + + def filter_attributes(obj, exclude_keys): + return { + key: val for key, val in obj.__dict__.items() if key not in exclude_keys + } + + def compare(b1, b2): + assert id(b1) != id(b2) + assert b1.__class__.__name__ == b2.__class__.__name__ + if hasattr(b1, "basis1"): + compare(b1.basis1, b2.basis1) + compare(b1.basis2, b2.basis2) + # add all params that are not parent or basis1,basis2 + d1 = filter_attributes(b1, exclude_keys=["basis1", "basis2", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["basis1", "basis2", "_parent"]) + assert d1 == d2 + else: + decay_rates_b1 = b1.__dict__.get("_decay_rates", -1) + decay_rates_b2 = b2.__dict__.get("_decay_rates", -1) + assert np.array_equal(decay_rates_b1, decay_rates_b2) + d1 = filter_attributes(b1, exclude_keys=["_decay_rates", "_parent"]) + d2 = filter_attributes(b2, exclude_keys=["_decay_rates", "_parent"]) + assert d1 == d2 + + add2 = add.__sklearn_clone__() + compare(add, add2) + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000]) @@ -1945,7 +2149,7 @@ def test_compute_features_returns_expected_number_of_basis( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the evaluation of the `AdditiveBasis` results in a number of basis @@ -1953,10 +2157,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj @@ -1984,16 +2188,16 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.compute_features( @@ -2021,17 +2225,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj required_dim = ( @@ -2053,16 +2257,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2077,16 +2287,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2100,17 +2316,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj + basis_b_obj inputs = [20] * n_input @@ -2132,7 +2354,13 @@ def test_evaluate_on_grid_input_number( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2141,9 +2369,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_add = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) + self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) # compute_features the basis over pynapple Tsd objects out = basis_add.compute_features(*([inp] * basis_add._n_input_dimensionality)) @@ -2158,7 +2386,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -2167,13 +2395,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2192,7 +2420,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2206,20 +2434,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2232,25 +2460,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -2258,10 +2492,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2274,40 +2508,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=3 + n_basis_a, basis_a, basis_class_specific_params, window_size=9 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=3 + n_basis_b, basis_b, basis_class_specific_params, window_size=9 ) bas_eva = basis_a_obj + basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = np.linspace(0, 1, 10) @@ -2319,19 +2559,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2340,19 +2586,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -2366,7 +2618,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2381,7 +2633,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -2394,10 +2646,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj + basis_b_obj with expectation: @@ -2408,22 +2660,22 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj - bas.set_kernel() + bas.setup_basis(*([np.ones(10)] * bas._n_input_dimensionality)) def check_kernel(basis_obj): has_kern = [] - if hasattr(basis_obj, "_basis1"): - has_kern += check_kernel(basis_obj._basis1) - has_kern += check_kernel(basis_obj._basis2) + if hasattr(basis_obj, "basis1"): + has_kern += check_kernel(basis_obj.basis1) + has_kern += check_kernel(basis_obj.basis2) else: has_kern += [ basis_obj.kernel_ is not None if basis_obj.mode == "conv" else True @@ -2437,21 +2689,21 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj + basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: context = does_not_raise() else: context = pytest.raises( - ValueError, - match="You must call `_set_kernel` before `_compute_features`", + RuntimeError, + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2475,11 +2727,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -2499,56 +2751,295 @@ def test_expected_input_number(self, n_input, expectation): with expectation: bas.compute_features(np.random.randn(30, 2), np.random.randn(30, n_input)) - -class TestMultiplicativeBasis(CombinedBasis): - cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} - @pytest.mark.parametrize( - "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - @pytest.mark.parametrize(" ws", [3]) - def test_non_empty_samples(self, samples, ws): - basis_obj = basis.MSplineEval(5) * basis.RaisedCosineLinearEval(5) - if any(tuple(len(s) == 0 for s in samples)): - with pytest.raises( - ValueError, match="All sample provided must be non empty" - ): - basis_obj.compute_features(*samples) - else: - basis_obj.compute_features(*samples) - @pytest.mark.parametrize( - "eval_input", - [ - [0, 0], - [[0], [0]], - [(0,), (0,)], - [np.array([0]), [0]], - [jax.numpy.array([0]), [0]], - ], + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") ) - def test_compute_features_input(self, eval_input): - """ - Checks that the sample size of the output from the compute_features() method matches the input sample size. - """ - basis_obj = basis.MSplineEval(5) * basis.MSplineEval(5) - basis_obj.compute_features(*eval_input) - - @pytest.mark.parametrize("n_basis_a", [5, 6]) - @pytest.mark.parametrize("n_basis_b", [5, 6]) - @pytest.mark.parametrize("sample_size", [10, 1000]) - @pytest.mark.parametrize("basis_a", list_all_basis_classes()) - @pytest.mark.parametrize("basis_b", list_all_basis_classes()) - @pytest.mark.parametrize("window_size", [10]) - def test_compute_features_returns_expected_number_of_basis( + @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) + @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_1d_arrays( self, - n_basis_a, - n_basis_b, - sample_size, basis_a, basis_b, - window_size, - class_specific_params, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) + @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_2d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) + @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_nd_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + + add.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + add.compute_features(*x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + (((1, 1), (1, 1)), does_not_raise()), + ( + ((1, 1.0), (1, 1)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ( + ((1, 1), (1, 1.0)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ], + ) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_set_input_value_types( + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + with expectation: + add.set_input_shape(*inp_shape) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + # test pointing to different objects + assert id(add.basis1) != id(basis_a) + assert id(add.basis1) != id(basis_b) + assert id(add.basis2) != id(basis_a) + assert id(add.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert add.basis1.n_basis_funcs == 5 + assert add.basis2.n_basis_funcs == 5 + + add.basis1.n_basis_funcs = 6 + add.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + add = basis_a + basis_b + add.basis1.n_basis_funcs = 10 + assert add.n_basis_funcs == 15 + add.basis2.n_basis_funcs = 10 + assert add.n_basis_funcs == 20 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + add = basis_a + basis_b + inps_a = [2] * basis_a._n_input_dimensionality + add.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * add.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * add.basis1.n_basis_funcs + assert add.n_output_features == new_out_num + add.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * add.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * add.basis2.n_basis_funcs + add.basis2.set_input_shape(*inps_b) + assert add.n_output_features == new_out_num + new_out_num_b + + +class TestMultiplicativeBasis(CombinedBasis): + cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis} + + @pytest.mark.parametrize( + "samples", [[[0], []], [[], [0]], [[0], [0]], [[0, 0], [0, 0]]] + ) + @pytest.mark.parametrize(" ws", [3]) + def test_non_empty_samples(self, samples, ws): + basis_obj = basis.MSplineEval(5) * basis.RaisedCosineLinearEval(5) + if any(tuple(len(s) == 0 for s in samples)): + with pytest.raises( + ValueError, match="All sample provided must be non empty" + ): + basis_obj.compute_features(*samples) + else: + basis_obj.compute_features(*samples) + + @pytest.mark.parametrize( + "eval_input", + [ + [0, 0], + [[0], [0]], + [(0,), (0,)], + [np.array([0]), [0]], + [jax.numpy.array([0]), [0]], + ], + ) + def test_compute_features_input(self, eval_input): + """ + Checks that the sample size of the output from the compute_features() method matches the input sample size. + """ + basis_obj = basis.MSplineEval(5) * basis.MSplineEval(5) + basis_obj.compute_features(*eval_input) + + @pytest.mark.parametrize("n_basis_a", [5, 6]) + @pytest.mark.parametrize("n_basis_b", [5, 6]) + @pytest.mark.parametrize("sample_size", [10, 1000]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + @pytest.mark.parametrize("window_size", [10]) + def test_compute_features_returns_expected_number_of_basis( + self, + n_basis_a, + n_basis_b, + sample_size, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): """ Test whether the evaluation of the `MultiplicativeBasis` results in a number of basis @@ -2556,10 +3047,10 @@ def test_compute_features_returns_expected_number_of_basis( """ # define the two basis basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj @@ -2588,17 +3079,17 @@ def test_sample_size_of_compute_features_matches_that_of_input( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the output sample size from the `MultiplicativeBasis` fit_transform function matches the input sample size. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.compute_features( @@ -2625,17 +3116,17 @@ def test_number_of_required_inputs_compute_features( basis_a, basis_b, window_size, - class_specific_params, + basis_class_specific_params, ): """ Test whether the number of required inputs for the `compute_features` function matches the sum of the number of input samples from the two bases. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj required_dim = ( @@ -2657,16 +3148,22 @@ def test_number_of_required_inputs_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_meshgrid_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the resulting meshgrid size matches the sample size input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj res = basis_obj.evaluate_on_grid( @@ -2681,16 +3178,22 @@ def test_evaluate_on_grid_meshgrid_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_basis_size( - self, sample_size, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, + sample_size, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + basis_class_specific_params, ): """ Test whether the number sample size output by evaluate_on_grid matches the sample size of the input. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj eval_basis = basis_obj.evaluate_on_grid( @@ -2704,17 +3207,23 @@ def test_evaluate_on_grid_basis_size( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [6]) def test_evaluate_on_grid_input_number( - self, n_input, basis_a, basis_b, n_basis_a, n_basis_b, class_specific_params + self, + n_input, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + basis_class_specific_params, ): """ Test whether the number of inputs provided to `evaluate_on_grid` matches the sum of the number of input samples required from each of the basis objects. """ basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) basis_obj = basis_a_obj * basis_b_obj inputs = [20] * n_input @@ -2744,15 +3253,15 @@ def test_inconsistent_sample_sizes( n_basis_b, sample_size_a, sample_size_b, - class_specific_params, + basis_class_specific_params, ): """Test that the inputs of inconsistent sample sizes result in an exception when compute_features is called""" raise_exception = sample_size_a != sample_size_b basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) input_a = [ np.linspace(0, 1, sample_size_a) @@ -2776,7 +3285,13 @@ def test_inconsistent_sample_sizes( @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) def test_pynapple_support_compute_features( - self, basis_a, basis_b, n_basis_a, n_basis_b, sample_size, class_specific_params + self, + basis_a, + basis_b, + n_basis_a, + n_basis_b, + sample_size, + basis_class_specific_params, ): iset = nap.IntervalSet(start=[0, 0.5], end=[0.49999, 1]) inp = nap.Tsd( @@ -2785,9 +3300,9 @@ def test_pynapple_support_compute_features( time_support=iset, ) basis_prod = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) * self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) out = basis_prod.compute_features(*([inp] * basis_prod._n_input_dimensionality)) assert isinstance(out, nap.TsdFrame) @@ -2799,7 +3314,7 @@ def test_pynapple_support_compute_features( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) @pytest.mark.parametrize("num_input", [0, 1, 2, 3, 4, 5]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) def test_call_input_num( self, n_basis_a, @@ -2808,13 +3323,13 @@ def test_call_input_num( basis_b, num_input, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj if num_input == basis_obj._n_input_dimensionality: @@ -2833,7 +3348,7 @@ def test_call_input_num( (np.linspace(0, 1, 10)[:, None], pytest.raises(ValueError)), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2847,20 +3362,20 @@ def test_call_input_shape( inp, window_size, expectation, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj with expectation: basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -2873,25 +3388,31 @@ def test_call_sample_axis( basis_b, time_axis_shape, window_size, - class_specific_params, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_nan( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): if ( basis_a == basis.OrthExponentialBasis @@ -2899,10 +3420,10 @@ def test_call_nan( ): return basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality @@ -2915,40 +3436,46 @@ def test_call_nan( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_equivalent_in_conv( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas_eva = basis_a_obj * basis_b_obj basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=8 + n_basis_a, basis_a, basis_class_specific_params, window_size=8 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=8 + n_basis_b, basis_b, basis_class_specific_params, window_size=8 ) bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_pynapple_support( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = np.linspace(0, 1, 10) @@ -2960,19 +3487,25 @@ def test_pynapple_support( assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [6, 7]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_basis_number( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -2981,19 +3514,25 @@ def test_call_basis_number( == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_call_non_empty( - self, n_basis_a, n_basis_b, basis_a, basis_b, window_size, class_specific_params + self, + n_basis_a, + n_basis_b, + basis_a, + basis_b, + window_size, + basis_class_specific_params, ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): @@ -3007,7 +3546,7 @@ def test_call_non_empty( (0.1, 2, does_not_raise()), ], ) - @pytest.mark.parametrize(" window_size", [3]) + @pytest.mark.parametrize(" window_size", [8]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @pytest.mark.parametrize("n_basis_a", [5]) @@ -3022,7 +3561,7 @@ def test_call_sample_range( mx, expectation, window_size, - class_specific_params, + basis_class_specific_params, ): if expectation == "check": if ( @@ -3035,10 +3574,10 @@ def test_call_sample_range( else: expectation = does_not_raise() basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=window_size + n_basis_a, basis_a, basis_class_specific_params, window_size=window_size ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=window_size + n_basis_b, basis_b, basis_class_specific_params, window_size=window_size ) bas = basis_a_obj * basis_b_obj with expectation: @@ -3049,22 +3588,22 @@ def test_call_sample_range( @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_fit_kernel( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj - bas.set_kernel() + bas._set_input_independent_states() def check_kernel(basis_obj): has_kern = [] - if hasattr(basis_obj, "_basis1"): - has_kern += check_kernel(basis_obj._basis1) - has_kern += check_kernel(basis_obj._basis2) + if hasattr(basis_obj, "basis1"): + has_kern += check_kernel(basis_obj.basis1) + has_kern += check_kernel(basis_obj.basis2) else: has_kern += [ basis_obj.kernel_ is not None if basis_obj.mode == "conv" else True @@ -3078,21 +3617,21 @@ def check_kernel(basis_obj): @pytest.mark.parametrize("n_basis_a", [5]) @pytest.mark.parametrize("n_basis_b", [5]) def test_transform_fails( - self, n_basis_a, n_basis_b, basis_a, basis_b, class_specific_params + self, n_basis_a, n_basis_b, basis_a, basis_b, basis_class_specific_params ): basis_a_obj = self.instantiate_basis( - n_basis_a, basis_a, class_specific_params, window_size=10 + n_basis_a, basis_a, basis_class_specific_params, window_size=10 ) basis_b_obj = self.instantiate_basis( - n_basis_b, basis_b, class_specific_params, window_size=10 + n_basis_b, basis_b, basis_class_specific_params, window_size=10 ) bas = basis_a_obj * basis_b_obj if "Eval" in basis_a.__name__ and "Eval" in basis_b.__name__: context = does_not_raise() else: context = pytest.raises( - ValueError, - match="You must call `_set_kernel` before `_compute_features`", + RuntimeError, + match="You must call `setup_basis` before `_compute_features`", ) with context: x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality @@ -3116,11 +3655,11 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 - assert bas_add.n_basis_input is None + assert bas_add.n_basis_input_ is None bas_add.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_add.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_add.n_basis_input_ == (n_basis_input1, n_basis_input2) @pytest.mark.parametrize( "n_input, expectation", @@ -3142,21 +3681,260 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) - def test_n_basis_input(self, n_basis_input1, n_basis_input2): + def test_n_basis_input_(self, n_basis_input1, n_basis_input2): bas1 = basis.RaisedCosineLinearConv(10, window_size=10) bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) ) - assert bas_prod.n_basis_input == (n_basis_input1, n_basis_input2) + assert bas_prod.n_basis_input_ == (n_basis_input1, n_basis_input2) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize("shape_a", [1, (), np.ones(3)]) + @pytest.mark.parametrize("shape_b", [1, (), np.ones(3)]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_1d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, *add_shape_a)), np.ones((10, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize("shape_a", [2, (2,), np.ones((3, 2))]) + @pytest.mark.parametrize("shape_b", [3, (3,), np.ones((3, 3))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_2d_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, 2, *add_shape_a)), np.ones((10, 3, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize("shape_a", [(2, 2), np.ones((3, 2, 2))]) + @pytest.mark.parametrize("shape_b", [(3, 1), np.ones((3, 3, 1))]) + @pytest.mark.parametrize("add_shape_a", [(), (1,)]) + @pytest.mark.parametrize("add_shape_b", [(), (1,)]) + def test_set_input_shape_type_nd_arrays( + self, + basis_a, + basis_b, + shape_a, + shape_b, + basis_class_specific_params, + add_shape_a, + add_shape_b, + ): + x = (np.ones((10, 2, 2, *add_shape_a)), np.ones((10, 3, 1, *add_shape_b))) + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + + mul.set_input_shape(shape_a, shape_b) + if add_shape_a == () and add_shape_b == (): + expectation = does_not_raise() + else: + expectation = pytest.raises( + ValueError, match="Input shape mismatch detected" + ) + with expectation: + mul.compute_features(*x) + + @pytest.mark.parametrize( + "inp_shape, expectation", + [ + (((1, 1), (1, 1)), does_not_raise()), + ( + ((1, 1.0), (1, 1)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ( + ((1, 1), (1, 1.0)), + pytest.raises( + ValueError, match="The tuple provided contains non integer" + ), + ), + ], + ) + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_set_input_value_types( + self, inp_shape, expectation, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + with expectation: + mul.set_input_shape(*inp_shape) + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_deep_copy_basis(self, basis_a, basis_b, basis_class_specific_params): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + # test pointing to different objects + assert id(mul.basis1) != id(basis_a) + assert id(mul.basis1) != id(basis_b) + assert id(mul.basis2) != id(basis_a) + assert id(mul.basis2) != id(basis_b) + + # test attributes are not related + basis_a.n_basis_funcs = 10 + basis_b.n_basis_funcs = 10 + assert mul.basis1.n_basis_funcs == 5 + assert mul.basis2.n_basis_funcs == 5 + + mul.basis1.n_basis_funcs = 6 + mul.basis2.n_basis_funcs = 6 + assert basis_a.n_basis_funcs == 10 + assert basis_b.n_basis_funcs == 10 + + @pytest.mark.parametrize( + "basis_a", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + @pytest.mark.parametrize( + "basis_b", list_all_basis_classes("Eval") + list_all_basis_classes("Conv") + ) + def test_compute_n_basis_runtime( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + mul = basis_a * basis_b + mul.basis1.n_basis_funcs = 10 + assert mul.n_basis_funcs == 50 + mul.basis2.n_basis_funcs = 10 + assert mul.n_basis_funcs == 100 + + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_runtime_n_basis_out_compute( + self, basis_a, basis_b, basis_class_specific_params + ): + basis_a = self.instantiate_basis( + 5, basis_a, basis_class_specific_params, window_size=10 + ) + basis_a.set_input_shape( + *([1] * basis_a._n_input_dimensionality) + ).to_transformer() + basis_b = self.instantiate_basis( + 5, basis_b, basis_class_specific_params, window_size=10 + ) + basis_b.set_input_shape( + *([1] * basis_b._n_input_dimensionality) + ).to_transformer() + mul = basis_a * basis_b + inps_a = [2] * basis_a._n_input_dimensionality + mul.basis1.set_input_shape(*inps_a) + if isinstance(basis_a, MultiplicativeBasis): + new_out_num = np.prod(inps_a) * mul.basis1.n_basis_funcs + else: + new_out_num = inps_a[0] * mul.basis1.n_basis_funcs + assert mul.n_output_features == new_out_num * mul.basis2.n_basis_funcs + inps_b = [3] * basis_b._n_input_dimensionality + if isinstance(basis_b, MultiplicativeBasis): + new_out_num_b = np.prod(inps_b) * mul.basis2.n_basis_funcs + else: + new_out_num_b = inps_b[0] * mul.basis2.n_basis_funcs + mul.basis2.set_input_shape(*inps_b) + assert mul.n_output_features == new_out_num * new_out_num_b @pytest.mark.parametrize( "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) -def test_power_of_basis(exponent, basis_class, class_specific_params): +def test_power_of_basis(exponent, basis_class, basis_class_specific_params): """Test if the power behaves as expected.""" raise_exception_type = not type(exponent) is int @@ -3166,7 +3944,7 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): raise_exception_value = False basis_obj = CombinedBasis.instantiate_basis( - 5, basis_class, class_specific_params, window_size=10 + 5, basis_class, basis_class_specific_params, window_size=10 ) if raise_exception_type: @@ -3202,13 +3980,14 @@ def test_power_of_basis(exponent, basis_class, class_specific_params): "basis_cls", list_all_basis_classes(), ) -def test_basis_to_transformer(basis_cls, class_specific_params): +def test_basis_to_transformer(basis_cls, basis_class_specific_params): n_basis_funcs = 5 bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 + n_basis_funcs, basis_cls, basis_class_specific_params, window_size=10 ) - - trans_bas = bas.to_transformer() + trans_bas = bas.set_input_shape( + *([1] * bas._n_input_dimensionality) + ).to_transformer() assert isinstance(trans_bas, basis.TransformerBasis) @@ -3220,386 +3999,6 @@ def test_basis_to_transformer(basis_cls, class_specific_params): assert np.all(getattr(bas, k) == getattr(trans_bas, k)) -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformer_has_the_same_public_attributes_as_basis( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - public_attrs_basis = {attr for attr in dir(bas) if not attr.startswith("_")} - public_attrs_transformerbasis = { - attr for attr in dir(bas.to_transformer()) if not attr.startswith("_") - } - - assert public_attrs_transformerbasis - public_attrs_basis == { - "fit", - "fit_transform", - "transform", - } - - assert public_attrs_basis - public_attrs_transformerbasis == set() - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_to_transformer_and_constructor_are_equivalent( - basis_cls, class_specific_params -): - n_basis_funcs = 5 - bas = CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - - trans_bas_a = bas.to_transformer() - trans_bas_b = basis.TransformerBasis(bas) - - # they both just have a _basis - assert ( - list(trans_bas_a.__dict__.keys()) - == list(trans_bas_b.__dict__.keys()) - == ["_basis"] - ) - # and those bases are the same - assert np.all( - trans_bas_a._basis.__dict__.pop("_decay_rates", 1) - == trans_bas_b._basis.__dict__.pop("_decay_rates", 1) - ) - assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_basis_to_transformer_makes_a_copy(basis_cls, class_specific_params): - bas_a = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = bas_a.to_transformer() - - # changing an attribute in bas should not change trans_bas - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - bas_a._basis1.n_basis_funcs = 10 - assert trans_bas_a._basis._basis1.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b._basis._basis1.n_basis_funcs = 100 - assert bas_b._basis1.n_basis_funcs == 5 - else: - bas_a.n_basis_funcs = 10 - assert trans_bas_a.n_basis_funcs == 5 - - # changing an attribute in the transformer basis should not change the original - bas_b = CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_b = bas_b.to_transformer() - trans_bas_b.n_basis_funcs = 100 - assert bas_b.n_basis_funcs == 5 - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5, 10, 20]) -def test_transformerbasis_getattr(basis_cls, n_basis_funcs, class_specific_params): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_basis._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_basis.n_basis_funcs == n_basis_funcs - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -@pytest.mark.parametrize("n_basis_funcs_init", [5]) -@pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) -def test_transformerbasis_set_params( - basis_cls, n_basis_funcs_init, n_basis_funcs_new, class_specific_params -): - trans_basis = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_init, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) - - assert trans_basis.n_basis_funcs == n_basis_funcs_new - assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis(basis_cls, class_specific_params): - # setting the _basis attribute should change it - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas._basis = CombinedBasis().instantiate_basis( - 20, basis_cls, class_specific_params, window_size=10 - ) - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_setattr_basis_attribute(basis_cls, class_specific_params): - # setting an attribute that is an attribute of the underlying _basis - # should propagate setting it on _basis itself - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas.n_basis_funcs = 20 - - assert trans_bas.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv") + list_all_basis_classes("Eval"), -) -def test_transformerbasis_copy_basis_on_contsruct(basis_cls, class_specific_params): - # modifying the transformerbasis's attributes shouldn't - # touch the original basis that was used to create it - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - trans_bas = basis.TransformerBasis(orig_bas) - trans_bas.n_basis_funcs = 20 - - assert orig_bas.n_basis_funcs == 10 - assert trans_bas._basis.n_basis_funcs == 20 - assert trans_bas._basis.n_basis_funcs == 20 - assert isinstance(trans_bas._basis, basis_cls) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_setattr_illegal_attribute(basis_cls, class_specific_params): - # changing an attribute that is not _basis or an attribute of _basis - # is not allowed - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=10 - ) - ) - - with pytest.raises( - ValueError, - match="Only setting _basis or existing attributes of _basis is allowed.", - ): - trans_bas.random_attr = "random value" - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_addition(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - bas_a = CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - bas_b = CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - trans_bas_a = basis.TransformerBasis(bas_a) - trans_bas_b = basis.TransformerBasis(bas_b) - trans_bas_sum = trans_bas_a + trans_bas_b - assert isinstance(trans_bas_sum, basis.TransformerBasis) - assert isinstance(trans_bas_sum._basis, AdditiveBasis) - assert ( - trans_bas_sum.n_basis_funcs - == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_sum._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_multiplication(basis_cls, class_specific_params): - n_basis_funcs_a = 5 - n_basis_funcs_b = n_basis_funcs_a * 2 - trans_bas_a = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_a, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_b = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs_b, basis_cls, class_specific_params, window_size=10 - ) - ) - trans_bas_prod = trans_bas_a * trans_bas_b - assert isinstance(trans_bas_prod, basis.TransformerBasis) - assert isinstance(trans_bas_prod._basis, MultiplicativeBasis) - assert ( - trans_bas_prod.n_basis_funcs - == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - ) - assert ( - trans_bas_prod._n_input_dimensionality - == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality - ) - if basis_cls not in [AdditiveBasis, MultiplicativeBasis]: - assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a - assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer"), - ], -) -def test_transformerbasis_exponentiation( - basis_cls, exponent: int, error_type, error_message, class_specific_params -): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - - if not isinstance(exponent, int): - with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas**exponent - assert isinstance(trans_bas_exp, basis.TransformerBasis) - assert isinstance(trans_bas_exp._basis, MultiplicativeBasis) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -def test_transformerbasis_dir(basis_cls, class_specific_params): - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - 5, basis_cls, class_specific_params, window_size=10 - ) - ) - for attr_name in ( - "fit", - "transform", - "fit_transform", - "n_basis_funcs", - "mode", - "window_size", - ): - if ( - attr_name == "window_size" - and "Conv" not in trans_bas._basis.__class__.__name__ - ): - continue - assert attr_name in dir(trans_bas) - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes("Conv"), -) -def test_transformerbasis_sk_clone_kernel_noned(basis_cls, class_specific_params): - orig_bas = CombinedBasis().instantiate_basis( - 10, basis_cls, class_specific_params, window_size=20 - ) - trans_bas = basis.TransformerBasis(orig_bas) - - # kernel should be saved in the object after fit - trans_bas.fit(np.random.randn(100, 20)) - assert isinstance(trans_bas.kernel_, np.ndarray) - - # cloning should set kernel_ to None - trans_bas_clone = sk_clone(trans_bas) - - # the original object should still have kernel_ - assert isinstance(trans_bas.kernel_, np.ndarray) - # but the clone should not have one - assert trans_bas_clone.kernel_ is None - - -@pytest.mark.parametrize( - "basis_cls", - list_all_basis_classes(), -) -@pytest.mark.parametrize("n_basis_funcs", [5]) -def test_transformerbasis_pickle( - tmpdir, basis_cls, n_basis_funcs, class_specific_params -): - # the test that tries cross-validation with n_jobs = 2 already should test this - trans_bas = basis.TransformerBasis( - CombinedBasis().instantiate_basis( - n_basis_funcs, basis_cls, class_specific_params, window_size=10 - ) - ) - filepath = tmpdir / "transformerbasis.pickle" - with open(filepath, "wb") as f: - pickle.dump(trans_bas, f) - with open(filepath, "rb") as f: - trans_bas2 = pickle.load(f) - - assert isinstance(trans_bas2, basis.TransformerBasis) - if basis_cls in [AdditiveBasis, MultiplicativeBasis]: - for bas in [ - getattr(trans_bas2._basis, attr) for attr in ("_basis1", "_basis2") - ]: - assert bas.n_basis_funcs == n_basis_funcs - else: - assert trans_bas2.n_basis_funcs == n_basis_funcs - - @pytest.mark.parametrize( "tsd", [ @@ -3638,7 +4037,7 @@ def test_multi_epoch_pynapple_basis( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -3652,7 +4051,11 @@ def test_multi_epoch_pynapple_basis( else: nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality @@ -3705,7 +4108,7 @@ def test_multi_epoch_pynapple_basis_transformer( shift, predictor_causality, nan_index, - class_specific_params, + basis_class_specific_params, ): """Test nan location in multi-epoch pynapple tsd.""" kwargs = dict( @@ -3719,18 +4122,22 @@ def test_multi_epoch_pynapple_basis_transformer( nbasis = 5 bas = CombinedBasis().instantiate_basis( - nbasis, basis_cls, class_specific_params, window_size=window_size, **kwargs + nbasis, + basis_cls, + basis_class_specific_params, + window_size=window_size, + **kwargs, ) n_input = bas._n_input_dimensionality - # pass through transformer - bas = basis.TransformerBasis(bas) - # concat input X = pynapple_concatenate_numpy([tsd[:, None]] * n_input, axis=1) # run convolutions + # pass through transformer + bas.set_input_shape(X) + bas = basis.TransformerBasis(bas) res = bas.fit_transform(X) # check nans @@ -3753,18 +4160,18 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__add__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), "3": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, ), @@ -3772,13 +4179,13 @@ def test_multi_epoch_pynapple_basis_transformer( "__add__", "__mul__", lambda bas1, bas2, bas3: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "(2 * 3)": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -3790,11 +4197,11 @@ def test_multi_epoch_pynapple_basis_transformer( # note that it doesn't respect algebra order but execute right to left (first add then multiplies) "(1 * (2 + 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs * ( - bas2._n_basis_input[0] * bas2.n_basis_funcs - + bas3._n_basis_input[0] * bas3.n_basis_funcs + bas2._n_basis_input_[0] * bas2.n_basis_funcs + + bas3._n_basis_input_[0] * bas3.n_basis_funcs ), ), }, @@ -3805,11 +4212,11 @@ def test_multi_epoch_pynapple_basis_transformer( lambda bas1, bas2, bas3: { "(1 * (2 * 3))": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs - * bas3._n_basis_input[0] + * bas3._n_basis_input_[0] * bas3.n_basis_funcs, ), }, @@ -3817,7 +4224,7 @@ def test_multi_epoch_pynapple_basis_transformer( ], ) def test__get_splitter( - bas1, bas2, bas3, operator1, operator2, compute_slice, class_specific_params + bas1, bas2, bas3, operator1, operator2, compute_slice, basis_class_specific_params ): # skip nested if any( @@ -3831,13 +4238,22 @@ def test__get_splitter( combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" + ) + bas1_instance.set_input_shape( + *([n_input_basis[0]] * bas1_instance._n_input_dimensionality) ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis[1]] * bas2_instance._n_input_dimensionality) ) bas3_instance = combine_basis.instantiate_basis( - n_basis[2], bas3, class_specific_params, window_size=10, label="3" + n_basis[2], bas3, basis_class_specific_params, window_size=10, label="3" + ) + bas3_instance.set_input_shape( + *([n_input_basis[2]] * bas3_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator1) @@ -3845,7 +4261,7 @@ def test__get_splitter( bas23 = func2(bas3_instance) bas123 = func1(bas23) inps = [np.zeros((1, n)) if n > 1 else np.zeros((1,)) for n in n_input_basis] - bas123._set_num_output_features(*inps) + bas123.set_input_shape(*inps) splitter_dict, _ = bas123._get_feature_slicing(split_by_input=False) exp_slices = compute_slice(bas1_instance, bas2_instance, bas3_instance) assert exp_slices == splitter_dict @@ -3863,11 +4279,11 @@ def test__get_splitter( 1, 1, lambda bas1, bas2: { - "1": slice(0, bas1._n_basis_input[0] * bas1.n_basis_funcs), + "1": slice(0, bas1._n_basis_input_[0] * bas1.n_basis_funcs), "2": slice( - bas1._n_basis_input[0] * bas1.n_basis_funcs, - bas1._n_basis_input[0] * bas1.n_basis_funcs - + bas2._n_basis_input[0] * bas2.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs, + bas1._n_basis_input_[0] * bas1.n_basis_funcs + + bas2._n_basis_input_[0] * bas2.n_basis_funcs, ), }, ), @@ -3878,9 +4294,9 @@ def test__get_splitter( lambda bas1, bas2: { "(1 * 2)": slice( 0, - bas1._n_basis_input[0] + bas1._n_basis_input_[0] * bas1.n_basis_funcs - * bas2._n_basis_input[0] + * bas2._n_basis_input_[0] * bas2.n_basis_funcs, ) }, @@ -3905,7 +4321,7 @@ def test__get_splitter( 1, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas1._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas1._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -3932,7 +4348,7 @@ def test__get_splitter( 2, lambda bas1, bas2: { "(1 * 2)": slice( - 0, bas2._n_basis_input[0] * bas1.n_basis_funcs * bas2.n_basis_funcs + 0, bas2._n_basis_input_[0] * bas1.n_basis_funcs * bas2.n_basis_funcs ) }, ), @@ -3974,7 +4390,7 @@ def test__get_splitter_split_by_input( n_input_basis_1, n_input_basis_2, compute_slice, - class_specific_params, + basis_class_specific_params, ): # skip nested if any( @@ -3986,10 +4402,17 @@ def test__get_splitter_split_by_input( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) + bas1_instance.set_input_shape( + *([n_input_basis_1] * bas1_instance._n_input_dimensionality) + ) + bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" + ) + bas2_instance.set_input_shape( + *([n_input_basis_2] * bas1_instance._n_input_dimensionality) ) func1 = getattr(bas1_instance, operator) @@ -3999,7 +4422,7 @@ def test__get_splitter_split_by_input( np.zeros((1, n)) if n > 1 else np.zeros((1,)) for n in (n_input_basis_1, n_input_basis_2) ] - bas12._set_num_output_features(*inps) + bas12.set_input_shape(*inps) splitter_dict, _ = bas12._get_feature_slicing() exp_slices = compute_slice(bas1_instance, bas2_instance) assert exp_slices == splitter_dict @@ -4009,7 +4432,7 @@ def test__get_splitter_split_by_input( "bas1, bas2, bas3", list(itertools.product(*[list_all_basis_classes()] * 3)), ) -def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): +def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params): # skip nested if any( bas in (AdditiveBasis, MultiplicativeBasis, basis.TransformerBasis) @@ -4019,18 +4442,18 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - 5, bas1, class_specific_params, window_size=10, label="label" + 5, bas1, basis_class_specific_params, window_size=10, label="label" ) bas2_instance = combine_basis.instantiate_basis( - 5, bas2, class_specific_params, window_size=10, label="label" + 5, bas2, basis_class_specific_params, window_size=10, label="label" ) bas3_instance = combine_basis.instantiate_basis( - 5, bas3, class_specific_params, window_size=10, label="label" + 5, bas3, basis_class_specific_params, window_size=10, label="label" ) bas_obj = bas1_instance + bas2_instance + bas3_instance inps = [np.zeros((1,)) for n in range(3)] - bas_obj._set_num_output_features(*inps) + bas_obj.set_input_shape(*inps) slice_dict = bas_obj._get_feature_slicing()[0] assert tuple(slice_dict.keys()) == ("label", "label-1", "label-2") @@ -4056,7 +4479,7 @@ def test_duplicate_keys(bas1, bas2, bas3, class_specific_params): ], ) def test_split_feature_axis( - bas1, bas2, x, axis, expectation, exp_shapes, class_specific_params + bas1, bas2, x, axis, expectation, exp_shapes, basis_class_specific_params ): # skip nested if any( @@ -4068,14 +4491,14 @@ def test_split_feature_axis( n_basis = [5, 6] combine_basis = CombinedBasis() bas1_instance = combine_basis.instantiate_basis( - n_basis[0], bas1, class_specific_params, window_size=10, label="1" + n_basis[0], bas1, basis_class_specific_params, window_size=10, label="1" ) bas2_instance = combine_basis.instantiate_basis( - n_basis[1], bas2, class_specific_params, window_size=10, label="2" + n_basis[1], bas2, basis_class_specific_params, window_size=10, label="2" ) bas = bas1_instance + bas2_instance - bas._set_num_output_features(np.zeros((1, 2)), np.zeros((1, 3))) + bas.set_input_shape(np.zeros((1, 2)), np.zeros((1, 3))) with expectation: out = bas.split_by_feature(x, axis=axis) for i, itm in enumerate(out.items()): diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py index ca4f4be2..f40ad214 100644 --- a/tests/test_identifiability_constraints.py +++ b/tests/test_identifiability_constraints.py @@ -190,6 +190,7 @@ def test_feature_matrix_dtype(dtype, expected_dtype): ) def test_apply_constraint_with_invalid(invalid_entries): """Test if the matrix retains its dtype after applying constraints.""" + jax.config.update("jax_enable_x64", True) x = np.random.randn(10, 5) # add invalid x[:2, 2] = invalid_entries diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5e4ce13d..9e52a4f2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -21,7 +21,7 @@ ) def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("eval", bas), ("fit", model)]) pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y) @@ -39,7 +39,7 @@ def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation): ) def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") @@ -60,7 +60,7 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( bas, poissonGLM_model_instantiation ): X, y, model, _, _ = poissonGLM_model_instantiation - bas = TransformerBasis(bas) + bas = TransformerBasis(bas.set_input_shape(*([1] * bas._n_input_dimensionality))) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) gridsearch = GridSearchCV( @@ -86,8 +86,15 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) - param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) + param_grid = dict( + transformerbasis___basis=( + bas_cls(5).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(10).set_input_shape(*([1] * bas._n_input_dimensionality)), + bas_cls(20).set_input_shape(*([1] * bas._n_input_dimensionality)), + ) + ) gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score="raise") gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -107,6 +114,7 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( ): X, y, model, _, _ = poissonGLM_model_instantiation bas = TransformerBasis(bas_cls(5)) + bas.set_input_shape(*([1] * bas._n_input_dimensionality)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), @@ -165,6 +173,7 @@ def test_sklearn_transformer_pipeline_pynapple( ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]]) X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep) y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep) + bas = bas.set_input_shape(*([1] * bas._n_input_dimensionality)) bas = TransformerBasis(bas) # fit a pipeline & predict from pynapple pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])