Skip to content

Commit

Permalink
Merge pull request #273 from flatironinstitute/basis_refactor_pr1
Browse files Browse the repository at this point in the history
Basis refactor pr1
  • Loading branch information
BalzaniEdoardo authored Dec 11, 2024
2 parents 205f409 + 00a2437 commit 74bde2b
Show file tree
Hide file tree
Showing 36 changed files with 7,179 additions and 8,568 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build:
- gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9
post_build:
# Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking.
- htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/"
- htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/"
# The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset
# therefore for this file I am not checking the figures.
- htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003"
Expand Down
82 changes: 73 additions & 9 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,88 @@ Classes for creating Generalized Linear Models (GLMs) for both single neurons an
The ``nemos.basis`` module
--------------------------
Provides basis function classes to construct and transform features for model inputs.
Basis can be grouped according to the mode of operation into basis that performs convolution and basis that operates
as non-linear maps.

.. currentmodule:: nemos.basis

**The Abstract Classes:**

These classes are the building blocks for the concrete basis classes.

.. currentmodule:: nemos.basis._basis

.. autosummary::
:toctree: generated/basis
:toctree: generated/_basis
:recursive:
:nosignatures:

Basis

.. currentmodule:: nemos.basis._spline_basis
.. autosummary::
:toctree: generated/_basis
:recursive:
:nosignatures:

SplineBasis
BSplineBasis
CyclicBSplineBasis
MSplineBasis
OrthExponentialBasis
RaisedCosineBasisLinear
RaisedCosineBasisLog


**Bases For Convolution:**

.. currentmodule:: nemos.basis

.. autosummary::
:toctree: generated/basis
:recursive:
:nosignatures:


MSplineConv
BSplineConv
CyclicBSplineConv
RaisedCosineLinearConv
RaisedCosineLogConv
OrthExponentialConv

.. check for a config that prints only nemos.basis.Name
**Bases For Non-Linear Mapping:**

.. currentmodule:: nemos.basis

.. autosummary::
:toctree: generated/basis
:recursive:
:nosignatures:

MSplineEval
BSplineEval
CyclicBSplineEval
RaisedCosineLinearEval
RaisedCosineLogEval
OrthExponentialEval

**Composite Bases:**

.. currentmodule:: nemos.basis._basis

.. autosummary::
:toctree: generated/_basis
:recursive:
:nosignatures:

AdditiveBasis
MultiplicativeBasis

**Basis As ``scikit-learn`` Tranformers:**

.. currentmodule:: nemos.basis._transformer_basis

.. autosummary::
:toctree: generated/_transformer_basis
:recursive:
:nosignatures:

TransformerBasis

.. _observation_models:
Expand Down Expand Up @@ -130,7 +194,7 @@ These objects can be provided as input to nemos GLM methods.
.. currentmodule:: nemos.pytrees

.. autosummary::
:toctree: generated/identifiability_constraints
:toctree: generated/pytree
:recursive:
:nosignatures:

Expand Down
114 changes: 53 additions & 61 deletions docs/background/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ warnings.filterwarnings(
),
category=RuntimeWarning,
)
```

(simple_basis_function)=
# Simple Basis Function

## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineBasis`](nemos.basis.MSplineBasis).
We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
The hyperparameters required to initialize this class are:

- The number of basis functions, which should be a positive integer.
Expand All @@ -58,35 +59,26 @@ import pynapple as nap
import nemos as nmo
# configure plots some
plt.style.use(nmo.styles.plot_style)
# Initialize hyperparameters
order = 4
n_basis = 10
# Define the 1D basis function object
bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order)
bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order)
```

## Evaluating a Basis

The [`Basis`](nemos.basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis
is defined by the samples that we input to the [`__call__`](nemos.basis.Basis.__call__) method, and covers from the smallest to the largest value.

We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements.

```{code-cell} ipython3
# evaluate the basis on 100 sample points
x, y = bspline.evaluate_on_grid(100)
# Generate a time series of sample points
samples = nap.Tsd(t=np.arange(1001), d=np.linspace(0, 1,1001))
# Evaluate the basis at the sample points
eval_basis = bspline(samples)
# Output information about the evaluated basis
print(f"Evaluated B-spline of order {order} with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples.")
fig = plt.figure()
plt.title("B-spline basis")
plt.plot(samples, eval_basis);
fig = plt.figure(figsize=(5, 3))
plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```

```{code-cell} ipython3
Expand All @@ -111,49 +103,18 @@ if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")
```

## Setting the basis support
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN.


```{code-cell} ipython3
bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))
print("Evaluated basis:")
# 0.5 is within the support, 0.1 is outside the support
print(np.round(bspline_range([0.5, 0.1]), 3))
```

Let's compare the default behavior of basis (estimating the range from the samples) with
the fixed range basis.

## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:

```{code-cell} ipython3
fig, axs = plt.subplots(2,1, sharex=True)
plt.suptitle("B-spline basis ")
axs[0].plot(samples, bspline(samples), color="k")
axs[0].set_title("default")
axs[1].plot(samples, bspline_range(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```
1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.

## Basis `mode`
In constructing features, [`Basis`](nemos.basis.Basis) objects can be used in two modalities: `"eval"` for evaluate or `"conv"`
for convolve. These two modalities change the behavior of the [`compute_features`](nemos.basis.Basis.compute_features) method of [`Basis`](nemos.basis.Basis), in particular,

- If a basis is in mode `"eval"`, then [`compute_features`](nemos.basis.Basis.compute_features) simply returns the evaluated basis.
- If a basis is in mode `"conv"`, then [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis
with `window_size` specified by the user.
2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`.

Let's see how this two modalities operate.


```{code-cell} ipython3
eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval")
conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100)
eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100)
# define an input
angles = np.linspace(0, np.pi*4, 201)
Expand Down Expand Up @@ -196,11 +157,10 @@ check out the tutorial on [1D convolutions](plot_03_1D_convolution).
:::



Plotting the Basis Function Elements:
--------------------------------------
We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
particularly evident when working with multidimensional basis functions. You can find more details and visual
background in the
Expand All @@ -219,6 +179,38 @@ plt.plot(equispaced_samples, eval_basis)
plt.show()
```


## Setting the basis support (Eval only)
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions).
Evaluating the basis at any sample outside the bounds will result in a NaN.


```{code-cell} ipython3
bspline_range = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))
print("Evaluated basis:")
# 0.5 is within the support, 0.1 is outside the support
print(np.round(bspline_range.compute_features([0.5, 0.1]), 3))
```

Let's compare the default behavior of basis (estimating the range from the samples) with
the fixed range basis.


```{code-cell} ipython3
samples = np.linspace(0, 1, 200)
fig, axs = plt.subplots(2,1, sharex=True)
plt.suptitle("B-spline basis ")
axs[0].plot(samples, bspline.compute_features(samples), color="k")
axs[0].set_title("default")
axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```

Other Basis Types
-----------------
Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
Expand All @@ -228,8 +220,8 @@ evaluate a log-spaced cosine raised function basis.


```{code-cell} ipython3
# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50)
# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50)
# Evaluate the raised cosine basis at the equi-spaced sample points
# (same method in all Basis elements)
Expand Down
36 changes: 17 additions & 19 deletions docs/background/plot_02_ND_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ Here, we simply add two basis objects, `a_basis` and `b_basis`, together to defi
```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
import nemos as nmo
# Define 1D basis objects
a_basis = nmo.basis.MSplineBasis(n_basis_funcs=15, order=3)
b_basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=14)
a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3)
b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14)
# Define the 2D additive basis object
additive_basis = a_basis + b_basis
Expand All @@ -151,7 +150,7 @@ x_coord = np.linspace(0, 1, 1000)
y_coord = np.linspace(0, 1, 1000)
# Evaluate the basis functions for the given trajectory.
eval_basis = additive_basis(x_coord, y_coord)
eval_basis = additive_basis.compute_features(x_coord, y_coord)
print(f"Sum of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
Expand All @@ -170,13 +169,13 @@ basis_b_element = 1
fig, axs = plt.subplots(1, 2, figsize=(6, 3))
axs[0].set_title(f"$a_{{{basis_a_element}}}(x)$", color="b")
axs[0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis(x_coord)[:, basis_a_element], "b")
axs[0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis.compute_features(x_coord)[:, basis_a_element], "b")
axs[0].set_xlabel("x-coord")
axs[1].set_title(f"$b_{{{basis_b_element}}}(x)$", color="b")
axs[1].plot(y_coord, b_basis(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b")
axs[1].plot(y_coord, b_basis.compute_features(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis.compute_features(x_coord)[:, basis_b_element], "b")
axs[1].set_xlabel("y-coord")
plt.tight_layout()
```
Expand Down Expand Up @@ -243,7 +242,7 @@ The number of elements of the product basis will be the product of the elements

```{code-cell} ipython3
# Evaluate the product basis at the x and y coordinates
eval_basis = prod_basis(x_coord, y_coord)
eval_basis = prod_basis.compute_features(x_coord, y_coord)
# Output the number of elements and samples of the evaluated basis,
# as well as the number of elements in the original 1D basis objects
Expand All @@ -269,19 +268,19 @@ fig, axs = plt.subplots(3,3,figsize=(8, 6))
cc = 0
for i, j in element_pairs:
# plot the element form a_basis
axs[cc, 0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis(x_coord)[:, i], "b")
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b")
axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b')
# plot the element form b_basis
axs[cc, 1].plot(y_coord, b_basis(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis(y_coord)[:, j], "b")
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b")
axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b')
# select & plot the corresponding product basis element
k = i * b_basis.n_basis_funcs + j
axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap='Blues')
axs[cc, 2].set_title(f"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b')
axs[cc, 2].set_title(fr"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b')
axs[cc, 2].set_xlabel('x-coord')
axs[cc, 2].set_ylabel('y-coord')
axs[cc, 2].set_aspect("equal")
Expand Down Expand Up @@ -323,7 +322,6 @@ in a linear maze and the LFP phase angle.
:::



N-Dimensional Basis
-------------------
Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
Expand All @@ -341,13 +339,13 @@ will output a $K^N \times T$ matrix.
T = 10
n_basis = 8
a_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
b_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
c_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
a_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
b_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
prod_basis_3 = a_basis * b_basis * c_basis
samples = np.linspace(0, 1, T)
eval_basis = prod_basis_3(samples, samples, samples)
eval_basis = prod_basis_3.compute_features(samples, samples, samples)
print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} "
f"basis elements.\nEvaluation output of shape {eval_basis.shape}")
Expand Down
Loading

0 comments on commit 74bde2b

Please sign in to comment.