Skip to content

Commit

Permalink
Merge branch 'multi_dim_eval_basis' into set_shape_basis_method
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 12, 2024
2 parents fbf0aeb + f3d44ac commit 8318716
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 191 deletions.
5 changes: 3 additions & 2 deletions docs/background/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@ plot_00_conceptual_intro.md

```{eval-rst}
.. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear
.. plot:: scripts/basis_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 100px
```

```{toctree}
:maxdepth: 2
:maxdepth: 3
basis/README.md
```

:::

:::{grid-item-card}
Expand Down
44 changes: 25 additions & 19 deletions docs/background/basis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,47 @@
- **Evaluation/Convolution**
- **Preferred Mode**
* - **B-Spline**
- .. plot:: scripts/basis_table_figs.py plot_bspline
- .. plot:: scripts/basis_figs.py plot_bspline
:show-source-link: False
:height: 80px
- :ref:`Grid cells <grid_cells_nemos>`
- :class:`~nemos.basis.BSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.BSplineConv`
- 🟢 Eval
* - **Cyclic B-Spline**
- .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline
- .. plot:: scripts/basis_figs.py plot_cyclic_bspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.CyclicBSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.CyclicBSplineConv`
- 🟢 Eval
* - **M-Spline**
- .. plot:: scripts/basis_table_figs.py plot_mspline
- .. plot:: scripts/basis_figs.py plot_mspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.MSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.MSplineConv`
- 🟢 Eval
* - **Linearly Spaced Raised Cosine**
- .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear
- .. plot:: scripts/basis_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 80px
-
- :class:`~nemos.basis.RaisedCosineLinearEval` :raw-html:`<br />`
:class:`~nemos.basis.RaisedCosineLinearConv`
- 🟢 Eval
* - **Log Spaced Raised Cosine**
- .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log
- .. plot:: scripts/basis_figs.py plot_raised_cosine_log
:show-source-link: False
:height: 80px
- :ref:`Head Direction <head_direction_reducing_dimensionality>`
- :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`<br />`
:class:`nemos.basis.RaisedCosineLogConv`
:class:`~nemos.basis.RaisedCosineLogConv`
- 🔵 Conv
* - **Orthogonalized Exponential Decays**
- .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis
- .. plot:: scripts/basis_figs.py plot_orth_exp_basis
:show-source-link: False
:height: 80px
-
Expand All @@ -83,27 +83,30 @@ $$

Here, $\approx$ means "approximately equal".

Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$.
Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. This preserves convexity, resulting in a much simpler optimization problem.


## Basis in NeMoS

NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two key uses described in the overview:
NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two uses described above:

- **Eval-basis objects**: For representing non-linear mappings between task variables and outputs. These objects are identified by names starting with `Eval`.
- **Conv-basis objects**: For linear temporal effects. These objects are identified by names starting with `Conv`.
- **Eval basis objects**: For representing non-linear mappings between task variables and outputs. These objects all have names ending with `Eval`.
- **Conv basis objects**: For linear temporal effects. These objects all have names ending with `Conv`.

`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling complex feature construction.
`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling [complex feature construction](composing_basis_function).

## Learn More

::::{grid} 1 2 2 2

:::{grid-item-card}

<figure>
<img src="../../_static/thumbnails/background/plot_01_1D_basis_function.svg" style="height: 100px", alt="One-Dimensional Basis."/>
</figure>
```{eval-rst}
.. plot:: scripts/basis_figs.py plot_1d_basis_thumbnail
:show-source-link: False
:height: 100px
```

```{toctree}
:maxdepth: 2
Expand All @@ -114,9 +117,12 @@ plot_01_1D_basis_function.md

:::{grid-item-card}

<figure>
<img src="../../_static/thumbnails/background/plot_02_ND_basis_function.svg" style="height: 100px", alt="N-Dimensional Basis."/>
</figure>
```{eval-rst}
.. plot:: scripts/basis_figs.py plot_nd_basis_thumbnail
:show-source-link: False
:height: 100px
```

```{toctree}
:maxdepth: 2
Expand All @@ -125,4 +131,4 @@ plot_02_ND_basis_function.md
```
:::

::::
::::
123 changes: 37 additions & 86 deletions docs/background/basis/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,18 @@ warnings.filterwarnings(
category=RuntimeWarning,
)
from nemos._documentation_utils._myst_nb_glue import glue_two_step_convolve
glue_two_step_convolve()
```

