From 67479a475da1408a408df8ed28ffd382c1234fc2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 17:27:23 -0500 Subject: [PATCH] remove calls to basis from tutorial --- docs/api_reference.rst | 6 +++++- docs/background/basis/plot_01_1D_basis_function.md | 4 ++-- docs/conf.py | 5 ++++- docs/how_to_guide/plot_06_glm_pytree.md | 10 +++++----- src/nemos/basis/_basis_mixin.py | 5 ++--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index f41c4c02..e507434c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -4,6 +4,7 @@ API Reference ============= .. _nemos_glm: + The ``nemos.glm`` module ------------------------ Classes for creating Generalized Linear Models (GLMs) for both single neurons and neural populations. @@ -19,6 +20,7 @@ Classes for creating Generalized Linear Models (GLMs) for both single neurons an PopulationGLM .. _nemos_basis: + The ``nemos.basis`` module -------------------------- Provides basis function classes to construct and transform features for model inputs. @@ -107,8 +109,9 @@ These classes are the building blocks for the concrete basis classes. TransformerBasis .. _observation_models: + The ``nemos.observation_models`` module --------------------------------------- +--------------------------------------- Statistical models to describe the distribution of neural responses or other predicted variables, given inputs. .. currentmodule:: nemos.observation_models @@ -123,6 +126,7 @@ Statistical models to describe the distribution of neural responses or other pre GammaObservations .. _regularizers: + The ``nemos.regularizer`` module -------------------------------- Implements various regularization techniques to constrain model parameters, which helps prevent overfitting. diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index c2437de8..55c5b80c 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -168,7 +168,7 @@ For N-dimensional input, with $N>1$, the method assumes that first axis is the s For "Eval" basis, `compute_features` is equivalent to "calling" the basis and then reshaping the input into a 2-dimensional feature matrix. ```{code-cell} ipython3 -basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=5) +basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=5) # generate a 3D array inp = np.random.randn(50, 2, 3) @@ -185,7 +185,7 @@ For each of the `3 * 2 = 6` inputs, `n_basis_funcs = 5` features are computed. T For "Conv" type basis, `compute_features` is equivalent to convolving each input with `n_basis_funcs` kernels, and concatenate the output into a 2D design matrix. ```{code-cell} ipython3 -basis = nmo.basis.ConvRaisedCosineLinear(n_basis_funcs=5, window_size=6) +basis = nmo.basis.RaisedCosineLinearConv(n_basis_funcs=5, window_size=6) # compute_features to perform the convolution and concatenate out = basis.compute_features(inp) diff --git a/docs/conf.py b/docs/conf.py index 2e752238..240085ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -168,4 +168,7 @@ viewcode_follow_imported_members = True # option for mpl extension -plot_html_show_formats = False \ No newline at end of file +plot_html_show_formats = False + +# raise an error if exec error in notebooks +nb_execution_raise_on_error = True \ No newline at end of file diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index 4c82be40..d6608c45 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -286,7 +286,7 @@ spikes = nwb['units'][unit_no] basis = nmo.basis.CyclicBSplineEval(10, order=5) x = np.linspace(-np.pi, np.pi, 100) plt.figure() -plt.plot(x, basis(x)) +plt.plot(x, basis.compute_features(x)) # Find the interval on which head_dir has no NaNs head_dir = head_dir.dropna() @@ -300,7 +300,7 @@ spikes = spikes.count(bin_size=1/head_dir.rate, ep=valid_data) # center. head_dir = head_dir.interpolate(spikes) -X = nmo.pytrees.FeaturePytree(head_direction=basis(head_dir)) +X = nmo.pytrees.FeaturePytree(head_direction=basis.compute_features(head_dir)) ``` Now we'll fit our GLM and then see what our head direction tuning looks like: @@ -311,7 +311,7 @@ model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001) model.fit(X, spikes) print(model.coef_['head_direction']) -bs_vis = basis(x) +bs_vis = basis.compute_features(x) tuning = jnp.einsum('b, tb->t', model.coef_['head_direction'], bs_vis) plt.figure() plt.polar(x, tuning) @@ -354,7 +354,7 @@ our data similarly. pos_basis = nmo.basis.RaisedCosineLinearEval(10) * nmo.basis.RaisedCosineLinearEval(10) spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data) -X['spatial_position'] = pos_basis(*spatial_pos.values.T) +X['spatial_position'] = pos_basis.compute_features(*spatial_pos.values.T) ``` Running the GLM is identical to before, but we can see that our coef_ @@ -373,7 +373,7 @@ coefficients). ```{code-cell} ipython3 -bs_vis = basis(x) +bs_vis = basis.compute_features(x) tuning = jnp.einsum('b,nb->n', model.coef_['head_direction'], bs_vis) print(model.coef_['head_direction']) plt.figure() diff --git a/src/nemos/basis/_basis_mixin.py b/src/nemos/basis/_basis_mixin.py index db3b3b4c..d30c850e 100644 --- a/src/nemos/basis/_basis_mixin.py +++ b/src/nemos/basis/_basis_mixin.py @@ -24,7 +24,7 @@ def _compute_features(self, *xi: NDArray): Apply the basis transformation to the input data. The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a - basis element. xi[k] must be a one-dimensional array or a pynapple Tsd. + basis element. xi[k] must be a N-dimensional array (N >= 1) or a pynapple Tsd/TsdFrame/TsdTensor. Parameters ---------- @@ -36,8 +36,7 @@ def _compute_features(self, *xi: NDArray): ------- : A matrix with the transformed features. The basis evaluated at the samples, - or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array - or a pynapple Tsd. + or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. """ out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi))