Skip to content

Commit

Permalink
remove calls to basis from tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 4, 2024
1 parent c9ec5ad commit 67479a4
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 12 deletions.
6 changes: 5 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ API Reference
=============

.. _nemos_glm:

The ``nemos.glm`` module
------------------------
Classes for creating Generalized Linear Models (GLMs) for both single neurons and neural populations.
Expand All @@ -19,6 +20,7 @@ Classes for creating Generalized Linear Models (GLMs) for both single neurons an
PopulationGLM

.. _nemos_basis:

The ``nemos.basis`` module
--------------------------
Provides basis function classes to construct and transform features for model inputs.
Expand Down Expand Up @@ -107,8 +109,9 @@ These classes are the building blocks for the concrete basis classes.
TransformerBasis

.. _observation_models:

The ``nemos.observation_models`` module
--------------------------------------
---------------------------------------
Statistical models to describe the distribution of neural responses or other predicted variables, given inputs.

.. currentmodule:: nemos.observation_models
Expand All @@ -123,6 +126,7 @@ Statistical models to describe the distribution of neural responses or other pre
GammaObservations

.. _regularizers:

The ``nemos.regularizer`` module
--------------------------------
Implements various regularization techniques to constrain model parameters, which helps prevent overfitting.
Expand Down
4 changes: 2 additions & 2 deletions docs/background/basis/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ For N-dimensional input, with $N>1$, the method assumes that first axis is the s
For "Eval" basis, `compute_features` is equivalent to "calling" the basis and then reshaping the input into a 2-dimensional feature matrix.

```{code-cell} ipython3
basis = nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=5)
basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=5)
# generate a 3D array
inp = np.random.randn(50, 2, 3)
Expand All @@ -185,7 +185,7 @@ For each of the `3 * 2 = 6` inputs, `n_basis_funcs = 5` features are computed. T
For "Conv" type basis, `compute_features` is equivalent to convolving each input with `n_basis_funcs` kernels, and concatenate the output into a 2D design matrix.

```{code-cell} ipython3
basis = nmo.basis.ConvRaisedCosineLinear(n_basis_funcs=5, window_size=6)
basis = nmo.basis.RaisedCosineLinearConv(n_basis_funcs=5, window_size=6)
# compute_features to perform the convolution and concatenate
out = basis.compute_features(inp)
Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,7 @@
viewcode_follow_imported_members = True

# option for mpl extension
plot_html_show_formats = False
plot_html_show_formats = False

# raise an error if exec error in notebooks
nb_execution_raise_on_error = True
10 changes: 5 additions & 5 deletions docs/how_to_guide/plot_06_glm_pytree.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ spikes = nwb['units'][unit_no]
basis = nmo.basis.CyclicBSplineEval(10, order=5)
x = np.linspace(-np.pi, np.pi, 100)
plt.figure()
plt.plot(x, basis(x))
plt.plot(x, basis.compute_features(x))
# Find the interval on which head_dir has no NaNs
head_dir = head_dir.dropna()
Expand All @@ -300,7 +300,7 @@ spikes = spikes.count(bin_size=1/head_dir.rate, ep=valid_data)
# center.
head_dir = head_dir.interpolate(spikes)
X = nmo.pytrees.FeaturePytree(head_direction=basis(head_dir))
X = nmo.pytrees.FeaturePytree(head_direction=basis.compute_features(head_dir))
```

Now we'll fit our GLM and then see what our head direction tuning looks like:
Expand All @@ -311,7 +311,7 @@ model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)
model.fit(X, spikes)
print(model.coef_['head_direction'])
bs_vis = basis(x)
bs_vis = basis.compute_features(x)
tuning = jnp.einsum('b, tb->t', model.coef_['head_direction'], bs_vis)
plt.figure()
plt.polar(x, tuning)
Expand Down Expand Up @@ -354,7 +354,7 @@ our data similarly.
pos_basis = nmo.basis.RaisedCosineLinearEval(10) * nmo.basis.RaisedCosineLinearEval(10)
spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data)
X['spatial_position'] = pos_basis(*spatial_pos.values.T)
X['spatial_position'] = pos_basis.compute_features(*spatial_pos.values.T)
```

Running the GLM is identical to before, but we can see that our coef_
Expand All @@ -373,7 +373,7 @@ coefficients).


```{code-cell} ipython3
bs_vis = basis(x)
bs_vis = basis.compute_features(x)
tuning = jnp.einsum('b,nb->n', model.coef_['head_direction'], bs_vis)
print(model.coef_['head_direction'])
plt.figure()
Expand Down
5 changes: 2 additions & 3 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _compute_features(self, *xi: NDArray):
Apply the basis transformation to the input data.
The basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i` is a
basis element. xi[k] must be a one-dimensional array or a pynapple Tsd.
basis element. xi[k] must be a N-dimensional array (N >= 1) or a pynapple Tsd/TsdFrame/TsdTensor.
Parameters
----------
Expand All @@ -36,8 +36,7 @@ def _compute_features(self, *xi: NDArray):
-------
:
A matrix with the transformed features. The basis evaluated at the samples,
or :math:`b_i(*xi)`, where :math:`b_i` is a basis element. xi[k] must be a one-dimensional array
or a pynapple Tsd.
or :math:`b_i(*xi)`, where :math:`b_i` is a basis element.
"""
out = self._evaluate(*(np.reshape(x, (x.shape[0], -1)) for x in xi))
Expand Down

0 comments on commit 67479a4

Please sign in to comment.