diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md index 424be525..fa5c5879 100644 --- a/docs/background/plot_01_1D_basis_function.md +++ b/docs/background/plot_01_1D_basis_function.md @@ -38,6 +38,7 @@ warnings.filterwarnings( ), category=RuntimeWarning, ) + ``` (simple_basis_function)= @@ -58,6 +59,9 @@ import pynapple as nap import nemos as nmo +# configure plots some +plt.style.use(nmo.styles.plot_style) + # Initialize hyperparameters order = 4 n_basis = 10 @@ -66,6 +70,39 @@ n_basis = 10 bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order) ``` +We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements. + +```{code-cell} ipython3 +# evaluate the basis on 100 sample points +x, y = bspline.evaluate_on_grid(100) + +fig = plt.figure(figsize=(5, 3)) +plt.plot(x, y, lw=2) +plt.title("B-Spline Basis") +``` + +```{code-cell} ipython3 +:tags: [hide-input] + +# save image for thumbnail +from pathlib import Path +import os + +root = os.environ.get("READTHEDOCS_OUTPUT") +if root: + path = Path(root) / "html/_static/thumbnails/background" +# if local store in ../_build/html/... +else: + path = Path("../_build/html/_static/thumbnails/background") + +# make sure the folder exists if run from build +if root or Path("../_build/html/_static").exists(): + path.mkdir(parents=True, exist_ok=True) + +if path.exists(): + fig.savefig(path / "plot_01_1D_basis_function.svg") +``` + ## Feature Computation The bases in the `nemos.basis` module can be grouped into two categories: @@ -75,7 +112,6 @@ The bases in the `nemos.basis` module can be grouped into two categories: Let's see how this two modalities operate. - ```{code-cell} ipython3 eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis) conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100) @@ -165,6 +201,7 @@ the fixed range basis. ```{code-cell} ipython3 +samples = np.linspace(0, 1, 200) fig, axs = plt.subplots(2,1, sharex=True) plt.suptitle("B-spline basis ") axs[0].plot(samples, bspline.compute_features(samples), color="k") diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/plot_02_ND_basis_function.md index 03c0062d..a9636285 100644 --- a/docs/background/plot_02_ND_basis_function.md +++ b/docs/background/plot_02_ND_basis_function.md @@ -150,7 +150,7 @@ x_coord = np.linspace(0, 1, 1000) y_coord = np.linspace(0, 1, 1000) # Evaluate the basis functions for the given trajectory. -eval_basis = additive_basis(x_coord, y_coord) +eval_basis = additive_basis.compute_features(x_coord, y_coord) print(f"Sum of two 1D splines with {eval_basis.shape[1]} " f"basis element and {eval_basis.shape[0]} samples:\n" @@ -169,13 +169,13 @@ basis_b_element = 1 fig, axs = plt.subplots(1, 2, figsize=(6, 3)) axs[0].set_title(f"$a_{{{basis_a_element}}}(x)$", color="b") -axs[0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3) -axs[0].plot(x_coord, a_basis(x_coord)[:, basis_a_element], "b") +axs[0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) +axs[0].plot(x_coord, a_basis.compute_features(x_coord)[:, basis_a_element], "b") axs[0].set_xlabel("x-coord") axs[1].set_title(f"$b_{{{basis_b_element}}}(x)$", color="b") -axs[1].plot(y_coord, b_basis(x_coord), "grey", alpha=.3) -axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b") +axs[1].plot(y_coord, b_basis.compute_features(x_coord), "grey", alpha=.3) +axs[1].plot(y_coord, b_basis.compute_features(x_coord)[:, basis_b_element], "b") axs[1].set_xlabel("y-coord") plt.tight_layout() ``` @@ -242,7 +242,7 @@ The number of elements of the product basis will be the product of the elements ```{code-cell} ipython3 # Evaluate the product basis at the x and y coordinates -eval_basis = prod_basis(x_coord, y_coord) +eval_basis = prod_basis.compute_features(x_coord, y_coord) # Output the number of elements and samples of the evaluated basis, # as well as the number of elements in the original 1D basis objects @@ -268,13 +268,13 @@ fig, axs = plt.subplots(3,3,figsize=(8, 6)) cc = 0 for i, j in element_pairs: # plot the element form a_basis - axs[cc, 0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3) - axs[cc, 0].plot(x_coord, a_basis(x_coord)[:, i], "b") + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b") axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b') # plot the element form b_basis - axs[cc, 1].plot(y_coord, b_basis(y_coord), "grey", alpha=.3) - axs[cc, 1].plot(y_coord, b_basis(y_coord)[:, j], "b") + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3) + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b") axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b') # select & plot the corresponding product basis element @@ -322,7 +322,6 @@ in a linear maze and the LFP phase angle. ::: - N-Dimensional Basis ------------------- Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of @@ -346,7 +345,7 @@ c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis) prod_basis_3 = a_basis * b_basis * c_basis samples = np.linspace(0, 1, T) -eval_basis = prod_basis_3(samples, samples, samples) +eval_basis = prod_basis_3.compute_features(samples, samples, samples) print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} " f"basis elements.\nEvaluation output of shape {eval_basis.shape}")