Skip to content

Commit

Permalink
edited docs and linted
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 14, 2024
1 parent 0340a3b commit 9629f7f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 33 deletions.
2 changes: 1 addition & 1 deletion docs/how_to_guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ plot_06_sklearn_pipeline_cv_demo.md
```{toctree}
:maxdepth: 2
plot_06_glm_pytree.md
plot_07_glm_pytree.md
```

:::
Expand Down
58 changes: 28 additions & 30 deletions docs/how_to_guide/plot_05_transformer_basis.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,25 @@ kernelspec:
(tansformer-vs-nemos-basis)=
## scikit-learn Transformers and NeMoS Basis

`scikit-learn` is a great machine learning package that provides advanced tooling for creating data analysis pipelines, from input transformations to model fitting and cross-validation.
`scikit-learn` is a powerful machine learning library that provides advanced tools for creating data analysis pipelines, from input transformations to model fitting and cross-validation.

All of `scikit-learn` machinery relies on very strong assumptions on how one should structure the inputs to each processing step.
In particular, all `scikit-learn` objects requires inputs in the form of arrays of at most two-dimensions, where the first dimension always represents time (or samples) dimension, and the other features.
This may feel a bit rigid at first, but what this buys you is that any transformation can be chained to any other, greatly simplifying the process of building stable complex pipelines.
All of `scikit-learn`'s machinery relies on strict assumptions about input structure. In particular, all `scikit-learn`
objects require inputs as arrays of at most two dimensions, where the first dimension represents the time (or samples)
axis, and the second dimension represents features.
While this may feel rigid, it enables transformations to be seamlessly chained together, greatly simplifying the
process of building stable, complex pipelines.

In `scikit-learn`, the data transformation steps are performed by object called `transformers`.
On the other hand, `NeMoS` takes a different approach to feature construction. `NeMoS`' bases are composable constructors that allow for more flexibility in the required input structure.
Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can take any shape as long as the time (or sample) axis is the first of each array;
`NeMoS` design favours object composability: one can combine any two or more bases to compute complex features, with a user-friendly interface that can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.).


On the other hand, NeMoS basis are powerful feature constructors that allow a high degree of flexibility in terms of the required input structure.
Depending on the basis type, it can accept one or more input arrays or `pynapple` time series data, each of which can have any shape as long as the time (or sample) axis is the first of each array;
NeMoS design favours object composability, one can combine any two or more bases to compute complex features, and a user-friendly interface can accept a separate array/time series for each input type (e.g., an array with the spike counts, an array for the animal's position, etc.).

Both approaches to data transformations are valuable and have their own advantages.
Wouldn't it be great if one could combine them? Well, this is what NeMoS `TransformerBasis` are for!
Both approaches to data transformation are valuable and each has its own advantages. Wouldn't it be great if one could combine the two? Well, this is what NeMoS `TransformerBasis` is for!


## From Basis to TransformerBasis


With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process the neural activity as a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array with the speed of an animal of shape `(n_samples,)`.
With NeMoS, you can easily create a basis accepting two inputs. Let's assume that we want to process neural activity stored in a 2-dimensional spike count array of shape `(n_samples, n_neurons)` and a second array containing the speed of an animal, with shape `(n_samples,)`.

```{code-cell} ipython3
import numpy as np
Expand All @@ -60,9 +58,9 @@ X = composite_basis.compute_features(counts, speed)
### Converting NeMoS `Basis` to a transformer

Now, imagine that we want to use this basis as a step in a `scikit-learn` pipeline.
In this standard (for NeMoS) form, it would not be possible the `composite_basis` object requires two inputs. We need to convert it first into a compliant scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class.
In this standard (for NeMoS) form, it would not be possible as the `composite_basis` object requires two inputs. We need to convert it first into a compliant `scikit-learn` transformer. This can be achieved through the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) wrapper class.

Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either using by the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):
Instantiating a [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) can be done either by using the constructor directly or with [`Basis.to_transformer()`](nemos.basis._basis.Basis.to_transformer):


```{code-cell} ipython3
Expand All @@ -83,7 +81,7 @@ trans_bas = bas.to_transformer()
print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
```

We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), and neither does changing the original [`Basis`](nemos.basis._basis.Basis) change [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created:
We can also set attributes of the underlying [`Basis`](nemos.basis._basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis._basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis._basis.Basis), nor does changing the original [`Basis`](nemos.basis._basis.Basis) modify the [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis) we created:


```{code-cell} ipython3
Expand All @@ -93,11 +91,11 @@ bas.n_basis_funcs = 100
print(bas.n_basis_funcs, trans_bas.n_basis_funcs)
```

As any `sckit-learn` tansformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore all accepts a single 2D array.
As with any `sckit-learn` transformer, the `TransformerBasis` implements `fit`, a preparation step, `transform`, the actual feature computation, and `fit_transform` which chains `fit` and `transform`. These methods comply with the `scikit-learn` input structure convention, and therefore they all accept a single 2D array.

## Setting up the TransformerBasis

At this point we have an object equipped with the correct methods, so now all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right?
At this point we have an object equipped with the correct methods, so now, all we have to do is concatenate the inputs into a unique array and call `fit_transform`, right?

```{code-cell} ipython3
Expand All @@ -115,28 +113,28 @@ except RuntimeError as e:
print(repr(e))
```

Unfortunately not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`.
...Unfortunately, not yet. The problem is that the basis has never interacted with the two separate inputs, and therefore doesn't know which columns of `inp` should be processed by `count_basis` and which by `speed_basis`.

There are several ways in which you can provide this information to the basis. The first one is by calling the method `set_input_shape`.
You can provide this information by calling the `set_input_shape` method of the basis.

This can be called before or after the transformer basis is defined. The method extracts and store the array shapes excluding the sample axis (which won't be affected in the concatenation).
This can be called before or after the transformer basis is defined. The method extracts and stores the array shapes excluding the sample axis (which won't be affected in the concatenation).

`set_input_shape` accepts directly the inputs,
`set_input_shape` directly accepts the inputs:

```{code-cell} ipython3
composite_basis.set_input_shape(counts, speed)
out = composite_basis.to_transformer().fit_transform(inp)
```

If the input is 1D or 2D, the number of columns,
If the input is 1D or 2D, it also accepts the number of columns:
```{code-cell} ipython3
composite_basis.set_input_shape(5, 1)
out = composite_basis.to_transformer().fit_transform(inp)
```

A tuple containing the shapes of all axis other than the first,
A tuple containing the shapes of all the axes other than the first,
```{code-cell} ipython3
composite_basis.set_input_shape((5,), (1,))
Expand All @@ -150,7 +148,7 @@ composite_basis.set_input_shape(counts, 1)
out = composite_basis.to_transformer().fit_transform(inp)
```

You can also invert the order and call `to_transform` first and set the input shapes after.
You can also invert the order of operations and call `to_transform` first and then set the input shapes.
```{code-cell} ipython3
trans_bas = composite_basis.to_transformer()
Expand All @@ -160,12 +158,12 @@ out = trans_bas.fit_transform(inp)

:::{note}

If you define a NeMoS basis and call `compute_features` on your inputs, internally, the basis will store the
input shapes, and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`.
If you define a basis and call `compute_features` on your inputs, internally, it will store its shapes,
and the `TransformerBasis` will be ready to process without any direct call to `set_input_shape`.
:::

If for some reason you will need to provide an input of different shape to the transformer, you must setup the
`TransformerBasis` again.
If for some reason you need to provide an input of different shape to an already set-up transformer, you must reset the
`TransformerBasis` with `set_input_shape`.

```{code-cell} ipython3
Expand All @@ -181,5 +179,5 @@ out2 = trans_bas.fit_transform(inp2)

### Learn more

If you want to learn more about basis how to select basis hyperparameters with `sklearn` pipelining and cross-validation, check out [this guide](sklearn-how-to).
If you want to learn more about how to select basis' hyperparameters with `sklearn` pipelining and cross-validation, check out [this how-to guide](sklearn-how-to).

2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ For Developers <developers_notes/README>
```


## __Neural ModelS__
# __Neural ModelS__


NeMoS (Neural ModelS) is a statistical modeling framework optimized for systems neuroscience and powered by [JAX](https://jax.readthedocs.io/en/latest/).
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def set_input_shape(self, xi: int | tuple[int, ...] | NDArray):
self._n_basis_input_ = n_inputs
return self


def _check_input_shape_consistency(self, x: NDArray):
"""Check input consistency across calls."""
# remove sample axis and squeeze
Expand All @@ -176,6 +175,7 @@ def _check_input_shape_consistency(self, x: NDArray):
"different shape, please create a new basis instance."
)


class EvalBasisMixin:
"""Mixin class for evaluational basis."""

Expand Down

0 comments on commit 9629f7f

Please sign in to comment.