Skip to content

Commit

Permalink
fix tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 4, 2024
1 parent f8d64ea commit fb3dd75
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
39 changes: 38 additions & 1 deletion docs/background/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ warnings.filterwarnings(
),
category=RuntimeWarning,
)
```

(simple_basis_function)=
Expand All @@ -58,6 +59,9 @@ import pynapple as nap
import nemos as nmo
# configure plots some
plt.style.use(nmo.styles.plot_style)
# Initialize hyperparameters
order = 4
n_basis = 10
Expand All @@ -66,6 +70,39 @@ n_basis = 10
bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order)
```

We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements.

```{code-cell} ipython3
# evaluate the basis on 100 sample points
x, y = bspline.evaluate_on_grid(100)
fig = plt.figure(figsize=(5, 3))
plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```

```{code-cell} ipython3
:tags: [hide-input]
# save image for thumbnail
from pathlib import Path
import os
root = os.environ.get("READTHEDOCS_OUTPUT")
if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../_build/html/_static/thumbnails/background")
# make sure the folder exists if run from build
if root or Path("../_build/html/_static").exists():
path.mkdir(parents=True, exist_ok=True)
if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")
```

## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:

Expand All @@ -75,7 +112,6 @@ The bases in the `nemos.basis` module can be grouped into two categories:

Let's see how this two modalities operate.


```{code-cell} ipython3
eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100)
Expand Down Expand Up @@ -165,6 +201,7 @@ the fixed range basis.


```{code-cell} ipython3
samples = np.linspace(0, 1, 200)
fig, axs = plt.subplots(2,1, sharex=True)
plt.suptitle("B-spline basis ")
axs[0].plot(samples, bspline.compute_features(samples), color="k")
Expand Down
23 changes: 11 additions & 12 deletions docs/background/plot_02_ND_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ x_coord = np.linspace(0, 1, 1000)
y_coord = np.linspace(0, 1, 1000)
# Evaluate the basis functions for the given trajectory.
eval_basis = additive_basis(x_coord, y_coord)
eval_basis = additive_basis.compute_features(x_coord, y_coord)
print(f"Sum of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
Expand All @@ -169,13 +169,13 @@ basis_b_element = 1
fig, axs = plt.subplots(1, 2, figsize=(6, 3))
axs[0].set_title(f"$a_{{{basis_a_element}}}(x)$", color="b")
axs[0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis(x_coord)[:, basis_a_element], "b")
axs[0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis.compute_features(x_coord)[:, basis_a_element], "b")
axs[0].set_xlabel("x-coord")
axs[1].set_title(f"$b_{{{basis_b_element}}}(x)$", color="b")
axs[1].plot(y_coord, b_basis(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b")
axs[1].plot(y_coord, b_basis.compute_features(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis.compute_features(x_coord)[:, basis_b_element], "b")
axs[1].set_xlabel("y-coord")
plt.tight_layout()
```
Expand Down Expand Up @@ -242,7 +242,7 @@ The number of elements of the product basis will be the product of the elements

```{code-cell} ipython3
# Evaluate the product basis at the x and y coordinates
eval_basis = prod_basis(x_coord, y_coord)
eval_basis = prod_basis.compute_features(x_coord, y_coord)
# Output the number of elements and samples of the evaluated basis,
# as well as the number of elements in the original 1D basis objects
Expand All @@ -268,13 +268,13 @@ fig, axs = plt.subplots(3,3,figsize=(8, 6))
cc = 0
for i, j in element_pairs:
# plot the element form a_basis
axs[cc, 0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis(x_coord)[:, i], "b")
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b")
axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b')
# plot the element form b_basis
axs[cc, 1].plot(y_coord, b_basis(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis(y_coord)[:, j], "b")
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b")
axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b')
# select & plot the corresponding product basis element
Expand Down Expand Up @@ -322,7 +322,6 @@ in a linear maze and the LFP phase angle.
:::



N-Dimensional Basis
-------------------
Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
Expand All @@ -346,7 +345,7 @@ c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
prod_basis_3 = a_basis * b_basis * c_basis
samples = np.linspace(0, 1, T)
eval_basis = prod_basis_3(samples, samples, samples)
eval_basis = prod_basis_3.compute_features(samples, samples, samples)
print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} "
f"basis elements.\nEvaluation output of shape {eval_basis.shape}")
Expand Down

0 comments on commit fb3dd75

Please sign in to comment.