(simple_basis_function)=
# Simple Basis Function

## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
We'll start by defining a 1D basis function object of the type [`BSplineEval`](nemos.basis.BSplineEval).
The hyperparameters needed to initialize this class are:

- The number of basis functions, which should be a positive integer (required).
Expand Down Expand Up @@ -81,50 +85,26 @@ plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```

```{code-cell} ipython3
:tags: [hide-input]
# save image for thumbnail
from pathlib import Path
import os
root = os.environ.get("READTHEDOCS_OUTPUT")
if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../../_build/html/_static/thumbnails/background")
# make sure the folder exists if run from build
if root or Path("../../assets/stylesheets").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")

print(path.resolve(), path.exists())
```

## Feature Computation
All bases in the `nemos.basis` module transform one or more time series into a set of features. This transformation is performed out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features).
The bases are categorized into two types based on the transformation applied by [`compute_features`](nemos.basis._basis.Basis.compute_features):
## Computing Features
All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features).
We can group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies:

1. **Evaluation Bases**: TThese bases evaluate the basis functions directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval", such as `BSplineEval`.
1. **Evaluation Bases**: These bases use `compute_features` to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.

2. **Convolution Bases**: hese bases convolve the input with a kernel of basis elements using a user-specified `window_size`. Classes in this category have names ending with "Conv", such as `BSplineConv`.
2. **Convolution Bases**: These bases use `compute_features` to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`.

Let's see how this two modalities operate.
Let's see how these two categories operate:

```{code-cell} ipython3
eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100)
eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100)
# define an input
angles = np.linspace(0, np.pi*4, 201)
y = np.cos(angles)
# compute features in the two modalities
# compute features
eval_feature = eval_mode.compute_features(y)
conv_feature = conv_mode.compute_features(y)
Expand Down Expand Up @@ -161,65 +141,57 @@ check out the tutorial on [1D convolutions](convolution_background).
:::

### Multi-dimensional inputs
For N-dimensional input with $N>1$, `compute_features` assumes the first axis represents samples. This is always valid for `pynapple` time series. For arrays, you can use `numpy.transpose` to re-arrange the axis if needed.
For inputs with more than one dimension, `compute_features` assumes the first axis represents samples. This is always valid for `pynapple` time series. For arrays, you can use [`numpy.transpose`](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html) to re-arrange the axis if needed.

#### "Eval" Basis
#### Eval Basis

For "Eval" bases, `compute_features` evaluates the basis and then reshape the result into a 2D feature matrix.
For Eval bases, `compute_features` evaluates the basis and outputs a 2D feature matrix.

```{code-cell} ipython3
basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=5)
# generate a 3D array
inp = np.random.randn(50, 2, 3)
inp = np.random.randn(50, 3, 2)
out = basis.compute_features(inp)
out.shape
```

For each of the $3 \times 2 = 6$ inputs, `n_basis_funcs = 5` features are computed. These are concatenated on the second axis of the feature matrix, for a total of
$3 \times 2 \times 5 = 30$ outputs concatenated on the second axis.
$3 \times 2 \times 5 = 30$ outputs.

#### "Conv" Basis
#### Conv Basis

For "Conv" bases, `compute_features` convolves each input with `n_basis_funcs` kernels, and reshaping the output into a 2D feature matrix.
For Conv bases, `compute_features` convolves each input with `n_basis_funcs` kernels and outputs a 2D feature matrix.

```{code-cell} ipython3
basis = nmo.basis.RaisedCosineLinearConv(n_basis_funcs=5, window_size=6)
# compute_features to perform the convolution and concatenate
out = basis.compute_features(inp)
print(f"`compute_features` output shape {out.shape}")
out.shape
```

This process is equivalent to performing the convolution separately usingS [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) and then reshaping the output.
:::{admonition} Note

```{code-cell} ipython3
# compute the kernels
basis.set_kernel()
print(f"Kernel shape (window_size, n_basis_funcs): {basis.kernel_.shape}")
This process is equivalent to performing the convolution separately with [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) and then reshaping the output.

# apply the convolution
out_two_steps = nmo.convolve.create_convolutional_predictor(basis.kernel_, inp)
print(f"Convolution output shape: {out_two_steps.shape}")
# then reshape to 2D
out_two_steps = out_two_steps.reshape(inp.shape[0], inp.shape[1] * inp.shape[2] * basis.n_basis_funcs)
```{glue} two-step-convolution-source-code
```

# check that this is equivalent to the output of compute_features
print(f"All matching: {np.array_equal(out_two_steps, out, equal_nan=True)}")
```{glue} two-step-convolution
```

:::

Plotting the Basis Function Elements
------------------------------------
We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
particularly evident when working with multidimensional basis functions. You can find more details and visual
background in the
[2D basis elements plotting section](plotting-2d-additive-basis-elements).
the equi-spaced samples along with the evaluated basis functions.

:::{admonition} Note

The array returned by `evaluate_on_grid(n_samples)` is the same as the kernel that is used by the Conv bases initialized with `window_sizes=n_samples`!

:::

```{code-cell} ipython3
# Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples
Expand All @@ -233,12 +205,13 @@ plt.plot(equispaced_samples, eval_basis)
plt.show()
```

The benefits of using `evaluate_on_grid` become particularly evident when working with multidimensional basis functions. You can find more details in the [2D basis elements plotting section](plotting-2d-additive-basis-elements).

## Setting the basis support (Eval only)
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions).
parameter at initialization of Eval bases.
Evaluating the basis at any sample outside the bounds will result in a NaN.


Expand All @@ -265,25 +238,3 @@ axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```

Other Basis Types
-----------------
Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
please refer to the [API Guide](nemos_basis). After instantiation, all classes
share the same syntax for basis evaluation. The following is an example of how to instantiate and
evaluate a log-spaced cosine raised function basis.


```{code-cell} ipython3
# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50)
# Evaluate the raised cosine basis at the equi-spaced sample points
# (same method in all Basis elements)
samples, eval_basis = raised_cosine_log.evaluate_on_grid(100)
# Plot the evaluated log-spaced raised cosine basis
plt.figure()
plt.title(f"Log-spaced Raised Cosine basis with {eval_basis.shape[1]} elements")
plt.plot(samples, eval_basis)
plt.show()
```
28 changes: 3 additions & 25 deletions docs/background/basis/plot_02_ND_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ combination of some multidimensional basis elements.

In this document, we introduce two strategies for defining a high-dimensional basis function by combining
two lower-dimensional bases. We refer to these strategies as "addition" and "multiplication" of bases,
and the resulting basis objects will be referred to as additive or multiplicative basis respectively.
and the resulting basis objects will be referred to as additive or multiplicative basis respectively: additive bases have their component bases operate *independently*, whereas multiplicative bases take the *outer product*. And these composite basis objects can be constructed using other composite bases, so that you can combine them as much as you'd like!


Consider we have two inputs $\mathbf{x} \in \mathbb{R}^N,\; \mathbf{y}\in \mathbb{R}^M$.
More precisely, let's say we have two inputs $\mathbf{x} \in \mathbb{R}^N,\; \mathbf{y}\in \mathbb{R}^M$.
Let's say we've defined two basis functions for these inputs:

- $ [ a_0 (\mathbf{x}), ..., a_{k-1} (\mathbf{x}) ] $ for $\mathbf{x}$
Expand Down Expand Up @@ -106,6 +105,7 @@ In the subsequent sections, we will:
1. Demonstrate the definition, evaluation, and visualization of 2D additive and multiplicative bases.
2. Illustrate how to iteratively apply addition and multiplication operations to extend to dimensions beyond two.

(composite_basis_2d)=
## 2D Basis Functions

Consider an instance where we want to capture a neuron's response to an animal's position within a given arena.
Expand Down Expand Up @@ -292,28 +292,6 @@ axs[2, 1].set_xlabel('y-coord')
plt.tight_layout()
```

```{code-cell} ipython3
:tags: [hide-input]
# save image for thumbnail
from pathlib import Path
import os
root = os.environ.get("READTHEDOCS_OUTPUT")
if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../../_build/html/_static/thumbnails/background")
# make sure the folder exists if run from build
if root or Path("../../assets/stylesheets").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
fig.savefig(path / "plot_02_ND_basis_function.svg")
```

:::{note}
Basis objects of different types can be combined through multiplication or addition.
This feature is particularly useful when one of the axes represents a periodic variable and another is non-periodic.
Expand Down
Loading

0 comments on commit 8318716

Please sign in to comment.