diff --git a/docs/assets/nemos_sklearn.svg b/docs/assets/nemos_sklearn.svg
new file mode 100644
index 00000000..8ea0a3e3
--- /dev/null
+++ b/docs/assets/nemos_sklearn.svg
@@ -0,0 +1,119 @@
+
+
+
+
diff --git a/docs/assets/pipeline.svg b/docs/assets/pipeline.svg
index 8c67c7ab..a38b6480 100644
--- a/docs/assets/pipeline.svg
+++ b/docs/assets/pipeline.svg
@@ -24,12 +24,12 @@
inkscape:deskcolor="#d1d1d1"
inkscape:document-units="mm"
inkscape:zoom="1.4142136"
- inkscape:cx="289.20667"
- inkscape:cy="-54.800776"
- inkscape:window-width="2488"
- inkscape:window-height="1262"
+ inkscape:cx="206.82873"
+ inkscape:cy="8.4852811"
+ inkscape:window-width="1800"
+ inkscape:window-height="1035"
inkscape:window-x="0"
- inkscape:window-y="25"
+ inkscape:window-y="44"
inkscape:window-maximized="0"
inkscape:current-layer="layer1" />
+ y="49.623287" />
Pipeline
+ x="26.058722"
+ y="84.056198">Pipeline
-
+
```{toctree}
:maxdepth: 2
-plot_05_sklearn_pipeline_cv_demo.md
+plot_05_transformer_basis.md
```
:::
:::{grid-item-card}
```{toctree}
:maxdepth: 2
-plot_06_glm_pytree.md
+plot_06_sklearn_pipeline_cv_demo.md
```
+
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_07_glm_pytree.md
+```
+
:::
::::
diff --git a/docs/how_to_guide/plot_05_transformer_basis.md b/docs/how_to_guide/plot_05_transformer_basis.md
new file mode 100644
index 00000000..9b8fe6b6
--- /dev/null
+++ b/docs/how_to_guide/plot_05_transformer_basis.md
@@ -0,0 +1,183 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+# Converting NeMoS Bases To scikit-learn Transformers
+
+(tansformer-vs-nemos-basis)=
+## scikit-learn Transformers and NeMoS Basis
+
+`scikit-learn` is a powerful machine learning library that provides advanced tools for creating data analysis pipelines, from input transformations to model fitting and cross-validation.
+
+All of `scikit-learn`'s machinery relies on strict assumptions about input structure. In particular, all `scikit-learn`
+objects require inputs as arrays of at most two dimensions, where the first dimension represents the time (or samples)
+axis, and the second dimension represents features.
+While this may feel rigid, it enables transformations to be seamlessly chained together, greatly simplifying the
+process of building stable, complex pipelines.
+
+On the other hand, `NeMoS` takes a different approach to feature construction. `NeMoS`' bases are composable constructors that allow for more flexibility in the required input structure.
+Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can take any shape as long as the time (or sample) axis is the first of each array;
+`NeMoS` design favours object composability: one can combine any two or more bases to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.).
+
+Both approaches to data transformation are valuable and each has its own advantages. Wouldn't it be great if one could combine the two? Well, this is what NeMoS `TransformerBasis` is for!
+
+
+## From Basis to TransformerBasis
+
+
+With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`.
+
+```{code-cell} ipython3
+import numpy as np
+import nemos as nmo
+
+# create the arrays
+n_samples, n_neurons = 100, 5
+counts = np.random.poisson(size=(100, 5))
+speed = np.random.normal(size=(100))
+
+# create a composite basis
+counts_basis = nmo.basis.RaisedCosineLogConv(5, window_size=10)
+speed_basis = nmo.basis.BSplineEval(5)
+composite_basis = counts_basis + speed_basis
+
+# compute the features
+X = composite_basis.compute_features(counts, speed)
+
+```
+
+### Converting NeMoS `Basis` to a transformer
+
+Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline.
+In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a compliant `scikit-learn` transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class.
+
+Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either by using the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):
+
+
+```{code-cell} ipython3
+bas = nmo.basis.RaisedCosineLinearConv(5, window_size=5)
+
+# initalize using the constructor
+trans_bas = nmo.basis.TransformerBasis(bas)
+
+# equivalent initialization via "to_transformer"
+trans_bas = bas.to_transformer()
+
+```
+
+[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes:
+
+
+```{code-cell} ipython3
+print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
+```
+
+We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), nor does changing the original [`Basis`](nemos.basis._basis.Basis) modify the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created:
+
+
+```{code-cell} ipython3
+trans_bas.n_basis_funcs = 10
+bas.n_basis_funcs = 100
+
+print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
+```
+
+As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore they all accept a single 2D array.
+
+## Setting up the TransformerBasis
+
+At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right?
+
+```{code-cell} ipython3
+
+# reinstantiate the basis transformer for illustration porpuses
+composite_basis = counts_basis + speed_basis
+trans_bas = (composite_basis).to_transformer()
+
+# concatenate the inputs
+inp = np.concatenate([counts, speed[:, np.newaxis]], axis=1)
+print(inp.shape)
+
+try:
+ trans_bas.fit_transform(inp)
+except RuntimeError as e:
+ print(repr(e))
+```
+
+...Unfortunately, not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`.
+
+You can provide this information by calling the `set_input_shape` method of the basis.
+
+This can be called before or after the transformer basis is defined. The method extracts and stores the array shapes excluding the sample axis (which won't be affected in the concatenation).
+
+`set_input_shape` directly accepts the inputs:
+
+```{code-cell} ipython3
+
+composite_basis.set_input_shape(counts, speed)
+out = composite_basis.to_transformer().fit_transform(inp)
+```
+
+If the input is 1D or 2D, it also accepts the number of columns:
+```{code-cell} ipython3
+
+composite_basis.set_input_shape(5, 1)
+out = composite_basis.to_transformer().fit_transform(inp)
+```
+
+A tuple containing the shapes of all the axes other than the first,
+```{code-cell} ipython3
+
+composite_basis.set_input_shape((5,), (1,))
+out = composite_basis.to_transformer().fit_transform(inp)
+```
+
+Or a mix of the above.
+```{code-cell} ipython3
+
+composite_basis.set_input_shape(counts, 1)
+out = composite_basis.to_transformer().fit_transform(inp)
+```
+
+You can also invert the order of operations and call `to_transform` first and then set the input shapes.
+```{code-cell} ipython3
+
+trans_bas = composite_basis.to_transformer()
+trans_bas.set_input_shape(5, 1)
+out = trans_bas.fit_transform(inp)
+```
+
+:::{note}
+
+If you define a basis and call `compute_features` on your inputs, internally, it will store its shapes,
+and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`.
+:::
+
+If for some reason you need to provide an input of different shape to an already set-up transformer, you must reset the
+`TransformerBasis` with `set_input_shape`.
+
+```{code-cell} ipython3
+
+# define inputs with different shapes and concatenate
+x, y = np.random.poisson(size=(10, 3)), np.random.randn(10, 2, 3)
+inp2 = np.concatenate([x, y.reshape(10, 6)], axis=1)
+
+trans_bas = composite_basis.to_transformer()
+trans_bas.set_input_shape(3, (2, 3))
+out2 = trans_bas.fit_transform(inp2)
+```
+
+
+### Learn more
+
+If you want to learn more about how to select basis' hyperparameters with `sklearn` pipelining and cross-validation, check out [this how-to guide](sklearn-how-to).
+
diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md
similarity index 92%
rename from docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md
rename to docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md
index 9f5a9652..2afa68e8 100644
--- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md
+++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md
@@ -71,7 +71,8 @@ To set up a scikit-learn [`Pipeline`](https://scikit-learn.org/1.5/modules/gener
Each transformation step takes a 2D array `X` of shape `(num_samples, num_original_features)` as input and outputs another 2D array of shape `(num_samples, num_transformed_features)`. The final step takes a pair `(X, y)`, where `X` is as before, and `y` is a 1D array of shape `(n_samples,)` containing the observations to be modeled.
You can define a pipeline as follows:
-```python
+
+```{code} ipython3
from sklearn.pipeline import Pipeline
# Assume transformer_i/predictor is a transformer/model object
@@ -92,7 +93,7 @@ Here we used a placeholder `"label_i"` for demonstration; you should choose a mo
:::
Calling `pipe.fit(X, y)` will perform the following computations:
-```python
+```{code} ipython3
# Chain of transformations
X1 = transformer_1.fit_transform(X)
X2 = transformer_2.fit_transform(X1)
@@ -111,6 +112,7 @@ Pipelines not only streamline and simplify your code but also offer several othe
In the following sections, we will showcase this approach with a concrete example: selecting the appropriate basis type and number of bases for a GLM regression in NeMoS.
## Combining basis transformations and GLM in a pipeline
+
Let's start by creating some toy data.
@@ -150,9 +152,7 @@ sns.despine(ax=ax)
```
### Converting NeMoS `Basis` to a transformer
-In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class.
-
-Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):
+In order to use NeMoS [`Basis`](nemos.basis._basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer.
```{code-cell} ipython3
@@ -164,24 +164,15 @@ trans_bas = nmo.basis.TransformerBasis(bas)
# equivalent initialization via "to_transformer"
trans_bas = bas.to_transformer()
+# setup the transformer
+trans_bas.set_input_shape(1)
```
-[`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis._basis.Basis) object's attributes:
-
-
-```{code-cell} ipython3
-print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
-```
-
-We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created:
+:::{admonition} Learn More about `TransformerBasis`
+:note:
-
-```{code-cell} ipython3
-trans_bas.n_basis_funcs = 10
-bas.n_basis_funcs = 100
-
-print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
-```
+To learn more about `sklearn` transformers and `TransforerBasis`, check out [this note](tansformer-vs-nemos-basis).
+:::
### Creating and fitting a pipeline
We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data.
@@ -194,7 +185,7 @@ pipeline = Pipeline(
[
(
"transformerbasis",
- nmo.basis.RaisedCosineLinearEval(6).to_transformer(),
+ nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1).to_transformer(),
),
(
"glm",
@@ -311,7 +302,7 @@ gridsearch.fit(X, y)
To appreciate how much boiler-plate code we are saving by calling scikit-learn cross-validation, below
we can see how this cross-validation will look like in a manual loop.
-```python
+```{code} ipython
from itertools import product
from copy import deepcopy
@@ -439,7 +430,7 @@ if root or Path("../assets/stylesheets").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
- fig.savefig(path / "plot_05_sklearn_pipeline_cv_demo.svg")
+ fig.savefig(path / "plot_06_sklearn_pipeline_cv_demo.svg")
```
🚀🚀🚀 **Success!** 🚀🚀🚀
@@ -457,12 +448,12 @@ Here we include `transformerbasis___basis` in the parameter grid to try differen
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis___basis=(
- nmo.basis.RaisedCosineLinearEval(5),
- nmo.basis.RaisedCosineLinearEval(10),
- nmo.basis.RaisedCosineLogEval(5),
- nmo.basis.RaisedCosineLogEval(10),
- nmo.basis.MSplineEval(5),
- nmo.basis.MSplineEval(10),
+ nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1),
+ nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1),
+ nmo.basis.RaisedCosineLogEval(5).set_input_shape(1),
+ nmo.basis.RaisedCosineLogEval(10).set_input_shape(1),
+ nmo.basis.MSplineEval(5).set_input_shape(1),
+ nmo.basis.MSplineEval(10).set_input_shape(1),
),
)
```
@@ -538,7 +529,7 @@ The plot confirms that the firing rate distribution is accurately captured by ou
:::{warning}
Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error:
-```python
+```{code} ipython
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_07_glm_pytree.md
similarity index 99%
rename from docs/how_to_guide/plot_06_glm_pytree.md
rename to docs/how_to_guide/plot_07_glm_pytree.md
index d6608c45..f980e75c 100644
--- a/docs/how_to_guide/plot_06_glm_pytree.md
+++ b/docs/how_to_guide/plot_07_glm_pytree.md
@@ -265,7 +265,7 @@ if root or Path("../assets/stylesheets").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
- fig.savefig(path / "plot_06_glm_pytree.svg")
+ fig.savefig(path / "plot_07_glm_pytree.svg")
```
diff --git a/docs/index.md b/docs/index.md
index 4da5ef06..0b61208c 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -20,7 +20,7 @@ For Developers
```
-## __Neural ModelS__
+# __Neural ModelS__
NeMoS (Neural ModelS) is a statistical modeling framework optimized for systems neuroscience and powered by [JAX](https://jax.readthedocs.io/en/latest/).
diff --git a/docs/quickstart.md b/docs/quickstart.md
index bdf3ffd4..f071839f 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -1,7 +1,18 @@
---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
hide:
- navigation
---
+
# Quickstart
## **Overview**
@@ -29,58 +40,56 @@ NeMoS provides two implementations of the GLM: one for fitting a single neuron,
You can define a single neuron GLM by instantiating an `GLM` object.
-```python
+```{code-cell} ipython3
->>> import nemos as nmo
+import nemos as nmo
->>> # Instantiate the single model
->>> model = nmo.glm.GLM()
+# Instantiate the single model
+model = nmo.glm.GLM()
```
The coefficients can be learned by invoking the `fit` method of `GLM`. The method requires a design
matrix of shape `(num_samples, num_features)`, and the output neural activity of shape `(num_samples, )`.
-```python
+```{code-cell} ipython3
->>> import numpy as np
->>> num_samples, num_features = 100, 3
+import numpy as np
+num_samples, num_features = 100, 3
->>> # Generate a design matrix
->>> X = np.random.normal(size=(num_samples, num_features))
->>> # generate some counts
->>> spike_counts = np.random.poisson(size=num_samples)
+# Generate a design matrix
+X = np.random.normal(size=(num_samples, num_features))
+# generate some counts
+spike_counts = np.random.poisson(size=num_samples)
->>> # define fit the model
->>> model = model.fit(X, spike_counts)
+# define fit the model
+model = model.fit(X, spike_counts)
```
Once the model is fit, you can retrieve the model parameters as shown below.
-```python
->>> # model coefficients shape is (num_features, )
->>> print(f"Model coefficients shape: {model.coef_.shape}")
-Model coefficients shape: (3,)
+```{code-cell} ipython3
+# model coefficients shape is (num_features, )
+print(f"Model coefficients shape: {model.coef_.shape}")
->>> # model intercept, shape (1,) since there is only one neuron.
->>> print(f"Model intercept shape: {model.intercept_.shape}")
-Model intercept shape: (1,)
+# model intercept, shape (1,) since there is only one neuron.
+print(f"Model intercept shape: {model.intercept_.shape}")
```
Additionally, you can predict the firing rate and call the compute the model log-likelihood by calling the `predict` and the `score` method respectively.
-```python
+```{code-cell} ipython3
+
+# predict the rate
+predicted_rate = model.predict(X)
+# firing rate has shape: (num_samples,)
+predicted_rate.shape
->>> # predict the rate
->>> predicted_rate = model.predict(X)
->>> # firing rate has shape: (num_samples,)
->>> predicted_rate.shape
-(100,)
->>> # compute the log-likelihood of the model
->>> log_likelihood = model.score(X, spike_counts)
+# compute the log-likelihood of the model
+log_likelihood = model.score(X, spike_counts)
```
@@ -88,49 +97,47 @@ Additionally, you can predict the firing rate and call the compute the model log
You can set up a population GLM by instantiating a `PopulationGLM`. The API for the `PopulationGLM` is the same as for the single-neuron `GLM`; the only difference you'll notice is that some of the methods' inputs and outputs have an additional dimension for the different neurons.
-```python
+```{code-cell}
->>> import nemos as nmo
->>> population_model = nmo.glm.PopulationGLM()
+import nemos as nmo
+population_model = nmo.glm.PopulationGLM()
```
As for the single neuron GLM, you can learn the model parameters by invoking the `fit` method: the input of `fit` are the design matrix (with shape `(num_samples, num_features)` ), and the population activity (with shape `(num_samples, num_neurons)`).
Once the model is fit, you can use `predict` and `score` to predict the firing rate and the log-likelihood.
-```python
+```{code-cell}
->>> import numpy as np
->>> num_samples, num_features, num_neurons = 100, 3, 5
+import numpy as np
+num_samples, num_features, num_neurons = 100, 3, 5
->>> # simulate a design matrix
->>> X = np.random.normal(size=(num_samples, num_features))
->>> # simulate some counts
->>> spike_counts = np.random.poisson(size=(num_samples, num_neurons))
+# simulate a design matrix
+X = np.random.normal(size=(num_samples, num_features))
+# simulate some counts
+spike_counts = np.random.poisson(size=(num_samples, num_neurons))
->>> # fit the model
->>> population_model = population_model.fit(X, spike_counts)
+# fit the model
+population_model = population_model.fit(X, spike_counts)
->>> # predict the rate of each neuron in the population
->>> predicted_rate = population_model.predict(X)
->>> predicted_rate.shape # expected shape: (num_samples, num_neurons)
-(100, 5)
+# predict the rate of each neuron in the population
+predicted_rate = population_model.predict(X)
+predicted_rate.shape # expected shape: (num_samples, num_neurons)
->>> # compute the log-likelihood of the model
->>> log_likelihood = population_model.score(X, spike_counts)
+
+# compute the log-likelihood of the model
+log_likelihood = population_model.score(X, spike_counts)
```
The learned coefficient and intercept will have shape `(num_features, num_neurons)` and `(num_neurons, )` respectively.
-```python
->>> # model coefficients shape is (num_features, num_neurons)
->>> print(f"Model coefficients shape: {population_model.coef_.shape}")
-Model coefficients shape: (3, 5)
+```{code-cell}
+# model coefficients shape is (num_features, num_neurons)
+print(f"Model coefficients shape: {population_model.coef_.shape}")
->>> # model intercept, (num_neurons,)
->>> print(f"Model intercept shape: {population_model.intercept_.shape}")
-Model intercept shape: (5,)
+# model intercept, (num_neurons,)
+print(f"Model intercept shape: {population_model.intercept_.shape}")
```
@@ -160,12 +167,12 @@ The `basis` module includes objects that perform two types of transformations on
Non-linear mapping is the default mode of operation of any `basis` object. To instantiate a basis for non-linear mapping,
you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Reference](nemos_basis) for detailed information).
-```python
+```{code-cell}
->>> import nemos as nmo
+import nemos as nmo
->>> n_basis_funcs = 10
->>> basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs)
+n_basis_funcs = 10
+basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs)
```
@@ -173,17 +180,16 @@ Once the basis is instantiated, you can apply it to your input data using the `c
This method takes an input array of shape `(n_samples, )` and transforms it into a two-dimensional array of
shape `(n_samples, n_basis_funcs)`, where each column represents a feature generated by the non-linear mapping.
-```python
+```{code-cell}
->>> import numpy as np
+import numpy as np
->>> # generate an input
->>> x = np.arange(100)
+# generate an input
+x = np.arange(100)
->>> # evaluate the basis
->>> X = basis.compute_features(x)
->>> X.shape
-(100, 10)
+# evaluate the basis
+X = basis.compute_features(x)
+X.shape
```
@@ -199,13 +205,13 @@ If you want to convolve a bank of basis functions with an input you must set the
`"conv"` and you must provide an integer `window_size` parameter, which defines the length of the filter bank in
number of sample points.
-```python
+```{code-cell} ipython3
->>> import nemos as nmo
+import nemos as nmo
->>> n_basis_funcs = 10
->>> # define a filter bank of 10 basis function, 200 samples long.
->>> basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200)
+n_basis_funcs = 10
+# define a filter bank of 10 basis function, 200 samples long.
+basis = nmo.basis.BSplineConv(n_basis_funcs, window_size=200)
```
@@ -219,23 +225,21 @@ Once the basis is initialized, you can call `compute_features` on an input of sh
The `window_size` must be shorter than the number of samples in the signal(s) being convolved.
:::
-```python
+```{code-cell} ipython3
->>> import numpy as np
+import numpy as np
->>> x = np.ones(500)
+x = np.ones(500)
->>> # convolve a single signal
->>> X = basis.compute_features(x)
->>> X.shape
-(500, 10)
+# convolve a single signal
+X = basis.compute_features(x)
+X.shape
->>> x_multi = np.ones((500, 3))
+x_multi = np.ones((500, 3))
->>> # convolve a multiple signals
->>> X_multi = basis.compute_features(x_multi)
->>> X_multi.shape
-(500, 30)
+# convolve a multiple signals
+X_multi = basis.set_input_shape(3).compute_features(x_multi)
+X_multi.shape
```
@@ -249,12 +253,12 @@ By default, NeMoS' GLM uses [Poisson observations](nemos.observation_models.Pois
To change the default observation model, set the `observation_model` argument during initialization:
-```python
+```{code-cell} ipython3
->>> import nemos as nmo
+import nemos as nmo
->>> # set up a Gamma GLM for modeling continuous non-negative data
->>> glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations())
+# set up a Gamma GLM for modeling continuous non-negative data
+glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations())
```
@@ -270,12 +274,12 @@ NeMoS supports various regularization schemes, including [Ridge](nemos.regulariz
You can specify the regularization scheme and its strength when initializing the GLM model:
-```python
+```{code-cell} ipython3
->>> import nemos as nmo
+import nemos as nmo
->>> # Instantiate a GLM with Ridge (L2) regularization
->>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
+# Instantiate a GLM with Ridge (L2) regularization
+glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
```
@@ -301,25 +305,26 @@ also be a `pynapple` time series.
A canonical example of this behavior is the `predict` method of `GLM`.
-```ipython
+```{code-cell} ipython3
+
+import numpy as np
+import pynapple as nap
->>> import numpy as np
->>> import pynapple as nap
+# suppress jax to numpy conversion warning
+nap.nap_config.suppress_conversion_warnings = True
->>> # create a TsdFrame with the features and a Tsd with the counts
->>> X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2)))
->>> y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, )))
+# create a TsdFrame with the features and a Tsd with the counts
+X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2)))
+y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, )))
->>> print(type(X)) # shape (num samples, num features)
-
+print(type(X)) # shape (num samples, num features)
->>> model = model.fit(X, y) # the following works
+model = model.fit(X, y) # the following works
->>> firing_rate = model.predict(X) # predict the firing rate of the neuron
+firing_rate = model.predict(X) # predict the firing rate of the neuron
->>> # this will still be a pynapple time series
->>> print(type(firing_rate)) # shape (num_samples, )
-
+# this will still be a pynapple time series
+print(type(firing_rate)) # shape (num_samples, )
```
@@ -331,29 +336,29 @@ Let's see how you can greatly streamline your analysis pipeline by integrating `
You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1).
:::
-```ipython
+```{code-cell} ipython3
->>> import nemos as nmo
->>> import pynapple as nap
+import nemos as nmo
+import pynapple as nap
->>> path = nmo.fetch.fetch_data("A2929-200711.nwb")
->>> data = nap.load_file(path)
+path = nmo.fetch.fetch_data("A2929-200711.nwb")
+data = nap.load_file(path)
->>> # load spikes and head direction
->>> spikes = data["units"]
->>> head_dir = data["ry"]
+# load spikes and head direction
+spikes = data["units"]
+head_dir = data["ry"]
->>> # restrict and bin
->>> counts = spikes[6].count(0.01, ep=head_dir.time_support)
+# restrict and bin
+counts = spikes[6].count(0.01, ep=head_dir.time_support)
->>> # down-sample head direction
->>> upsampled_head_dir = head_dir.bin_average(0.01)
+# down-sample head direction
+upsampled_head_dir = head_dir.bin_average(0.01)
->>> # create your features
->>> X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir)
+# create your features
+X = nmo.basis.CyclicBSplineEval(10).compute_features(upsampled_head_dir)
->>> # add a neuron axis and fit model
->>> model = nmo.glm.GLM().fit(X, counts)
+# add a neuron axis and fit model
+model = nmo.glm.GLM().fit(X, counts)
```
@@ -361,35 +366,31 @@ You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oa
Finally, let's compare the tuning curves
-```ipython
+```{code-cell} ipython3
->>> import numpy as np
->>> import matplotlib.pyplot as plt
+import numpy as np
+import matplotlib.pyplot as plt
->>> # tuning curves
->>> raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6]
+# tuning curves
+raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6]
->>> # model based tuning curve
->>> model_tuning = nap.compute_1d_tuning_curves_continuous(
-... model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate
-... head_dir,
-... nb_bins=100
-... )[0]
+# model based tuning curve
+model_tuning = nap.compute_1d_tuning_curves_continuous(
+ model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate
+ head_dir,
+ nb_bins=100
+ )[0]
->>> # plot results
->>> sub = plt.subplot(111, projection="polar")
->>> plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw")
->>> plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm")
->>> legend = plt.yticks([])
->>> xlab = plt.xlabel("heading angle")
+# plot results
+sub = plt.subplot(111, projection="polar")
+plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw")
+plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm")
+legend = plt.yticks([])
+xlab = plt.xlabel("heading angle")
```
-
-
-
-
## **Compatibility with `scikit-learn`**
@@ -400,34 +401,34 @@ For example, if we would like to tune the critical hyper-parameter `regularizer_
[^1]: For a detailed explanation and practical examples, refer to the [cross-validation page](https://scikit-learn.org/stable/modules/cross_validation.html) in the `scikit-learn` documentation.
-```ipython
+```{code-cell} ipython3
->>> # set up the model
->>> import nemos as nmo
->>> import numpy as np
+# set up the model
+import nemos as nmo
+import numpy as np
->>> # generate data
->>> X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100)
+# generate data
+X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100)
->>> # model definition
->>> model = nmo.glm.GLM(regularizer="Ridge")
+# model definition
+model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
```
Fit a 5-fold cross-validation scheme for comparing two different regularizer strengths:
-```ipython
+```{code-cell} ipython3
->>> from sklearn.model_selection import GridSearchCV
+from sklearn.model_selection import GridSearchCV
->>> # define the parameter grid
->>> param_grid = dict(regularizer_strength=(0.01, 0.001))
+# define the parameter grid
+param_grid = dict(regularizer_strength=(0.01, 0.001))
->>> # define the 5-fold cross-validation grid search from sklearn
->>> cls = GridSearchCV(model, param_grid=param_grid, cv=5)
+# define the 5-fold cross-validation grid search from sklearn
+cls = GridSearchCV(model, param_grid=param_grid, cv=5)
->>> # run the 5-fold cross-validation grid search
->>> cls = cls.fit(X, counts)
+# run the 5-fold cross-validation grid search
+cls = cls.fit(X, counts)
```
@@ -440,11 +441,10 @@ For more information and a practical example on how to construct a parameter gri
Finally, we can print the regularizer strength with the best cross-validated performance:
-```ipython
+```{code-cell} ipython3
->>> # print best regularizer strength
->>> print(cls.best_params_)
-{'regularizer_strength': 0.01}
+# print best regularizer strength
+print(cls.best_params_)
```
diff --git a/pyproject.toml b/pyproject.toml
index d20fd307..f57d3238 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -125,8 +125,12 @@ profile = "black"
testpaths = ["tests"] # Specify the directory where test files are located
filterwarnings = [
# note the use of single quote below to denote "raw" strings in TOML
+ # this is raised whenever one imports the plotting utils
'ignore:plotting functions contained within:UserWarning',
- 'ignore:Tolerance of \d\.\d+e-\d\d reached:RuntimeWarning',
+ # numerical inversion test reaches tolerance...
+ 'ignore:Tolerance of -?\d\.\d+e-\d\d reached:RuntimeWarning',
+ # mpl must be non-interctive for testing otherwise doctests will freeze
+ 'ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning',
]
[tool.coverage.run]
diff --git a/src/nemos/basis/_basis.py b/src/nemos/basis/_basis.py
index 46d9f5a5..b6a66a56 100644
--- a/src/nemos/basis/_basis.py
+++ b/src/nemos/basis/_basis.py
@@ -146,8 +146,6 @@ def __init__(
else:
self._label = str(label)
- self._check_n_basis_min()
-
# specified only after inputs/input shapes are provided
self._n_basis_input_ = getattr(self, "_n_basis_input_", None)
self._input_shape_ = getattr(self, "_input_shape_", None)
@@ -278,12 +276,13 @@ def _set_input_independent_states(self):
"""
Compute all the basis states that do not depend on the input.
- An example of such state is the kernel_ for Conv baisis, which can be computed
+ An example of such state is the kernel_ for Conv bases, which can be computed
without any input (it only depends on the basis type, the window size and the
number of basis elements).
"""
pass
+ @abc.abstractmethod
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Set the expected input shape for the basis object.
@@ -293,54 +292,8 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
from an array. The method also calculates the total number of input
features and output features based on the number of basis functions.
- Parameters
- ----------
- xi :
- The input shape specification.
- - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input.
- - A tuple: Represents the exact input shape excluding the first axis (sample axis).
- All elements must be integers.
- - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
-
- Raises
- ------
- ValueError
- If a tuple is provided and it contains non-integer elements.
-
- Returns
- -------
- self :
- Returns the instance itself to allow method chaining.
-
- Notes
- -----
- All state attributes that depends on the input must be set in this method in order for
- the API of basis to work correctly. In particular, this method is called by ``_basis_fit``,
- which is equivalent to ``fit`` for a transformer. If any input dependent state
- is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break.
-
- Separating states related to the input (settable with this method) and states that are unrelated
- from the input (settable with ``set_kernel`` for Conv bases) is a deliberate design choice
- that improves modularity.
-
"""
- if isinstance(xi, tuple):
- if not all(isinstance(i, int) for i in xi):
- raise ValueError(
- f"The tuple provided contains non integer values. Tuple: {xi}."
- )
- shape = xi
- elif isinstance(xi, int):
- shape = () if xi == 1 else (xi,)
- else:
- shape = xi.shape[1:]
-
- n_inputs = (int(np.prod(shape)),)
-
- self._input_shape_ = shape
-
- self._n_basis_input_ = n_inputs
- return self
+ pass
@abc.abstractmethod
def _evaluate(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor) -> FeatureMatrix:
@@ -510,20 +463,6 @@ def _check_samples_consistency(*xi: NDArray) -> None:
"Sample size mismatch. Input elements have inconsistent sample sizes."
)
- @abc.abstractmethod
- def _check_n_basis_min(self) -> None:
- """Check that the user required enough basis elements.
-
- Most of the basis work with at least 1 element, but some
- such as the RaisedCosineBasisLog requires a minimum of 2 basis to be well defined.
-
- Raises
- ------
- ValueError
- If an insufficient number of basis element is requested for the basis type
- """
- pass
-
def __add__(self, other: Basis) -> AdditiveBasis:
"""
Add two Basis objects together.
@@ -625,8 +564,7 @@ def _get_feature_slicing(
_get_default_slicing : Handles default slicing logic.
_merge_slicing_dicts : Merges multiple slicing dictionaries, handling keys conflicts.
"""
- # Set default values for n_inputs and start_slice if not provided
- n_inputs = n_inputs or self._n_basis_input_
+ # Set default values for start_slice if not provided
start_slice = start_slice or 0
# Handle the default case for non-additive basis types
# See overwritten method for recursion logic
@@ -772,72 +710,6 @@ def is_leaf(val):
)
return reshaped_out
- def _check_input_shape_consistency(self, x: NDArray):
- """Check input consistency across calls."""
- # remove sample axis and squeeze
- shape = x.shape[1:]
-
- initialized = self._input_shape_ is not None
- is_shape_match = self._input_shape_ == shape
- if initialized and not is_shape_match:
- expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:]
- expected_shape_str = expected_shape_str.replace(",)", ")")
- raise ValueError(
- f"Input shape mismatch detected.\n\n"
- f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with "
- f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n"
- f" Expected: {expected_shape_str}\n"
- f" But got: {x.shape}.\n\n"
- "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, "
- "but all other dimensions must remain the same. If you need to process inputs with a "
- "different shape, please create a new basis instance."
- )
-
- def _list_components(self):
- """List all basis components.
-
- This is re-implemented for composite basis in the mixin class.
-
- Returns
- -------
- A list with all 1d basis components.
-
- Raises
- ------
- RuntimeError
- If the basis has multiple components. This would only happen if there is an
- implementation issue, for example, if a composite basis is implemented but the
- mixin class is not initialized, or if the _list_components method of the composite mixin
- class is accidentally removed.
- """
- if hasattr(self, "basis1"):
- raise RuntimeError(
- "Composite basis must implement the _list_components method."
- )
- return [self]
-
- def __sklearn_clone__(self) -> Basis:
- """Clone the basis while preserving attributes related to input shapes.
-
- This method ensures that input shape attributes (e.g., `_n_basis_input_`,
- `_input_shape_`) are preserved during cloning. Reinitializing the class
- as in the regular sklearn clone would drop these attributes, rendering
- cross-validation unusable.
- The method also handles recursive cloning for composite basis structures.
- """
- # clone recursively
- if hasattr(self, "_basis1") and hasattr(self, "_basis2"):
- basis1 = self._basis1.__sklearn_clone__()
- basis2 = self._basis2.__sklearn_clone__()
- klass = self.__class__(basis1, basis2)
-
- else:
- klass = self.__class__(**self.get_params())
-
- for attr_name in ["_n_basis_input_", "_input_shape_"]:
- setattr(klass, attr_name, getattr(self, attr_name))
- return klass
-
class AdditiveBasis(CompositeBasisMixin, Basis):
"""
@@ -895,34 +767,9 @@ def n_output_features(self):
return None
return out1 + out2
+ @add_docstring("set_input_shape", CompositeBasisMixin)
def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
"""
- Set the expected input shape for the basis object.
-
- This method sets the input shape for each component basis in the ``AdditiveBasis``.
- One ``xi`` must be provided for each basis component, specified as an integer,
- a tuple of integers, or an array. The method calculates and stores the total number of output features
- based on the number of basis functions in each component and the provided input shapes.
-
- Parameters
- ----------
- *xi :
- The input shape specifications. For every k, ``xi[k]`` can be:
- - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input.
- - A tuple: Represents the exact input shape excluding the first axis (sample axis).
- All elements must be integers.
- - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
-
- Raises
- ------
- ValueError
- If a tuple is provided and it contains non-integer elements.
-
- Returns
- -------
- self :
- Returns the instance itself to allow method chaining.
-
Examples
--------
>>> # Generate sample data
@@ -944,15 +791,7 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
181
"""
- self._n_basis_input_ = (
- *self._basis1.set_input_shape(
- *xi[: self._basis1._n_input_dimensionality]
- )._n_basis_input_,
- *self._basis2.set_input_shape(
- *xi[self._basis1._n_input_dimensionality :]
- )._n_basis_input_,
- )
- return self
+ return super().set_input_shape(*xi)
@support_pynapple(conv_type="numpy")
@check_transform_input
@@ -1406,64 +1245,6 @@ def _compute_features(
)
return X
- def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
- """
- Set the expected input shape for the basis object.
-
- This method sets the input shape for each component basis in the ``MultiplicativeBasis``.
- One ``xi`` must be provided for each basis component, specified as an integer,
- a tuple of integers, or an array. The method calculates and stores the total number of output features
- based on the number of basis functions in each component and the provided input shapes.
-
- Parameters
- ----------
- *xi :
- The input shape specifications. For every k,``xi[k]`` can be:
- - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input.
- - A tuple: Represents the exact input shape excluding the first axis (sample axis).
- All elements must be integers.
- - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
-
- Raises
- ------
- ValueError
- If a tuple is provided and it contains non-integer elements.
-
- Returns
- -------
- self :
- Returns the instance itself to allow method chaining.
- Examples
- --------
- >>> # Generate sample data
- >>> import numpy as np
- >>> import nemos as nmo
-
- >>> # define an additive basis
- >>> basis_1 = nmo.basis.BSplineEval(5)
- >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6)
- >>> basis_3 = nmo.basis.MSplineEval(7)
- >>> multiplicative_basis = basis_1 * basis_2 * basis_3
-
- Specify the input shape using all 3 allowed ways: integer, tuple, array
- >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5)))
-
- Expected output features are:
- (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200
- >>> multiplicative_basis.n_output_features
- 25200
-
- """
- self._n_basis_input_ = (
- *self._basis1.set_input_shape(
- *xi[: self._basis1._n_input_dimensionality]
- )._n_basis_input_,
- *self._basis2.set_input_shape(
- *xi[self._basis1._n_input_dimensionality :]
- )._n_basis_input_,
- )
- return self
-
def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
"""Evaluate the basis set on a grid of equi-spaced sample points.
@@ -1558,3 +1339,29 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
+
+ @add_docstring("set_input_shape", CompositeBasisMixin)
+ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
+ """
+ Examples
+ --------
+ >>> # Generate sample data
+ >>> import numpy as np
+ >>> import nemos as nmo
+
+ >>> # define an additive basis
+ >>> basis_1 = nmo.basis.BSplineEval(5)
+ >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6)
+ >>> basis_3 = nmo.basis.MSplineEval(7)
+ >>> multiplicative_basis = basis_1 * basis_2 * basis_3
+
+ Specify the input shape using all 3 allowed ways: integer, tuple, array
+ >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5)))
+
+ Expected output features are:
+ (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200
+ >>> multiplicative_basis.n_output_features
+ 25200
+
+ """
+ return super().set_input_shape(*xi)
diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py
index 9b208e31..d0c80f81 100644
--- a/src/nemos/basis/_basis_mixin.py
+++ b/src/nemos/basis/_basis_mixin.py
@@ -6,6 +6,7 @@
import copy
import inspect
import warnings
+from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np
@@ -19,14 +20,167 @@
from ._basis import Basis
+def set_input_shape_state(method):
+ """
+ Decorator to preserve input shape-related attributes during method execution.
+
+ This decorator ensures that the attributes `_n_basis_input_` and `_input_shape_`
+ are copied from the original object (`self`) to the returned object (`klass`)
+ after the wrapped method executes. It is intended to be used with methods that
+ clone or create a new instance of the class, ensuring these critical attributes
+ are retained for functionality such as cross-validation.
+
+ Parameters
+ ----------
+ method :
+ The method to be wrapped. This method is expected to return an object
+ (`klass`) that requires the `_n_basis_input_` and `_input_shape_` attributes.
+
+ Returns
+ -------
+ :
+ The wrapped method that copies `_n_basis_input_` and `_input_shape_` from
+ the original object (`self`) to the new object (`klass`).
+
+ Examples
+ --------
+ Applying the decorator to a method:
+
+ >>> from functools import wraps
+ >>> @set_input_shape_state
+ ... def __sklearn_clone__(self):
+ ... klass = self.__class__(**self.get_params())
+ ... return klass
+
+ The `_n_basis_input_` and `_input_shape_` attributes of `self` will be
+ copied to `klass` after the method executes.
+ """
+
+ @wraps(method)
+ def wrapper(self, *args, **kwargs):
+ klass: Basis = method(self, *args, **kwargs)
+ for attr_name in ["_n_basis_input_", "_input_shape_"]:
+ setattr(klass, attr_name, getattr(self, attr_name))
+ return klass
+
+ return wrapper
+
+
+class AtomicBasisMixin:
+ """Mixin class for atomic bases (i.e. non-composite)."""
+
+ def __init__(self, n_basis_funcs: int):
+ self._n_basis_funcs = n_basis_funcs
+ self._check_n_basis_min()
+
+ @set_input_shape_state
+ def __sklearn_clone__(self) -> Basis:
+ """Clone the basis while preserving attributes related to input shapes.
+
+ This method ensures that input shape attributes (e.g., `_n_basis_input_`,
+ `_input_shape_`) are preserved during cloning. Reinitializing the class
+ as in the regular sklearn clone would drop these attributes, rendering
+ cross-validation unusable.
+ """
+ klass = self.__class__(**self.get_params())
+
+ for attr_name in ["_n_basis_input_", "_input_shape_"]:
+ setattr(klass, attr_name, getattr(self, attr_name))
+ return klass
+
+ def _list_components(self):
+ """List all basis components.
+
+ For atomic bases, the list is just [self].
+
+ Returns
+ -------
+ A list with the basis components.
+
+ """
+ return [self]
+
+ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
+ """
+ Set the expected input shape for the basis object.
+
+ This method configures the shape of the input data that the basis object expects.
+ ``xi`` can be specified as an integer, a tuple of integers, or derived
+ from an array. The method also calculates the total number of input
+ features and output features based on the number of basis functions.
+
+ Parameters
+ ----------
+ xi :
+ The input shape specification.
+ - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input.
+ - A tuple: Represents the exact input shape excluding the first axis (sample axis).
+ All elements must be integers.
+ - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
+
+ Raises
+ ------
+ ValueError
+ If a tuple is provided and it contains non-integer elements.
+
+ Returns
+ -------
+ self :
+ Returns the instance itself to allow method chaining.
+
+ Notes
+ -----
+ All state attributes that depends on the input must be set in this method in order for
+ the API of basis to work correctly. In particular, this method is called by ``setup_basis``,
+ which is equivalent to ``fit`` for a transformer. If any input dependent state
+ is not set in this method, then ``compute_features`` (equivalent to ``fit_transform``) will break.
+
+ """
+ if isinstance(xi, tuple):
+ if not all(isinstance(i, int) for i in xi):
+ raise ValueError(
+ f"The tuple provided contains non integer values. Tuple: {xi}."
+ )
+ shape = xi
+ elif isinstance(xi, int):
+ shape = () if xi == 1 else (xi,)
+ else:
+ shape = xi.shape[1:]
+
+ n_inputs = (int(np.prod(shape)),)
+
+ self._input_shape_ = shape
+
+ self._n_basis_input_ = n_inputs
+ return self
+
+ def _check_input_shape_consistency(self, x: NDArray):
+ """Check input consistency across calls."""
+ # remove sample axis and squeeze
+ shape = x.shape[1:]
+
+ initialized = self._input_shape_ is not None
+ is_shape_match = self._input_shape_ == shape
+ if initialized and not is_shape_match:
+ expected_shape_str = "(n_samples, " + f"{self._input_shape_}"[1:]
+ expected_shape_str = expected_shape_str.replace(",)", ")")
+ raise ValueError(
+ f"Input shape mismatch detected.\n\n"
+ f"The basis `{self.__class__.__name__}` with label '{self.label}' expects inputs with "
+ f"a consistent shape (excluding the sample axis). Specifically, the shape should be:\n"
+ f" Expected: {expected_shape_str}\n"
+ f" But got: {x.shape}.\n\n"
+ "Note: The number of samples (`n_samples`) can vary between calls of `compute_features`, "
+ "but all other dimensions must remain the same. If you need to process inputs with a "
+ "different shape, please create a new basis instance."
+ )
+
+
class EvalBasisMixin:
"""Mixin class for evaluational basis."""
- def __init__(
- self, n_basis_funcs: int, bounds: Optional[Tuple[float, float]] = None
- ):
+ def __init__(self, bounds: Optional[Tuple[float, float]] = None):
self.bounds = bounds
- self._n_basis_funcs = n_basis_funcs
def _compute_features(self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor):
"""Evaluate basis at sample points.
@@ -61,7 +215,7 @@ def setup_basis(self, *xi: NDArray) -> Basis:
Set all basis states.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
- it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
+ it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
@@ -122,13 +276,10 @@ def bounds(self, values: Union[None, Tuple[float, float]]):
class ConvBasisMixin:
"""Mixin class for convolutional basis."""
- def __init__(
- self, n_basis_funcs: int, window_size: int, conv_kwargs: Optional[dict] = None
- ):
+ def __init__(self, window_size: int, conv_kwargs: Optional[dict] = None):
self.kernel_ = None
self.window_size = window_size
self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs
- self._n_basis_funcs = n_basis_funcs
def _compute_features(self, *xi: NDArray | Tsd | TsdFrame | TsdTensor):
"""Convolve basis functions with input time series.
@@ -172,7 +323,7 @@ def setup_basis(self, *xi: NDArray) -> Basis:
Set all basis states.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
- it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
+ it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
@@ -383,15 +534,12 @@ def n_basis_funcs(self):
"""Read only property for composite bases."""
pass
- def _check_n_basis_min(self) -> None:
- pass
-
def setup_basis(self, *xi: NDArray) -> Basis:
"""
Set all basis states.
This method corresponds sklearn transformer ``fit``. As fit, it must receive the input and
- it must set all basis states, i.e. kernel_ and all the states relative to the input shape.
+ it must set all basis states, i.e. ``kernel_`` and all the states relative to the input shape.
The difference between this method and the transformer ``fit`` is in the expected input structure,
where the transformer ``fit`` method requires the inputs to be concatenated in a 2D array, while here
each input is provided as a separate time series for each basis element.
@@ -462,3 +610,60 @@ def _list_components(self):
A list with all 1d basis components.
"""
return self._basis1._list_components() + self._basis2._list_components()
+
+ @set_input_shape_state
+ def __sklearn_clone__(self) -> Basis:
+ """Clone the basis while preserving attributes related to input shapes.
+
+ This method ensures that input shape attributes (e.g., `_n_basis_input_`,
+ `_input_shape_`) are preserved during cloning. Reinitializing the class
+ as in the regular sklearn clone would drop these attributes, rendering
+ cross-validation unusable.
+ The method also handles recursive cloning for composite basis structures.
+ """
+ # clone recursively
+ basis1 = self._basis1.__sklearn_clone__()
+ basis2 = self._basis2.__sklearn_clone__()
+ klass = self.__class__(basis1, basis2)
+
+ for attr_name in ["_n_basis_input_", "_input_shape_"]:
+ setattr(klass, attr_name, getattr(self, attr_name))
+ return klass
+
+ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
+ """
+ Set the expected input shape for the basis object.
+
+ This method sets the input shape for each component basis in the basis.
+ One ``xi`` must be provided for each basis component, specified as an integer,
+ a tuple of integers, or an array. The method calculates and stores the total number of output features
+ based on the number of basis functions in each component and the provided input shapes.
+
+ Parameters
+ ----------
+ *xi :
+ The input shape specifications. For every k,``xi[k]`` can be:
+ - An integer: Represents the dimensionality of the input. A value of ``1`` is treated as scalar input.
+ - A tuple: Represents the exact input shape excluding the first axis (sample axis).
+ All elements must be integers.
+ - An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
+
+ Raises
+ ------
+ ValueError
+ If a tuple is provided and it contains non-integer elements.
+
+ Returns
+ -------
+ self :
+ Returns the instance itself to allow method chaining.
+ """
+ self._n_basis_input_ = (
+ *self._basis1.set_input_shape(
+ *xi[: self._basis1._n_input_dimensionality]
+ )._n_basis_input_,
+ *self._basis2.set_input_shape(
+ *xi[self._basis1._n_input_dimensionality :]
+ )._n_basis_input_,
+ )
+ return self
diff --git a/src/nemos/basis/_decaying_exponential.py b/src/nemos/basis/_decaying_exponential.py
index 7df05947..a1fd4a24 100644
--- a/src/nemos/basis/_decaying_exponential.py
+++ b/src/nemos/basis/_decaying_exponential.py
@@ -14,9 +14,10 @@
from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
from ._basis import Basis, check_transform_input, min_max_rescale_samples
+from ._basis_mixin import AtomicBasisMixin
-class OrthExponentialBasis(Basis, abc.ABC):
+class OrthExponentialBasis(Basis, AtomicBasisMixin, abc.ABC):
"""Set of 1D basis decaying exponential functions numerically orthogonalized.
Parameters
@@ -37,6 +38,7 @@ def __init__(
mode="eval",
label: Optional[str] = "OrthExponentialBasis",
):
+ AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs)
super().__init__(
mode=mode,
label=label,
diff --git a/src/nemos/basis/_raised_cosine_basis.py b/src/nemos/basis/_raised_cosine_basis.py
index 0521a683..c964f1dc 100644
--- a/src/nemos/basis/_raised_cosine_basis.py
+++ b/src/nemos/basis/_raised_cosine_basis.py
@@ -11,9 +11,10 @@
from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
from ._basis import Basis, check_transform_input, min_max_rescale_samples
+from ._basis_mixin import AtomicBasisMixin
-class RaisedCosineBasisLinear(Basis, abc.ABC):
+class RaisedCosineBasisLinear(Basis, AtomicBasisMixin, abc.ABC):
"""Represent linearly-spaced raised cosine basis functions.
This implementation is based on the cosine bumps used by Pillow et al. [1]_
@@ -44,6 +45,7 @@ def __init__(
width: float = 2.0,
label: Optional[str] = "RaisedCosineBasisLinear",
) -> None:
+ AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs)
super().__init__(
mode=mode,
label=label,
diff --git a/src/nemos/basis/_spline_basis.py b/src/nemos/basis/_spline_basis.py
index 5fc4c38e..78cc34a6 100644
--- a/src/nemos/basis/_spline_basis.py
+++ b/src/nemos/basis/_spline_basis.py
@@ -13,9 +13,10 @@
from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
from ._basis import Basis, check_transform_input, min_max_rescale_samples
+from ._basis_mixin import AtomicBasisMixin
-class SplineBasis(Basis, abc.ABC):
+class SplineBasis(Basis, AtomicBasisMixin, abc.ABC):
"""
SplineBasis class inherits from the Basis class and represents spline basis functions.
@@ -43,6 +44,7 @@ def __init__(
mode: Literal["conv", "eval"] = "eval",
) -> None:
self.order = order
+ AtomicBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs)
super().__init__(
label=label,
mode=mode,
diff --git a/src/nemos/basis/_transformer_basis.py b/src/nemos/basis/_transformer_basis.py
index a8ce37fa..db2d5676 100644
--- a/src/nemos/basis/_transformer_basis.py
+++ b/src/nemos/basis/_transformer_basis.py
@@ -100,7 +100,6 @@ def basis(self):
@basis.setter
def basis(self, basis):
- self._check_initialized(basis)
self._basis = basis
def _unpack_inputs(self, X: FeatureMatrix) -> List:
@@ -121,14 +120,13 @@ def _unpack_inputs(self, X: FeatureMatrix) -> List:
"""
n_samples = X.shape[0]
- out = []
- cc = 0
- for i, bas in enumerate(self._list_components()):
- n_input = self._n_basis_input_[i]
- out.append(
- np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_))
+ out = [
+ np.reshape(X[:, cc : cc + n_input], (n_samples, *bas._input_shape_))
+ for i, (bas, n_input) in enumerate(
+ zip(self._list_components(), self._n_basis_input_)
)
- cc += n_input
+ for cc in [sum(self._n_basis_input_[:i])]
+ ]
return out
def fit(self, X: FeatureMatrix, y=None):
@@ -166,9 +164,9 @@ def fit(self, X: FeatureMatrix, y=None):
>>> # Example input
>>> X = np.random.normal(size=(100, 2))
- >>> # Define, setup and fit transformer basis
- >>> basis = MSplineEval(10)
- >>> transformer = TransformerBasis(basis).set_input_shape(2)
+ >>> # Define and fit tranformation basis
+ >>> basis = MSplineEval(10).set_input_shape(2)
+ >>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X)
"""
self._check_initialized(self._basis)
@@ -200,12 +198,12 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> # Example input
>>> X = np.random.normal(size=(10000, 2))
- >>> basis = MSplineConv(10, window_size=200)
+ >>> basis = MSplineConv(10, window_size=200).set_input_shape(2)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_
- >>> transformer_fitted = transformer.set_input_shape(2).fit(X)
+ >>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)
@@ -244,15 +242,10 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
>>> from nemos.basis import MSplineEval, TransformerBasis
>>> # Example input
- >>> n_inputs = 2
- >>> X = np.random.normal(size=(100, 2))
+ >>> X = np.random.normal(size=(100, 1))
>>> # Define tranformation basis
- >>> basis = MSplineEval(10)
- >>> # Prepare basis to process 2 inputs
- >>> # This step must be done before
- >>> basis = basis.set_input_shape(n_inputs)
-
+ >>> basis = MSplineEval(10).set_input_shape(1)
>>> transformer = TransformerBasis(basis)
>>> # Fit and transform basis
@@ -346,10 +339,10 @@ def __setattr__(self, name: str, value) -> None:
>>> trans_bas.n_basis_funcs = 20
>>> # not allowed
>>> try:
- ... trans_bas.rand_attr = "some value"
+ ... trans_bas.random_attribute_name = "some value"
... except ValueError as e:
... print(repr(e))
- ValueError('Only setting _basis or existing attributes of _basis is allowed. Attempt to set `rand_attr`.')
+ ValueError('Only setting _basis or existing attributes of _basis is allowed.')
"""
# allow self._basis = basis and other attrs of self to be retrievable
if name in ["_basis", "basis", "_wrapped_methods"]:
diff --git a/src/nemos/basis/basis.py b/src/nemos/basis/basis.py
index 702f7b56..2ea0bd86 100644
--- a/src/nemos/basis/basis.py
+++ b/src/nemos/basis/basis.py
@@ -9,7 +9,7 @@
from ..typing import FeatureMatrix
from ._basis import add_docstring
-from ._basis_mixin import ConvBasisMixin, EvalBasisMixin
+from ._basis_mixin import AtomicBasisMixin, ConvBasisMixin, EvalBasisMixin
from ._decaying_exponential import OrthExponentialBasis
from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog
from ._spline_basis import BSplineBasis, CyclicBSplineBasis, MSplineBasis
@@ -83,13 +83,15 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "BSplineEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+
BSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="eval",
order=order,
label=label,
)
+ EvalBasisMixin.__init__(self, bounds=bounds)
@add_docstring("split_by_feature", BSplineBasis)
def split_by_feature(
@@ -156,7 +158,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", BSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -164,25 +166,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.BSplineEval(5)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class BSplineConv(ConvBasisMixin, BSplineBasis):
@@ -236,14 +235,10 @@ def __init__(
label: Optional[str] = "BSplineConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
BSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="conv",
order=order,
label=label,
@@ -314,7 +309,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", BSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -322,25 +317,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.BSplineConv(5, 10)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis):
@@ -381,9 +373,10 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "CyclicBSplineEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+ EvalBasisMixin.__init__(self, bounds=bounds)
CyclicBSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="eval",
order=order,
label=label,
@@ -454,7 +447,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", CyclicBSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -462,25 +455,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.CyclicBSplineEval(5)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis):
@@ -526,14 +516,10 @@ def __init__(
label: Optional[str] = "CyclicBSplineConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
CyclicBSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="conv",
order=order,
label=label,
@@ -604,7 +590,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", CyclicBSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -612,25 +598,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.CyclicBSplineConv(5, 10)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class MSplineEval(EvalBasisMixin, MSplineBasis):
@@ -695,9 +678,10 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "MSplineEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+ EvalBasisMixin.__init__(self, bounds=bounds)
MSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="eval",
order=order,
label=label,
@@ -768,7 +752,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", MSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -776,25 +760,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.MSplineEval(5)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class MSplineConv(ConvBasisMixin, MSplineBasis):
@@ -864,14 +845,10 @@ def __init__(
label: Optional[str] = "MSplineConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
MSplineBasis.__init__(
self,
+ n_basis_funcs,
mode="conv",
order=order,
label=label,
@@ -942,7 +919,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
"""
return super().evaluate_on_grid(n_samples)
- @add_docstring("set_input_shape", MSplineBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -950,25 +927,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.MSplineConv(5, 10)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear):
@@ -1017,9 +991,10 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "RaisedCosineLinearEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+ EvalBasisMixin.__init__(self, bounds=bounds)
RaisedCosineBasisLinear.__init__(
self,
+ n_basis_funcs,
width=width,
mode="eval",
label=label,
@@ -1083,7 +1058,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", RaisedCosineBasisLinear)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1091,25 +1066,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.RaisedCosineLinearEval(5)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear):
@@ -1163,14 +1135,10 @@ def __init__(
label: Optional[str] = "RaisedCosineLinearConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
RaisedCosineBasisLinear.__init__(
self,
+ n_basis_funcs,
mode="conv",
width=width,
label=label,
@@ -1234,7 +1202,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", RaisedCosineBasisLinear)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1242,25 +1210,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.RaisedCosineLinearConv(5, 10)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog):
@@ -1323,9 +1288,10 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "RaisedCosineLogEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+ EvalBasisMixin.__init__(self, bounds=bounds)
RaisedCosineBasisLog.__init__(
self,
+ n_basis_funcs,
width=width,
time_scaling=time_scaling,
enforce_decay_to_zero=enforce_decay_to_zero,
@@ -1391,7 +1357,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", RaisedCosineBasisLog)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1399,25 +1365,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.RaisedCosineLogEval(5)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog):
@@ -1481,14 +1444,10 @@ def __init__(
label: Optional[str] = "RaisedCosineLogConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
RaisedCosineBasisLog.__init__(
self,
+ n_basis_funcs,
mode="conv",
width=width,
time_scaling=time_scaling,
@@ -1554,7 +1513,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", RaisedCosineBasisLog)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1562,25 +1521,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.RaisedCosineLogConv(5, 10)
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis):
@@ -1623,9 +1579,10 @@ def __init__(
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "OrthExponentialEval",
):
- EvalBasisMixin.__init__(self, n_basis_funcs=n_basis_funcs, bounds=bounds)
+ EvalBasisMixin.__init__(self, bounds=bounds)
OrthExponentialBasis.__init__(
self,
+ n_basis_funcs,
decay_rates=decay_rates,
mode="eval",
label=label,
@@ -1693,7 +1650,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", OrthExponentialBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1701,25 +1658,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.OrthExponentialEval(5, decay_rates=np.arange(1, 6))
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis):
@@ -1765,14 +1719,10 @@ def __init__(
label: Optional[str] = "OrthExponentialConv",
conv_kwargs: Optional[dict] = None,
):
- ConvBasisMixin.__init__(
- self,
- n_basis_funcs=n_basis_funcs,
- window_size=window_size,
- conv_kwargs=conv_kwargs,
- )
+ ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
OrthExponentialBasis.__init__(
self,
+ n_basis_funcs,
mode="conv",
decay_rates=decay_rates,
label=label,
@@ -1844,7 +1794,7 @@ def split_by_feature(
"""
return super().split_by_feature(x, axis=axis)
- @add_docstring("set_input_shape", OrthExponentialBasis)
+ @add_docstring("set_input_shape", AtomicBasisMixin)
def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
"""
Examples
@@ -1852,25 +1802,22 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
>>> import nemos as nmo
>>> import numpy as np
>>> basis = nmo.basis.OrthExponentialConv(5, window_size=10, decay_rates=np.arange(1, 6))
-
- Configure with an integer input:
+ >>> # Configure with an integer input:
>>> _ = basis.set_input_shape(3)
>>> basis.n_output_features
15
-
- Configure with a tuple:
+ >>> # Configure with a tuple:
>>> _ = basis.set_input_shape((4, 5))
>>> basis.n_output_features
100
-
- Configure with an array:
+ >>> # Configure with an array:
>>> x = np.ones((10, 4, 5))
>>> _ = basis.set_input_shape(x)
>>> basis.n_output_features
100
"""
- return super().set_input_shape(xi)
+ return AtomicBasisMixin.set_input_shape(self, xi)
def _check_window_size(self, window_size: int):
"""OrthExponentialBasis specific window size check."""
diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py
index b949b489..d0f7709e 100644
--- a/src/nemos/identifiability_constraints.py
+++ b/src/nemos/identifiability_constraints.py
@@ -218,6 +218,8 @@ def apply_identifiability_constraints(
>>> from nemos.identifiability_constraints import apply_identifiability_constraints
>>> from nemos.basis import BSplineEval
>>> from nemos.glm import GLM
+ >>> import jax
+ >>> jax.config.update('jax_enable_x64', True)
>>> # define a feature matrix
>>> bas = BSplineEval(5) + BSplineEval(6)
>>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100))
@@ -280,9 +282,11 @@ def apply_identifiability_constraints_by_basis_component(
Examples
--------
>>> import numpy as np
+ >>> import jax
>>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component
>>> from nemos.basis import BSplineEval
>>> from nemos.glm import GLM
+ >>> jax.config.update('jax_enable_x64', True)
>>> # define a feature matrix
>>> bas = BSplineEval(5) + BSplineEval(6)
>>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100))
diff --git a/tests/test_basis.py b/tests/test_basis.py
index 02a0332d..f6f033e1 100644
--- a/tests/test_basis.py
+++ b/tests/test_basis.py
@@ -1327,6 +1327,20 @@ def test_set_input_value_types(self, inp_shape, expectation, cls):
with expectation:
bas.set_input_shape(inp_shape)
+ @pytest.mark.parametrize(
+ "mode, kwargs", [("eval", {}), ("conv", {"window_size": 6})]
+ )
+ def test_list_component(self, mode, kwargs, cls):
+ basis_obj = cls[mode](
+ n_basis_funcs=5,
+ **kwargs,
+ **extra_decay_rates(cls[mode], 5),
+ )
+
+ out = basis_obj._list_components()
+ assert len(out) == 1
+ assert id(out[0]) == id(basis_obj)
+
class TestRaisedCosineLogBasis(BasisFuncsTesting):
cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv}
@@ -1957,6 +1971,33 @@ def test_samples_range_matches_compute_features_requirements(
class TestAdditiveBasis(CombinedBasis):
cls = {"eval": AdditiveBasis, "conv": AdditiveBasis}
+ @pytest.mark.parametrize("basis_a", list_all_basis_classes())
+ @pytest.mark.parametrize("basis_b", list_all_basis_classes())
+ def test_list_component(self, basis_a, basis_b, basis_class_specific_params):
+ basis_a_obj = self.instantiate_basis(
+ 5, basis_a, basis_class_specific_params, window_size=10
+ )
+ basis_b_obj = self.instantiate_basis(
+ 6, basis_b, basis_class_specific_params, window_size=10
+ )
+ add = basis_a_obj + basis_b_obj
+ out = add._list_components()
+
+ assert len(out) == add._n_input_dimensionality
+
+ def get_ids(bas):
+
+ if hasattr(bas, "basis1"):
+ ids = get_ids(bas.basis1)
+ ids += get_ids(bas.basis2)
+ else:
+ ids = [id(bas)]
+ return ids
+
+ id_list = get_ids(add)
+
+ assert tuple(id(o) for o in out) == tuple(id_list)
+
@pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]])
@pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv])
def test_non_empty_samples(self, base_cls, samples, basis_class_specific_params):
diff --git a/tests/test_identifiability_constraints.py b/tests/test_identifiability_constraints.py
index ca4f4be2..f40ad214 100644
--- a/tests/test_identifiability_constraints.py
+++ b/tests/test_identifiability_constraints.py
@@ -190,6 +190,7 @@ def test_feature_matrix_dtype(dtype, expected_dtype):
)
def test_apply_constraint_with_invalid(invalid_entries):
"""Test if the matrix retains its dtype after applying constraints."""
+ jax.config.update("jax_enable_x64", True)
x = np.random.randn(10, 5)
# add invalid
x[:2, 2] = invalid_entries
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index cf75ad3e..6c667851 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -174,6 +174,7 @@ def test_sklearn_transformer_pipeline_pynapple(
X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep)
y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep)
bas = TransformerBasis(bas).set_input_shape(*([1] * bas._n_input_dimensionality))
+
# fit a pipeline & predict from pynapple
pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])
pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap)