diff --git a/.gitignore b/.gitignore
index 970ea314..eb0dc77e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -117,8 +117,8 @@ venv.bak/
# Rope project settings
.ropeproject
-# mkdocs documentation
-/site
+# sphinx build documentation
+/docs/_build
# mypy
.mypy_cache/
@@ -143,3 +143,14 @@ docs/generated/
# nwb cahce
nwb-cache/
+
+# scripting folder
+_scripts/
+
+# env variable for nwbs
+docs/data/
+
+# rst generated files
+docs/stubs
+
+
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index b1e575be..ecaca19d 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -14,13 +14,14 @@ build:
pre_build:
- gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9
post_build:
- # Check everything but the reference (because mkdocstrings do not set href=)
- - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/reference\/.+/"
- # Check the reference allowing missing href
- - htmlproofer $READTHEDOCS_OUTPUT/html/reference --assume-extension --check-external-hash --ignore-urls "https://fonts.gstatic.com" --allow-missing-href --ignore-status-codes 403
-
-mkdocs:
- configuration: mkdocs.yml
+ # Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking.
+ - htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/"
+ # The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset
+ # therefore for this file I am not checking the figures.
+ - htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003"
+# Build documentation in the docs/ directory with Sphinx
+sphinx:
+ configuration: docs/conf.py
# Optionally declare the Python requirements required to build your docs
python:
diff --git a/docs/CCN-logo-wText.png b/docs/CCN-logo-wText.png
deleted file mode 100644
index 51fc110d..00000000
Binary files a/docs/CCN-logo-wText.png and /dev/null differ
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..9d97911b
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,21 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+# Set an environ variable available during sphinx build
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst
new file mode 100644
index 00000000..068175b2
--- /dev/null
+++ b/docs/_templates/autosummary/class.rst
@@ -0,0 +1,34 @@
+{{ fullname | escape | underline }}
+
+.. currentmodule:: {{ module }}
+
+.. autoclass:: {{ objname }}
+ :members:
+ :inherited-members:
+
+{% block attributes %}
+ {% if attributes %}
+ .. rubric:: Attributes
+
+ .. autosummary::
+ :toctree: ./
+ {% for item in attributes %}
+ ~{{ objname }}.{{ item }}
+ {%- endfor %}
+ {% endif %}
+{% endblock %}
+
+{% block methods %}
+ .. automethod:: __init__
+
+ {% if methods %}
+ .. rubric:: Methods
+
+ .. autosummary::
+ :toctree: ./
+ {% for item in methods %}
+ ~{{ objname }}.{{ item }}
+ {%- endfor %}
+ {% endif %}
+{% endblock %}
+
diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html
new file mode 100644
index 00000000..9a987f9b
--- /dev/null
+++ b/docs/_templates/layout.html
@@ -0,0 +1,16 @@
+{% extends "!layout.html" %}
+
+{%- block footer %}
+
+{%- endblock %}
\ No newline at end of file
diff --git a/docs/api_reference.rst b/docs/api_reference.rst
new file mode 100644
index 00000000..779e31a6
--- /dev/null
+++ b/docs/api_reference.rst
@@ -0,0 +1,137 @@
+.. _api_ref:
+
+API Reference
+=============
+
+.. _nemos_glm:
+The ``nemos.glm`` module
+------------------------
+Classes for creating Generalized Linear Models (GLMs) for both single neurons and neural populations.
+
+.. currentmodule:: nemos.glm
+
+.. autosummary::
+ :toctree: generated/glm
+ :recursive:
+ :nosignatures:
+
+ GLM
+ PopulationGLM
+
+.. _nemos_basis:
+The ``nemos.basis`` module
+--------------------------
+Provides basis function classes to construct and transform features for model inputs.
+
+.. currentmodule:: nemos.basis
+
+.. autosummary::
+ :toctree: generated/basis
+ :recursive:
+ :nosignatures:
+
+ Basis
+ SplineBasis
+ BSplineBasis
+ CyclicBSplineBasis
+ MSplineBasis
+ OrthExponentialBasis
+ RaisedCosineBasisLinear
+ RaisedCosineBasisLog
+ AdditiveBasis
+ MultiplicativeBasis
+ TransformerBasis
+
+.. _observation_models:
+The ``nemos.observation_models`` module
+--------------------------------------
+Statistical models to describe the distribution of neural responses or other predicted variables, given inputs.
+
+.. currentmodule:: nemos.observation_models
+
+.. autosummary::
+ :toctree: generated/observation_models
+ :recursive:
+ :nosignatures:
+
+ Observations
+ PoissonObservations
+ GammaObservations
+
+.. _regularizers:
+The ``nemos.regularizer`` module
+--------------------------------
+Implements various regularization techniques to constrain model parameters, which helps prevent overfitting.
+
+.. currentmodule:: nemos.regularizer
+
+.. autosummary::
+ :toctree: generated/regularizer
+ :recursive:
+ :nosignatures:
+
+ Regularizer
+ UnRegularized
+ Ridge
+ Lasso
+ GroupLasso
+
+The ``nemos.simulation`` module
+-------------------------------
+Utility functions for simulating spiking activity in recurrently connected neural populations.
+
+.. currentmodule:: nemos.simulation
+
+.. autosummary::
+ :toctree: generated/simulation
+ :recursive:
+ :nosignatures:
+
+ simulate_recurrent
+ difference_of_gammas
+ regress_filter
+
+
+The ``nemos.convolve`` module
+-----------------------------
+Utility functions for running convolution over the sample axis.
+
+.. currentmodule:: nemos.convolve
+
+.. autosummary::
+ :toctree: generated/regularizer
+ :recursive:
+ :nosignatures:
+
+ create_convolutional_predictor
+ tensor_convolve
+
+
+The ``nemos.identifiability_constraints`` module
+------------------------------------------------
+Functions to apply identifiability constraints to rank-deficient feature matrices, ensuring the uniqueness of model
+solutions.
+
+.. currentmodule:: nemos.identifiability_constraints
+
+.. autosummary::
+ :toctree: generated/identifiability_constraints
+ :recursive:
+ :nosignatures:
+
+ apply_identifiability_constraints
+ apply_identifiability_constraints_by_basis_component
+
+The ``nemos.pytrees.FeaturePytree`` class
+-----------------------------------------
+Class for storing the input arrays in a dictionary. Keys are usually variable names.
+These objects can be provided as input to nemos GLM methods.
+
+.. currentmodule:: nemos.pytrees
+
+.. autosummary::
+ :toctree: generated/identifiability_constraints
+ :recursive:
+ :nosignatures:
+
+ FeaturePytree
diff --git a/docs/assets/NeMoS_Logo_CMYK_White.svg b/docs/assets/NeMoS_Logo_CMYK_White.svg
new file mode 100755
index 00000000..cb453c01
--- /dev/null
+++ b/docs/assets/NeMoS_Logo_CMYK_White.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/developers_notes/classes_nemos.png b/docs/assets/classes_nemos.png
similarity index 100%
rename from docs/developers_notes/classes_nemos.png
rename to docs/assets/classes_nemos.png
diff --git a/docs/developers_notes/classes_nemos.svg b/docs/assets/classes_nemos.svg
similarity index 100%
rename from docs/developers_notes/classes_nemos.svg
rename to docs/assets/classes_nemos.svg
diff --git a/docs/assets/extra.css b/docs/assets/extra.css
deleted file mode 100644
index 288364bd..00000000
--- a/docs/assets/extra.css
+++ /dev/null
@@ -1,3 +0,0 @@
-.notes {
- display: none;
-}
diff --git a/docs/head_dir_tuning.jpg b/docs/assets/head_dir_tuning.jpg
similarity index 100%
rename from docs/head_dir_tuning.jpg
rename to docs/assets/head_dir_tuning.jpg
diff --git a/docs/assets/logo_flatiron_white.svg b/docs/assets/logo_flatiron_white.svg
new file mode 100644
index 00000000..9509e67d
--- /dev/null
+++ b/docs/assets/logo_flatiron_white.svg
@@ -0,0 +1,206 @@
+
+
+
+
diff --git a/docs/assets/stylesheets/custom.css b/docs/assets/stylesheets/custom.css
new file mode 100644
index 00000000..f7dc7a81
--- /dev/null
+++ b/docs/assets/stylesheets/custom.css
@@ -0,0 +1,110 @@
+.bd-main .bd-content .bd-article-container{
+ max-width:100%;
+ flex-grow: 1;
+}
+
+html[data-theme=light]{
+ --pst-color-primary: rgb(52, 54, 99);
+ --pst-color-secondary: rgb(107, 161, 174);
+ --pst-color-link: rgb(74, 105, 145);
+ --pst-color-inline-code: rgb(96, 141, 130);
+}
+
+:root {
+ --pst-font-size-h1: 38px;
+ --pst-font-size-h2: 32px;
+ --pst-font-size-h3: 27px;
+ --pst-font-size-h4: 22px;
+ --pst-font-size-h5: 18px;
+ --pst-font-size-h6: 15px;
+}
+
+.iconify {
+ display: inline-block;
+ width: 2em;
+ height: 2em;
+ vertical-align: middle;
+}
+
+.cards {
+ display: grid;
+ gap: 2rem; /* Adjust as needed */
+}
+
+@media (min-width: 1024px) {
+ .cards {
+ grid-template-columns: repeat(3, 1fr); /* 3 columns on large screens */
+ }
+}
+
+@media (min-width: 768px) and (max-width: 1023px) {
+ .cards {
+ grid-template-columns: repeat(2, 1fr); /* 2 columns on medium screens */
+ }
+}
+
+@media (max-width: 767px) {
+ .cards {
+ grid-template-columns: 1fr; /* 1 column on small screens */
+ }
+}
+
+.sd-card-title {
+ height: 2.5rem; /* Adjust to a height that fits your content */
+ display: flex;
+ align-items: center;
+ margin-bottom: 0.5rem; /* Spacing between title and separator */
+}
+
+.sd-card-body hr {
+ width: 100%;
+ border: 0;
+ border-top: 2px solid #ddd; /* Light gray for a subtle line */
+ margin: 5%; /* Remove any default margin around the
*/
+}
+
+.sd-card-body {
+ padding-top: .5rem; /* Adjust padding for consistent layout */
+}
+
+.card-footer-content {
+ margin-top: auto; /* Pushes the footer content to the bottom */
+}
+
+
+.sphinxsidebar .globaltoc {
+ display: none;
+}
+
+
+/* Add vertical spacing between cards */
+.sd-col {
+ margin-bottom: 20px; /* Adjust this value as needed for desired spacing */
+}
+
+/* Style level-1 ToC entries */
+.toctree-l1 {
+ font-weight: bold;
+}
+
+/* Keep level-2 and deeper ToC entries unbolded */
+.toctree-l2,
+.toctree-l3,
+.toctree-l4 {
+ font-weight: normal;
+}
+
+/*!* Style the brackets *!*/
+/*span.fn-bracket {*/
+/* color: #666; !* Dim the brackets *!*/
+/*}*/
+
+/*!* Style the links within the footnotes *!*/
+/*aside.footnote a {*/
+/* color: #007BFF; !* Blue link color *!*/
+/* text-decoration: none; !* Remove underline *!*/
+/*}*/
+
+/*aside.footnote a:hover {*/
+/* text-decoration: underline; !* Add underline on hover *!*/
+/*}*/
diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css
deleted file mode 100644
index ae328858..00000000
--- a/docs/assets/stylesheets/extra.css
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Flexbox layout for the list items */
-.grid.cards ul > li {
- display: flex !important; /* Ensure flexbox is applied */
- flex-direction: column;
- justify-content: space-between;
- height: 100%;
- box-sizing: border-box;
-}
-
-/* Ensure the link stays at the bottom */
-.grid.cards ul > li > p:last-of-type {
- margin-top: auto;
- padding-top: 10px; /* Optional spacing */
-}
-
-/* Adjust the spacing for the hr element */
-.grid.cards ul > li > hr {
- margin: 10px 0; /* Adjust this value to match the original spacing */
- flex-shrink: 0; /* Prevent the hr from shrinking */
-}
-
-/* Center the icon and title within the first paragraph */
-.grid.cards ul > li > p:first-of-type {
- display: flex;
- align-items: center;
- justify-content: center;
- text-align: center;
-}
-
-/* Add spacing between the icon and the title */
-.grid.cards ul > li > p:first-of-type img {
- margin-right: 8px;
-}
-
-/* this sets up the figure in the quickstart to have the appropriate size*/
-.custom-figure {
- width: 70% !important; /* Set the desired width */
-}
-
-/* Apply this style to the image within the figure */
-.custom-figure img {
- width: 100%; /* Ensure the image fits the figure */
-}
\ No newline at end of file
diff --git a/docs/background/README.md b/docs/background/README.md
index 471ad027..3215c329 100644
--- a/docs/background/README.md
+++ b/docs/background/README.md
@@ -2,9 +2,72 @@
These notes aim to provide the essential background knowledge needed to understand the models and data processing techniques implemented in NeMoS.
-??? attention "Additional requirements"
- To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
- You can install all of the required packages with the following command:
- ```
- pip install nemos[examples]
- ```
+:::{dropdown} Additional requirements
+:color: warning
+:icon: alert
+
+To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
+You can install all of the required packages with the following command:
+```
+pip install nemos[examples]
+```
+
+:::
+
+
+::::{grid} 1 2 3 3
+
+:::{grid-item-card}
+
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_00_conceptual_intro.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_01_1D_basis_function.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_02_ND_basis_function.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_03_1D_convolution.md
+```
+:::
+
+::::
diff --git a/docs/background/_plot_04_modeling.py b/docs/background/_plot_04_modeling.py
deleted file mode 100644
index b18d386e..00000000
--- a/docs/background/_plot_04_modeling.py
+++ /dev/null
@@ -1 +0,0 @@
-"""GLM models."""
\ No newline at end of file
diff --git a/docs/background/plot_00_conceptual_intro.md b/docs/background/plot_00_conceptual_intro.md
new file mode 100644
index 00000000..49c4a250
--- /dev/null
+++ b/docs/background/plot_00_conceptual_intro.md
@@ -0,0 +1,258 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+(glm_intro_background)=
+# Generalized Linear Models: An Introduction
+
+Before we dive into using NeMoS, you might wonder: why model at all? Why not
+ just make a bunch of tuning curves and submit to *Science*? Modeling is
+ helpful because:
+
+- The tuning curve reflects the correlation between neuronal spiking and
+ feature of interest, but activity might be driven by some other highly
+ correlated input (after all, [correlation does not imply
+ causation](https://xkcd.com/552/)). How do you identify what's driving
+ activity?
+
+- Your model instantiates specific hypotheses about the system (e.g., that
+ only instantaneous current matters for firing rate) and makes specific
+ quantitative predictions that can be used to compare among hypotheses.
+
+:::{attention}
+We are not claiming that the GLM will allow you to uniquely determine
+causation! Like any statistical model or method, the GLM will not solve
+causation for you (causation being a notoriously difficult problem in
+science), but it will allow you to see the effect of adding and removing
+different inputs on the predicted firing rate, which can facilitate
+causal inferences. For more reading on causation and explanation in
+neuroscience, the work of [Carl Craver](https://philosophy.wustl.edu/people/carl-f-craver)
+is a good place to start.
+:::
+
+Now that we've convinced you that modeling is worthwhile, let's get started!
+How should we begin?
+
+When modeling, it's generally a good idea to start simple and add complexity
+as needed. Simple models are:
+
+- Easier to understand, so you can more easily reason through why a model is
+ capturing or not capturing some feature of your data.
+
+- Easier to fit, so you can more quickly see how you did.
+
+- Surprisingly powerful, so you might not actually need all the bells and
+ whistles you expected.
+
+Therefore, let's start with the simplest possible model: the only input is the
+instantaneous value of some input. This is equivalent to saying that the only
+input influencing the firing rate of this neuron at time $t$ is the input it
+received at that same time. As neuroscientists, we know this isn't true, but
+given the data exploration we did above, it looks like a reasonable starting
+place. We can always build in more complications later.
+
+### GLM components
+
+The Generalized Linear Model in neuroscience can also be thought of as a LNP
+model: a linear-nonlinear-Poisson model.
+
+
+
+The model receives some input and then:
+
+- sends it through a linear filter or transformation of some sort.
+- passes that through a nonlinearity to get the *firing rate*.
+- uses the firing rate as the mean of a Poisson process to generate *spikes*.
+
+Let's step through each of those in turn.
+
+Our input feature(s) are first passed through a linear transformation, which
+rescales and shifts the input: $ \boldsymbol{W X} + \boldsymbol{c} $. In the one-dimensional case, as
+in this example, this is equivalent to scaling it by a constant and adding an
+intercept.
+
+:::{note}
+In geometry, this is more correctly referred to as an [affine
+transformation](https://en.wikipedia.org/wiki/Affine_transformation),
+which includes translations, scaling, and rotations. *Linear*
+transformations are the subset of affine transformations that do not
+include translations.
+
+In neuroscience, "linear" is the more common term, and we will use it
+throughout.
+:::
+
+This means that, in the 1d case, we have two knobs to transform the input: we
+can make it bigger or smaller, or we can shift it up or down. That is, we
+compute:
+
+$$
+L(x(t)) = w x(t) + c \tag{1}
+$$
+
+for some value of $w$ and $c$. Let's visualize some possible transformations
+that our model can make with three cartoon neurons:
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+
+# first import things
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+# some helper plotting functions
+from nemos import _documentation_utils as doc_plots
+
+# configure plots some
+plt.style.use(nmo.styles.plot_style)
+```
+
+to simplify things, we will look at three simple LNP neuron models as
+described above, working through each step of the transform. First, we will
+plot the linear transformation of the input x:
+
+
+```{code-cell} ipython3
+weights = np.asarray([.5, 4, -4])
+intercepts = np.asarray([.5, -3, -2])
+
+# make a step function with some noise riding on top
+input_feature = np.zeros(100)
+input_feature[50:] = 1
+input_feature *= np.random.rand(100)
+input_feature = nap.Tsd(np.linspace(0, 100, 100), input_feature)
+
+fig = doc_plots.lnp_schematic(input_feature, weights, intercepts)
+```
+
+With these linear transformations, we see that we can stretch or shrink the
+input and move its baseline up or down. Remember that the goal of this
+model is to predict the firing rate of the neuron. Thus, changing what
+happens when there's zero input is equivalent to changing the baseline firing
+rate of the neuron, so that's how we should think about the intercept.
+
+However, if this is meant to be the firing rate, there's something odd ---
+the output of the linear transformation is often negative, but firing rates
+have to be non-negative! That's what the nonlinearity handles: making sure our
+firing rate is always positive. We can visualize this second stage of the LNP model
+by adding the `plot_nonlinear` keyword to our `lnp_schematic()` plotting function:
+
+
+```{code-cell} ipython3
+fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
+ plot_nonlinear=True)
+```
+
+:::{note}
+In NeMoS, the non-linearity is kept fixed. We default to the exponential,
+but a small number of other choices, such as soft-plus, are allowed. The
+allowed choices guarantee both the non-negativity constraint described
+above, as well as convexity, i.e. a single optimal solution. In
+principle, one could choose a more complex non-linearity, but convexity
+is not guaranteed in general.
+:::
+
+Specifically, our firing rate is:
+
+$$
+\lambda (t) = \exp (L(x(t)) = \exp (w x(t) + c) \tag{2}
+$$
+
+We can see that the output of the nonlinear transformation is always
+positive, though note that the y-values have changed drastically.
+
+Now we're ready to look at the third step of the LNP model, and see what
+the generated spikes spikes look like!
+
+
+```{code-cell} ipython3
+# mkdocs_gallery_thumbnail_number = 3
+fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
+ plot_nonlinear=True, plot_spikes=True)
+```
+
+Remember, spiking is a stochastic process. That means that a given firing
+rate can give rise to a variety of different spike trains; the plot above
+shows three possibilities for each neuron. Each spike train is a sample from
+a Poisson process with the mean equal to the firing rate, i.e., output of
+the linear-nonlinear parts of the model.
+
+Given that this is a stochastic process that could produce an infinite number
+of possible spike trains, how do we compare our model against the single
+observed spike train we have? We use the _log-likelihood_. This quantifies how
+likely it is to observe the given spike train for the computed firing rate:
+if $y(t)$ is the spike counts and $\lambda(t)$ the firing rate, the equation
+for the log-likelihood is
+
+$$ \sum_t \log P(y(t) | \lambda(t)) = \sum_t y(t) \log(\lambda(t)) -
+\lambda(t) - \log (y(t)!)\tag{3}$$
+
+Note that this last $\log(y(t)!)$ term does not depend on $\lambda(t)$ and
+thus is independent of the model, so it is normally ignored.
+
+$$ \sum_t \log P(y(t) | \lambda(t)) \propto \sum_t y(t) \log(\lambda(t)) -
+\lambda(t))\tag{4}$$
+
+This is the objective function of the GLM model: we are trying to find the
+firing rate that maximizes the likelihood of the observed spike train.
+
+:::{note}
+
+In NeMoS, the log-likelihood can be computed directly by calling the
+`score` method, passing the predictors and the counts. The method first
+computes the rate $\lambda(t)$ using (2) and then the likelihood using
+(4). This method is used under the hood during optimization.
+
+:::
+
+
+
+## More general GLMs
+So far, we have focused on the relatively simple LNP model of spike generation, which is a special case of a GLM. The LNP model has some known shortcomings[$^{[1]}$](#ref-1). For instance, LNP ignores things like refactory periods and other history-dependent features of spiking in a neuron. As we will show in other demos, such _spike history filters_ can be built into GLMs to give more accurate results. We will also show how, if you have recordings from a large _population_ of neurons simultaneously, you can build connections between the neurons into the GLM in the form of _coupling filters_. This can help answer the degree to which activity is driven primarily by the input X, or by network influences in the population.
+
+## References
+(ref-1)=
+[1] [Pillow, JW, Shlens, J, Paninski, L, Sher, A, Litke, AM, Chichilnisky, EJ, Simoncelli, EP (2008), "Spatio-temporal correlations and visual signalling in a complete neuronal population." Nature 454: 995-9.](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2684455/)
diff --git a/docs/background/plot_00_conceptual_intro.py b/docs/background/plot_00_conceptual_intro.py
deleted file mode 100644
index 30f7c5ff..00000000
--- a/docs/background/plot_00_conceptual_intro.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# -*- coding: utf-8 -*-
-
-r"""# Generalized Linear Models: An Introduction
-
-Before we dive into using NeMoS, you might wonder: why model at all? Why not
- just make a bunch of tuning curves and submit to *Science*? Modeling is
- helpful because:
-
-- The tuning curve reflects the correlation between neuronal spiking and
- feature of interest, but activity might be driven by some other highly
- correlated input (after all, [correlation does not imply
- causation](https://xkcd.com/552/)). How do you identify what's driving
- activity?
-
-- Your model instantiates specific hypotheses about the system (e.g., that
- only instantaneous current matters for firing rate) and makes specific
- quantitative predictions that can be used to compare among hypotheses.
-
-!!! warning
-
- We are not claiming that the GLM will allow you to uniquely determine
- causation! Like any statistical model or method, the GLM will not solve
- causation for you (causation being a notoriously difficult problem in
- science), but it will allow you to see the effect of adding and removing
- different inputs on the predicted firing rate, which can facilitate
- causal inferences. For more reading on causation and explanation in
- neuroscience, the work of [Carl
- Craver](https://philosophy.wustl.edu/people/carl-f-craver) is a good
- place to start.
-
-Now that we've convinced you that modeling is worthwhile, let's get started!
-How should we begin?
-
-When modeling, it's generally a good idea to start simple and add complexity
-as needed. Simple models are:
-
-- Easier to understand, so you can more easily reason through why a model is
- capturing or not capturing some feature of your data.
-
-- Easier to fit, so you can more quickly see how you did.
-
-- Surprisingly powerful, so you might not actually need all the bells and
- whistles you expected.
-
-Therefore, let's start with the simplest possible model: the only input is the
-instantaneous value of some input. This is equivalent to saying that the only
-input influencing the firing rate of this neuron at time $t$ is the input it
-received at that same time. As neuroscientists, we know this isn't true, but
-given the data exploration we did above, it looks like a reasonable starting
-place. We can always build in more complications later.
-
-### GLM components
-
-The Generalized Linear Model in neuroscience can also be thought of as a LNP
-model: a linear-nonlinear-Poisson model.
-
-
-
-The model receives some input and then:
-
-- sends it through a linear filter or transformation of some sort.
-- passes that through a nonlinearity to get the *firing rate*.
-- uses the firing rate as the mean of a Poisson process to generate *spikes*.
-
-Let's step through each of those in turn.
-
-Our input feature(s) are first passed through a linear transformation, which
-rescales and shifts the input: $\bm{WX}+\bm{c}$. In the one-dimensional case, as
-in this example, this is equivalent to scaling it by a constant and adding an
-intercept.
-
-!!! note
-
- In geometry, this is more correctly referred to as an [affine
- transformation](https://en.wikipedia.org/wiki/Affine_transformation),
- which includes translations, scaling, and rotations. *Linear*
- transformations are the subset of affine transformations that do not
- include translations.
-
- In neuroscience, "linear" is the more common term, and we will use it
- throughout.
-
-This means that, in the 1d case, we have two knobs to transform the input: we
-can make it bigger or smaller, or we can shift it up or down. That is, we
-compute:
-
-$$L(x(t)) = w x(t) + c \tag{1}$$
-
-for some value of $w$ and $c$. Let's visualize some possible transformations
-that our model can make with three cartoon neurons:
-
-"""
-
-import matplotlib.pyplot as plt
-
-# first import things
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-# some helper plotting functions
-from nemos import _documentation_utils as doc_plots
-
-# configure plots some
-plt.style.use(nmo.styles.plot_style)
-
-# %%
-# to simplify things, we will look at three simple LNP neuron models as
-# described above, working through each step of the transform. First, we will
-# plot the linear transformation of the input x:
-
-weights = np.asarray([.5, 4, -4])
-intercepts = np.asarray([.5, -3, -2])
-
-# make a step function with some noise riding on top
-input_feature = np.zeros(100)
-input_feature[50:] = 1
-input_feature *= np.random.rand(100)
-input_feature = nap.Tsd(np.linspace(0, 100, 100), input_feature)
-
-fig = doc_plots.lnp_schematic(input_feature, weights, intercepts)
-
-# %%
-#
-# With these linear transformations, we see that we can stretch or shrink the
-# input and move its baseline up or down. Remember that the goal of this
-# model is to predict the firing rate of the neuron. Thus, changing what
-# happens when there's zero input is equivalent to changing the baseline firing
-# rate of the neuron, so that's how we should think about the intercept.
-#
-# However, if this is meant to be the firing rate, there's something odd ---
-# the output of the linear transformation is often negative, but firing rates
-# have to be non-negative! That's what the nonlinearity handles: making sure our
-# firing rate is always positive. We can visualize this second stage of the LNP model
-# by adding the `plot_nonlinear` keyword to our `lnp_schematic()` plotting function:
-
-fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
- plot_nonlinear=True)
-
-# %%
-# !!! info
-#
-# In NeMoS, the non-linearity is kept fixed. We default to the exponential,
-# but a small number of other choices, such as soft-plus, are allowed. The
-# allowed choices guarantee both the non-negativity constraint described
-# above, as well as convexity, i.e. a single optimal solution. In
-# principle, one could choose a more complex non-linearity, but convexity
-# is not guaranteed in general.
-#
-# Specifically, our firing rate is:
-# $$ \lambda (t) = \exp (L(x(t)) = \exp (w x(t) + c) \tag{2}$$
-#
-# We can see that the output of the nonlinear transformation is always
-# positive, though note that the y-values have changed drastically.
-#
-# Now we're ready to look at the third step of the LNP model, and see what
-# the generated spikes spikes look like!
-
-# mkdocs_gallery_thumbnail_number = 3
-fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
- plot_nonlinear=True, plot_spikes=True)
-
-# %%
-#
-# Remember, spiking is a stochastic process. That means that a given firing
-# rate can give rise to a variety of different spike trains; the plot above
-# shows three possibilities for each neuron. Each spike train is a sample from
-# a Poisson process with the mean equal to the firing rate, i.e., output of
-# the linear-nonlinear parts of the model.
-#
-# Given that this is a stochastic process that could produce an infinite number
-# of possible spike trains, how do we compare our model against the single
-# observed spike train we have? We use the _log-likelihood_. This quantifies how
-# likely it is to observe the given spike train for the computed firing rate:
-# if $y(t)$ is the spike counts and $\lambda(t)$ the firing rate, the equation
-# for the log-likelihood is
-#
-# $$ \sum\_t \log P(y(t) | \lambda(t)) = \sum\_t y(t) \log(\lambda(t)) -
-# \lambda(t) - \log (y(t)!)\tag{3}$$
-#
-# Note that this last $\log(y(t)!)$ term does not depend on $\lambda(t)$ and
-# thus is independent of the model, so it is normally ignored.
-#
-# $$ \sum\_t \log P(y(t) | \lambda(t)) \propto \sum\_t y(t) \log(\lambda(t)) -
-# \lambda(t))\tag{4}$$
-#
-# This is the objective function of the GLM model: we are trying to find the
-# firing rate that maximizes the likelihood of the observed spike train.
-#
-# !!! info
-#
-# In NeMoS, the log-likelihood can be computed directly by calling the
-# `score` method, passing the predictors and the counts. The method first
-# computes the rate $\lambda(t)$ using (2) and then the likelihood using
-# (4). This method is used under the hood during optimization.
-
-# %%
-# ## More general GLMs
-# So far, we have focused on the relatively simple LNP model of spike generation, which is a special case of a GLM. The LNP model has some known shortcomings[$^{[1]}$](#ref-1). For instance, LNP ignores things like refactory periods and other history-dependent features of spiking in a neuron. As we will show in other demos, such _spike history filters_ can be built into GLMs to give more accurate results. We will also show how, if you have recordings from a large _population_ of neurons simultaneously, you can build connections between the neurons into the GLM in the form of _coupling filters_. This can help answer the degree to which activity is driven primarily by the input X, or by network influences in the population.
-#
-# ## References
-#
-# [1] Pillow, JW, Shlens, J, Paninski, L, Sher, A, Litke, AM, Chichilnisky, EJ, Simoncelli, EP (2008), "Spatio-temporal correlations and visual signalling in a complete neuronal population." Nature 454: 995-9.
diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/plot_01_1D_basis_function.md
new file mode 100644
index 00000000..4d823717
--- /dev/null
+++ b/docs/background/plot_01_1D_basis_function.md
@@ -0,0 +1,243 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+(simple_basis_function)=
+# Simple Basis Function
+
+## Defining a 1D Basis Object
+
+We'll start by defining a 1D basis function object of the type [`MSplineBasis`](nemos.basis.MSplineBasis).
+The hyperparameters required to initialize this class are:
+
+- The number of basis functions, which should be a positive integer.
+- The order of the spline, which should be an integer greater than 1.
+
+```{code-cell} ipython3
+import matplotlib.pylab as plt
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+# Initialize hyperparameters
+order = 4
+n_basis = 10
+
+# Define the 1D basis function object
+bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order)
+```
+
+## Evaluating a Basis
+
+The [`Basis`](nemos.basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis
+is defined by the samples that we input to the [`__call__`](nemos.basis.Basis.__call__) method, and covers from the smallest to the largest value.
+
+
+```{code-cell} ipython3
+
+# Generate a time series of sample points
+samples = nap.Tsd(t=np.arange(1001), d=np.linspace(0, 1,1001))
+
+# Evaluate the basis at the sample points
+eval_basis = bspline(samples)
+
+# Output information about the evaluated basis
+print(f"Evaluated B-spline of order {order} with {eval_basis.shape[1]} "
+ f"basis element and {eval_basis.shape[0]} samples.")
+
+fig = plt.figure()
+plt.title("B-spline basis")
+plt.plot(samples, eval_basis);
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/background"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/background")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_01_1D_basis_function.svg")
+```
+
+## Setting the basis support
+Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
+your basis covers the same range across multiple experimental sessions.
+You can specify a range for the support of your basis by setting the `bounds`
+parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN.
+
+
+```{code-cell} ipython3
+bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))
+
+print("Evaluated basis:")
+# 0.5 is within the support, 0.1 is outside the support
+print(np.round(bspline_range([0.5, 0.1]), 3))
+```
+
+Let's compare the default behavior of basis (estimating the range from the samples) with
+the fixed range basis.
+
+
+```{code-cell} ipython3
+fig, axs = plt.subplots(2,1, sharex=True)
+plt.suptitle("B-spline basis ")
+axs[0].plot(samples, bspline(samples), color="k")
+axs[0].set_title("default")
+axs[1].plot(samples, bspline_range(samples), color="tomato")
+axs[1].set_title("bounds=[0.2, 0.8]")
+plt.tight_layout()
+```
+
+## Basis `mode`
+In constructing features, [`Basis`](nemos.basis.Basis) objects can be used in two modalities: `"eval"` for evaluate or `"conv"`
+for convolve. These two modalities change the behavior of the [`compute_features`](nemos.basis.Basis.compute_features) method of [`Basis`](nemos.basis.Basis), in particular,
+
+- If a basis is in mode `"eval"`, then [`compute_features`](nemos.basis.Basis.compute_features) simply returns the evaluated basis.
+- If a basis is in mode `"conv"`, then [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis
+ with `window_size` specified by the user.
+
+Let's see how this two modalities operate.
+
+
+```{code-cell} ipython3
+eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval")
+conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100)
+
+# define an input
+angles = np.linspace(0, np.pi*4, 201)
+y = np.cos(angles)
+
+# compute features in the two modalities
+eval_feature = eval_mode.compute_features(y)
+conv_feature = conv_mode.compute_features(y)
+
+# plot results
+fig, axs = plt.subplots( 3, 1, sharex="all", figsize=(6, 4))
+
+# plot signal
+axs[0].set_title("Input")
+axs[0].plot(y)
+axs[0].set_xticks([])
+axs[0].set_ylabel("signal", fontsize=12)
+
+# plot eval results
+axs[1].set_title("eval features")
+axs[1].imshow(eval_feature.T, aspect="auto")
+axs[1].set_xticks([])
+axs[1].set_ylabel("basis", fontsize=12)
+
+# plot conv results
+axs[2].set_title("convolutional features")
+axs[2].imshow(conv_feature.T, aspect="auto")
+axs[2].set_xlabel("time", fontsize=12)
+axs[2].set_ylabel("basis", fontsize=12)
+plt.tight_layout()
+```
+
+:::{admonition} NaN-Padding
+:class: note
+Convolution is performed in "valid" mode, and then NaN-padded. The default behavior
+is padding left, which makes the output feature causal.
+This is why the first half of the `conv_feature` is full of NaNs and appears as white.
+If you want to learn more about convolutions, as well as how and when to change defaults
+check out the tutorial on [1D convolutions](plot_03_1D_convolution).
+:::
+
+
+
+Plotting the Basis Function Elements:
+--------------------------------------
+We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
+and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
+the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
+particularly evident when working with multidimensional basis functions. You can find more details and visual
+background in the
+[2D basis elements plotting section](plotting-2d-additive-basis-elements).
+
+
+```{code-cell} ipython3
+# Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples
+n_samples = 100
+equispaced_samples, eval_basis = bspline.evaluate_on_grid(n_samples)
+
+# Plot each basis element
+plt.figure()
+plt.title(f"B-spline basis with {eval_basis.shape[1]} elements\nevaluated at {eval_basis.shape[0]} sample points")
+plt.plot(equispaced_samples, eval_basis)
+plt.show()
+```
+
+Other Basis Types
+-----------------
+Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
+please refer to the [API Guide](nemos_basis). After instantiation, all classes
+share the same syntax for basis evaluation. The following is an example of how to instantiate and
+evaluate a log-spaced cosine raised function basis.
+
+
+```{code-cell} ipython3
+# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
+raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50)
+
+# Evaluate the raised cosine basis at the equi-spaced sample points
+# (same method in all Basis elements)
+samples, eval_basis = raised_cosine_log.evaluate_on_grid(100)
+
+# Plot the evaluated log-spaced raised cosine basis
+plt.figure()
+plt.title(f"Log-spaced Raised Cosine basis with {eval_basis.shape[1]} elements")
+plt.plot(samples, eval_basis)
+plt.show()
+```
diff --git a/docs/background/plot_01_1D_basis_function.py b/docs/background/plot_01_1D_basis_function.py
deleted file mode 100644
index 3e22f052..00000000
--- a/docs/background/plot_01_1D_basis_function.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-# One-Dimensional Basis
-
-## Defining a 1D Basis Object
-
-We'll start by defining a 1D basis function object of the type `MSplineBasis`.
-The hyperparameters required to initialize this class are:
-
-- The number of basis functions, which should be a positive integer.
-- The order of the spline, which should be an integer greater than 1.
-"""
-
-import matplotlib.pylab as plt
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-# Initialize hyperparameters
-order = 4
-n_basis = 10
-
-# Define the 1D basis function object
-bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order)
-
-# %%
-# ## Evaluating a Basis
-#
-# The `Basis` object is callable, and can be evaluated as a function. By default, the support of the basis
-# is defined by the samples that we input to the `__call__` method, and covers from the smallest to the largest value.
-
-# Generate a time series of sample points
-samples = nap.Tsd(t=np.arange(1001), d=np.linspace(0, 1,1001))
-
-# Evaluate the basis at the sample points
-eval_basis = bspline(samples)
-
-# Output information about the evaluated basis
-print(f"Evaluated B-spline of order {order} with {eval_basis.shape[1]} "
- f"basis element and {eval_basis.shape[0]} samples.")
-
-plt.figure()
-plt.title("B-spline basis")
-plt.plot(eval_basis)
-
-# %%
-# ## Setting the basis support
-# Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
-# your basis covers the same range across multiple experimental sessions.
-# You can specify a range for the support of your basis by setting the `bounds`
-# parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN.
-
-bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))
-
-print("Evaluated basis:")
-# 0.5 is within the support, 0.1 is outside the support
-print(np.round(bspline_range([0.5, 0.1]), 3))
-
-
-# %%
-# Let's compare the default behavior of basis (estimating the range from the samples) with
-# the fixed range basis.
-
-fig, axs = plt.subplots(2,1, sharex=True)
-plt.suptitle("B-spline basis ")
-axs[0].plot(bspline(samples), color="k")
-axs[0].set_title("default")
-axs[1].plot(bspline_range(samples), color="tomato")
-axs[1].set_title("bounds=[0.2, 0.8]")
-plt.tight_layout()
-
-# %%
-# ## Basis `mode`
-# In constructing features, `Basis` objects can be used in two modalities: `"eval"` for evaluate or `"conv"`
-# for convolve. These two modalities change the behavior of the `construct_features` method of `Basis`, in particular,
-#
-# - If a basis is in mode `"eval"`, then `construct_features` simply returns the evaluated basis.
-# - If a basis is in mode `"conv"`, then `construct_features` will convolve the input with a kernel of basis
-# with `window_size` specified by the user.
-#
-# Let's see how this two modalities operate.
-
-eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval")
-conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100)
-
-# define an input
-angles = np.linspace(0, np.pi*4, 201)
-y = np.cos(angles)
-
-# compute features in the two modalities
-eval_feature = eval_mode.compute_features(y)
-conv_feature = conv_mode.compute_features(y)
-
-# plot results
-fig, axs = plt.subplots( 3, 1, sharex="all", figsize=(6, 4))
-
-# plot signal
-axs[0].set_title("Input")
-axs[0].plot(y)
-axs[0].set_xticks([])
-axs[0].set_ylabel("signal", fontsize=12)
-
-# plot eval results
-axs[1].set_title("eval features")
-axs[1].imshow(eval_feature.T, aspect="auto")
-axs[1].set_xticks([])
-axs[1].set_ylabel("basis", fontsize=12)
-
-# plot conv results
-axs[2].set_title("convolutional features")
-axs[2].imshow(conv_feature.T, aspect="auto")
-axs[2].set_xlabel("time", fontsize=12)
-axs[2].set_ylabel("basis", fontsize=12)
-plt.tight_layout()
-
-# %%
-#
-# !!! note "NaN-Padding"
-# Convolution is performed in "valid" mode, and then NaN-padded. The default behavior
-# is padding left, which makes the output feature causal.
-# This is why the first half of the `conv_feature` is full of NaNs and appears as white.
-# If you want to learn more about convolutions, as well as how and when to change defaults
-# check out the tutorial on [1D convolutions](../plot_03_1D_convolution).
-
-# %%
-# Plotting the Basis Function Elements:
-# --------------------------------------
-# We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
-# and then plotting the result. The method `Basis.evaluate_on_grid` is designed for this, as it generates and returns
-# the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
-# particularly evident when working with multidimensional basis functions. You can find more details and visual
-# background in the
-# [2D basis elements plotting section](../plot_02_ND_basis_function/#plotting-2d-additive-basis-elements).
-
-# Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples
-n_samples = 100
-equispaced_samples, eval_basis = bspline.evaluate_on_grid(n_samples)
-
-# Plot each basis element
-plt.figure()
-plt.title(f"B-spline basis with {eval_basis.shape[1]} elements\nevaluated at {eval_basis.shape[0]} sample points")
-plt.plot(equispaced_samples, eval_basis)
-plt.show()
-
-# %%
-# Other Basis Types
-# -----------------
-# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
-# please refer to the [API Guide](../../../reference/nemos/basis). After instantiation, all classes
-# share the same syntax for basis evaluation. The following is an example of how to instantiate and
-# evaluate a log-spaced cosine raised function basis.
-
-# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
-raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50)
-
-# Evaluate the raised cosine basis at the equi-spaced sample points
-# (same method in all Basis elements)
-samples, eval_basis = raised_cosine_log.evaluate_on_grid(100)
-
-# Plot the evaluated log-spaced raised cosine basis
-plt.figure()
-plt.title(f"Log-spaced Raised Cosine basis with {eval_basis.shape[1]} elements")
-plt.plot(samples, eval_basis)
-plt.show()
-
diff --git a/docs/background/plot_02_ND_basis_function.py b/docs/background/plot_02_ND_basis_function.md
similarity index 56%
rename from docs/background/plot_02_ND_basis_function.py
rename to docs/background/plot_02_ND_basis_function.md
index 095633e9..c14c0ba4 100644
--- a/docs/background/plot_02_ND_basis_function.py
+++ b/docs/background/plot_02_ND_basis_function.md
@@ -1,7 +1,47 @@
-# -*- coding: utf-8 -*-
-
-r"""
-# Multidimensional Basis
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+(composing_basis_function)=
+# Composing Basis Functions
## Background
@@ -18,28 +58,47 @@
Let's say we've defined two basis functions for these inputs:
- $ [ a_0 (\mathbf{x}), ..., a_{k-1} (\mathbf{x}) ] $ for $\mathbf{x}$
-- $[b_0 (\mathbf{y}), ..., b_{h-1} (\mathbf{y}) ]$ for $\mathbf{y}$.
+- $ [b_0 (\mathbf{y}), ..., b_{h-1} (\mathbf{y}) ] $ for $\mathbf{y}$.
These basis functions can be combined in the following ways:
1. **Addition:** If we assume that there is no interaction between the stimuli, the response function can be adequately described by the sum of the individual components. The function is defined as:
+
+ $$
+ f(\mathbf{x}, \mathbf{y}) \approx \sum_{i=0}^{k-1} \alpha_{i} \, a_i (\mathbf{x}) + \sum_{j=0}^{h-1} \beta_j b_j(\mathbf{y}).
$$
- f(\mathbf{x}, \mathbf{y}) \\approx \sum_{i=0}^{k-1} \\alpha_{i} \, a_i (\mathbf{x}) + \sum_{j=0}^{h-1} \\beta_j b_j(\mathbf{y}).
+
+ The resulting additive basis simply consists of the concatenation of the two basis sets:
+
+ $$
+ [A_0 (\mathbf{x}, \mathbf{y}), ..., A_{k+h-1} (\mathbf{x}, \mathbf{y})],
$$
- The resulting additive basis simply consists of the concatenation of the two basis sets: $$[A_0 (\mathbf{x}, \mathbf{y}), ..., A_{k+h-1} (\mathbf{x}, \mathbf{y})],$$ where
+
+ where
+
$$
- A_j(\mathbf{x}, \mathbf{y}) = \\begin{cases} a_j(\mathbf{x}) & \\text{if }\; j \leq k-1 \\\\\ b_{j-k+1}(\mathbf{y}) & \\text{otherwise.} \end{cases}
+ A_j(\mathbf{x}, \mathbf{y}) = \begin{cases} a_j(\mathbf{x}) &\text{if }\; j \leq k-1 \\
+ b_{j-k+1}(\mathbf{y}) &\text{otherwise.} \end{cases}
$$
+
Note that we have a total of $k+h$ basis elements, and that each element is constant in one of the axis.
2. **Multiplication:** If we expect the response function to capture arbitrary interactions between the inputs, we can approximate it as the external product of the two bases:
+
$$
- f(\mathbf{x}, \mathbf{y}) \\approx \sum_{i=0}^{k-1}\sum_{j=0}^{h-1} \\alpha_{ij} \, a_i (\mathbf{x}) b_j(\mathbf{y}).
+ f(\mathbf{x}, \mathbf{y}) \approx \sum_{i=0}^{k-1}\sum_{j=0}^{h-1} \alpha_{ij} \, a_i (\mathbf{x}) b_j(\mathbf{y}).
$$
- In this case, the resulting basis consists of the $h \cdot k$ products of the individual bases: $$[A_0(\mathbf{x}, \mathbf{y}),..., A_{k \cdot h-1}(\mathbf{x}, \mathbf{y})],$$
+
+ In this case, the resulting basis consists of the $h \cdot k$ products of the individual bases:
+
+ $$
+ [A_0(\mathbf{x}, \mathbf{y}),..., A_{k \cdot h-1}(\mathbf{x}, \mathbf{y})],
+ $$
+
where,
+
$$
- A_{i \cdot h + j}(\mathbf{x}, \mathbf{y}) = a_i(\mathbf{x})b_{j}(\mathbf{y}), \; \\text{for} \; i=0,\dots, k-1 \; \\text{ and } \; j=0,\dots,h-1.
+ A_{i \cdot h + j}(\mathbf{x}, \mathbf{y}) = a_i(\mathbf{x})b_{j}(\mathbf{y}), \; \text{for} \; i=0,\dots, k-1 \; \text{ and } \; j=0,\dots,h-1.
$$
In the subsequent sections, we will:
@@ -52,17 +111,20 @@
Consider an instance where we want to capture a neuron's response to an animal's position within a given arena.
In this scenario, the stimuli are the 2D coordinates (x, y) that represent the animal's position at each time point.
-"""
-# %%
-# ### Additive Basis Object
-# One way to model the response to our 2D stimuli is to hypothesize that it decomposes into two factors:
-# one due to the x-coordinate and another due to the y-coordinate. We can express this relationship as:
-# $$
-# f(x,y) \\approx \sum_i \alpha_i \cdot a_i(x) + \sum_j \beta_j \cdot b_j(y).
-# $$
-# Here, we simply add two basis objects, `a_basis` and `b_basis`, together to define the additive basis.
+### Additive Basis Object
+One way to model the response to our 2D stimuli is to hypothesize that it decomposes into two factors:
+one due to the x-coordinate and another due to the y-coordinate. We can express this relationship as:
+
+$$
+f(x,y) \approx \sum_i \alpha_i \cdot a_i(x) + \sum_j \beta_j \cdot b_j(y).
+$$
+
+Here, we simply add two basis objects, `a_basis` and `b_basis`, together to define the additive basis.
+
+
+```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
@@ -74,11 +136,13 @@
# Define the 2D additive basis object
additive_basis = a_basis + b_basis
+```
+
+Evaluating the additive basis will require two inputs, one for each coordinate.
+The total number of elements of the additive basis will be the sum of the elements of the 1D basis.
-# %%
-# Evaluating the additive basis will require two inputs, one for each coordinate.
-# The total number of elements of the additive basis will be the sum of the elements of the 1D basis.
+```{code-cell} ipython3
# Define a trajectory with 1000 time-points representing the recorded trajectory of the animal
T = 1000
@@ -92,11 +156,14 @@
print(f"Sum of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
f"\t- a_basis had {a_basis.n_basis_funcs} elements\n\t- b_basis had {b_basis.n_basis_funcs} elements.")
+```
-# %%
-# #### Plotting 2D Additive Basis Elements
-# Let's select and plot a basis element from each of the basis we added.
+(plotting-2d-additive-basis-elements)=
+#### Plotting 2D Additive Basis Elements
+Let's select and plot a basis element from each of the basis we added.
+
+```{code-cell} ipython3
basis_a_element = 5
basis_b_element = 1
# Plot the 1D basis elements
@@ -112,21 +179,28 @@
axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b")
axs[1].set_xlabel("y-coord")
plt.tight_layout()
+```
+
+We can visualize how these elements are extended in 2D by evaluating the additive basis
+on a grid of points that spans its domain and plotting the result.
+We use the `evaluate_on_grid` method for this.
-# %%
-# We can visualize how these elements are extended in 2D by evaluating the additive basis
-# on a grid of points that spans its domain and plotting the result.
-# We use the `evaluate_on_grid` method for this.
+```{code-cell} ipython3
X, Y, Z = additive_basis.evaluate_on_grid(200, 200)
+```
-# %%
-# We can select the indices of the 2D additive basis that corresponds to the 1D original elements.
+We can select the indices of the 2D additive basis that corresponds to the 1D original elements.
+
+```{code-cell} ipython3
basis_elem_idx = [basis_a_element, a_basis.n_basis_funcs + basis_b_element]
+```
+
+Finally, we can plot the 2D counterparts.
+
-# %%
-# Finally, we can plot the 2D counterparts.
+```{code-cell} ipython3
_, axs = plt.subplots(1, 2, subplot_kw={'aspect': 1})
# Plot the corresponding 2D elements.
@@ -143,27 +217,31 @@
axs[cc].set_ylabel("y-coord")
plt.tight_layout()
plt.show()
+```
+
+### Multiplicative Basis Object
+
+If the aim is to capture interactions between the coordinates, the response function can be modeled as the external
+product of two 1D basis functions. The approximation of the response function in this scenario would be:
-# %%
-# ### Multiplicative Basis Object
-#
-# If the aim is to capture interactions between the coordinates, the response function can be modeled as the external
-# product of two 1D basis functions. The approximation of the response function in this scenario would be:
-#
-# $$
-# f(x, y) \\approx \sum_{ij} \\alpha_{ij} \, a_i (x) b_j(y).
-# $$
-#
-# In this model, we define the 2D basis function as the product of two 1D basis objects.
-# This allows the response to capture non-linear and interaction effects between the x and y coordinates.
+$$
+f(x, y) \approx \sum_{ij} \alpha_{ij} \, a_i (x) b_j(y).
+$$
+In this model, we define the 2D basis function as the product of two 1D basis objects.
+This allows the response to capture non-linear and interaction effects between the x and y coordinates.
+
+
+```{code-cell} ipython3
# 2D basis function as the product of the two 1D basis objects
prod_basis = a_basis * b_basis
+```
-# %%
-# Again evaluating the basis will require 2 inputs.
-# The number of elements of the product basis will be the product of the elements of the two 1D bases.
+Again evaluating the basis will require 2 inputs.
+The number of elements of the product basis will be the product of the elements of the two 1D bases.
+
+```{code-cell} ipython3
# Evaluate the product basis at the x and y coordinates
eval_basis = prod_basis(x_coord, y_coord)
@@ -172,14 +250,14 @@
print(f"Product of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
f"\t- a_basis had {a_basis.n_basis_funcs} elements\n\t- b_basis had {b_basis.n_basis_funcs} elements.")
+```
+
+#### Plotting 2D Multiplicative Basis Elements
+Plotting works in the same way as before. To demonstrate that, we select a few pairs of 1D basis elements,
+and we visualize the corresponding product.
-# %%
-# #### Plotting 2D Multiplicative Basis Elements
-# Plotting works in the same way as before. To demonstrate that, we select a few pairs of 1D basis elements,
-# and we visualize the corresponding product.
-# Set this figure as the thumbnail
-# mkdocs_gallery_thumbnail_number = 3
+```{code-cell} ipython3
X, Y, Z = prod_basis.evaluate_on_grid(200, 200)
@@ -213,29 +291,53 @@
axs[2, 1].set_xlabel('y-coord')
plt.tight_layout()
-
-# %%
-# !!! info
-# Basis objects of different types can be combined through multiplication or addition.
-# This feature is particularly useful when one of the axes represents a periodic variable and another is non-periodic.
-# A practical example would be characterizing the responses to position
-# in a linear maze and the LFP phase angle.
-
-
-# %%
-# N-Dimensional Basis
-# -------------------
-# Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
-# an animal and its spatial position. In order to model an N-dimensional response function, you can combine
-# N 1D basis objects using additions and multiplications.
-#
-# !!! warning
-# If you multiply basis together, the dimension of the evaluated basis function
-# will increase exponentially with the number of dimensions potentially causing memory errors.
-# For example, evaluating a product of $N$ 1D bases with $T$ samples and $K$ basis element,
-# will output a $K^N \times T$ matrix.
-
-
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/background"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/background")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_02_ND_basis_function.svg")
+```
+
+:::{note}
+Basis objects of different types can be combined through multiplication or addition.
+This feature is particularly useful when one of the axes represents a periodic variable and another is non-periodic.
+A practical example would be characterizing the responses to position
+in a linear maze and the LFP phase angle.
+:::
+
+
+
+N-Dimensional Basis
+-------------------
+Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
+an animal and its spatial position. In order to model an N-dimensional response function, you can combine
+N 1D basis objects using additions and multiplications.
+
+:::{warning}
+If you multiply basis together, the dimension of the evaluated basis function
+will increase exponentially with the number of dimensions potentially causing memory errors.
+For example, evaluating a product of $N$ 1D bases with $T$ samples and $K$ basis element,
+will output a $K^N \times T$ matrix.
+:::
+
+```{code-cell} ipython3
T = 10
n_basis = 8
@@ -249,11 +351,12 @@
print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} "
f"basis elements.\nEvaluation output of shape {eval_basis.shape}")
+```
-# %%
-# The evaluation of the product of 3 basis is a 4 dimensional tensor; we can visualize slices of it.
+The evaluation of the product of 3 basis is a 4 dimensional tensor; we can visualize slices of it.
+```{code-cell} ipython3
X, Y, W, Z = prod_basis_3.evaluate_on_grid(30, 30, 30)
# select any slice
@@ -278,16 +381,20 @@
# Check sparsity
print(f"Sparsity check: {(Z == 0).sum() / Z.size * 100: .2f}% of the evaluated basis is null.")
+```
+
+:::{note}
+The evaluated basis is going to be **sparse** if the basis elements support do not cover the
+full domain of the basis.
+:::
+
+
-# %%
-# !!! info
-# The evaluated basis is going to be **sparse** if the basis elements support do not cover the
-# full domain of the basis.
+Here we demonstrate a shortcut syntax for multiplying bases of the same class.
+This is achieved using the power operator with an integer exponent.
-# %%
-# Here we demonstrate a shortcut syntax for multiplying bases of the same class.
-# This is achieved using the power operator with an integer exponent.
+```{code-cell} ipython3
# First, let's define a basis `power_basis` that is equivalent to `prod_basis_3`,
# but we use the power syntax this time:
power_basis = a_basis**3
@@ -299,3 +406,4 @@
# We can now assert that the original basis and the new `power_basis` match.
# If they do, the total number of mismatched entries should be zero.
print(f"Total mismatched entries: {(Z_pow_syntax != Z_prod_syntax).sum()}")
+```
diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md
new file mode 100644
index 00000000..fa7335e5
--- /dev/null
+++ b/docs/background/plot_03_1D_convolution.md
@@ -0,0 +1,216 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+(convolution_background)=
+# Convolution
+
+## Generate synthetic data
+Generate some simulated spike counts.
+
+
+```{code-cell} ipython3
+import matplotlib.patches as patches
+import matplotlib.pylab as plt
+import numpy as np
+
+import nemos as nmo
+
+np.random.seed(10)
+ws = 10
+# samples
+n_samples = 100
+
+spk = np.random.poisson(lam=0.1, size=(n_samples, ))
+
+# add borders (extreme case, general border effect are represented)
+spk[0] = 1
+spk[3] = 1
+spk[-1] = 1
+spk[-4] = 1
+```
+
+## Convolution in `"valid"` mode
+Generate and plot a filter, then execute a convolution in "valid" mode for all trials and neurons.
+In nemos, you can use the [`tensor_convolve`](nemos.convolve.tensor_convolve) function for this.
+
+:::{note}
+The `"valid"` mode of convolution only calculates the product when the two input vectors overlap completely,
+avoiding border artifacts. The outcome of such a convolution will
+be an array of `max(M,N) - min(M,N) + 1` elements in length, where `M` and `N` represent the number
+of elements in the arrays being convolved. For more detailed information on this,
+see [jax.numpy.convolve](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.convolve.html).
+:::
+
+```{code-cell} ipython3
+# create three filters
+basis_obj = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=3)
+_, w = basis_obj.evaluate_on_grid(ws)
+
+plt.plot(w)
+
+spk_conv = nmo.convolve.tensor_convolve(spk, w)
+
+# valid convolution should be of shape n_samples - ws + 1
+print(f"Shape of the convolution output: {spk_conv.shape}")
+```
+
+## Causal, Anti-Causal, and Acausal filters
+NaN padding appropriately the output of the convolution allows to model causal, anti-causal and acausal filters.
+A causal filter captures how an event or task variable influences the future firing-rate.
+An example usage case would be that of characterizing the refractory period of a neuron
+(i.e. the drop in firing rate immediately after a spike event). Another example could be characterizing how
+the current position of an animal in a maze would affect its future spiking activity.
+
+On the other hand, if we are interested in capturing the firing rate modulation before an event occurs we may want
+to use an anti-causal filter. An example of that may be the preparatory activity of pre-motor cortex that happens
+before a movement is initiated (here the event is. "movement onset").
+
+Finally, if one wants to capture both causal
+and anti-causal effects one should use the acausal filters.
+Below we provide the function [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) that runs the convolution in "valid" mode and pads the convolution output
+for the different filter types.
+
+
+```{code-cell} ipython3
+# pad according to the causal direction of the filter, after squeeze,
+# the dimension is (n_filters, n_samples)
+spk_causal_conv = nmo.convolve.create_convolutional_predictor(
+ w, spk, predictor_causality="causal"
+)
+spk_anticausal_conv = nmo.convolve.create_convolutional_predictor(
+ w, spk, predictor_causality="anti-causal"
+)
+spk_acausal_conv = nmo.convolve.create_convolutional_predictor(
+ w, spk, predictor_causality="acausal"
+)
+```
+
+Plot the results
+
+
+```{code-cell} ipython3
+# NaN padded area
+rect_causal = patches.Rectangle((0, -2.5), ws, 5, alpha=0.3, color='grey')
+rect_anticausal = patches.Rectangle((len(spk)-ws, -2.5), ws, 5, alpha=0.3, color='grey')
+rect_acausal_left = patches.Rectangle((0, -2.5), (ws-1)//2, 5, alpha=0.3, color='grey')
+rect_acausal_right = patches.Rectangle((len(spk) - (ws-1)//2, -2.5), (ws-1)//2, 5, alpha=0.3, color='grey')
+
+# Set this figure as the thumbnail
+# mkdocs_gallery_thumbnail_number = 2
+
+fig = plt.figure(figsize=(6, 4))
+
+shift_spk = - spk - 0.1
+ax = plt.subplot(311)
+
+plt.title('valid + nan-pad')
+ax.add_patch(rect_causal)
+plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
+plt.plot(np.arange(spk.shape[0]), spk_causal_conv)
+plt.ylabel('causal')
+
+ax = plt.subplot(312)
+ax.add_patch(rect_anticausal)
+plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
+plt.plot(np.arange(spk.shape[0]), spk_anticausal_conv)
+plt.ylabel('anti-causal')
+
+ax = plt.subplot(313)
+ax.add_patch(rect_acausal_left)
+ax.add_patch(rect_acausal_right)
+plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
+plt.plot(np.arange(spk.shape[0]), spk_acausal_conv)
+plt.ylabel('acausal')
+plt.tight_layout()
+
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/background"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/background")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_03_1D_convolution.svg")
+```
+
+## Convolve using [`Basis.compute_features`](nemos.basis.Basis.compute_features)
+All the parameters of [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) can be passed to a [`Basis`](nemos.basis.Basis) directly
+at initialization. Note that you must set `mode == "conv"` to actually perform convolution
+with [`Basis.compute_features`](nemos.basis.Basis.compute_features). Let's see how we can get the same results through [`Basis`](nemos.basis.Basis).
+
+
+```{code-cell} ipython3
+# define basis with different predictor causality
+causal_basis = nmo.basis.RaisedCosineBasisLinear(
+ n_basis_funcs=3, mode="conv", window_size=ws,
+ predictor_causality="causal"
+)
+
+acausal_basis = nmo.basis.RaisedCosineBasisLinear(
+ n_basis_funcs=3, mode="conv", window_size=ws,
+ predictor_causality="acausal"
+)
+
+anticausal_basis = nmo.basis.RaisedCosineBasisLinear(
+ n_basis_funcs=3, mode="conv", window_size=ws,
+ predictor_causality="anti-causal"
+)
+
+# compute convolutions
+basis_causal_conv = causal_basis.compute_features(spk)
+basis_acausal_conv = acausal_basis.compute_features(spk)
+basis_anticausal_conv = anticausal_basis.compute_features(spk)
+```
diff --git a/docs/background/plot_03_1D_convolution.py b/docs/background/plot_03_1D_convolution.py
deleted file mode 100644
index 79f6ddd9..00000000
--- a/docs/background/plot_03_1D_convolution.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""
-One-dimensional convolutions
-"""
-
-# %%
-# ## Generate synthetic data
-# Generate some simulated spike counts.
-
-import matplotlib.patches as patches
-import matplotlib.pylab as plt
-import numpy as np
-
-import nemos as nmo
-
-np.random.seed(10)
-ws = 10
-# samples
-n_samples = 100
-
-spk = np.random.poisson(lam=0.1, size=(n_samples, ))
-
-# add borders (extreme case, general border effect are represented)
-spk[0] = 1
-spk[3] = 1
-spk[-1] = 1
-spk[-4] = 1
-
-
-# %%
-# ## Convolution in `"valid"` mode
-# Generate and plot a filter, then execute a convolution in "valid" mode for all trials and neurons.
-#
-# !!! info
-# The `"valid"` mode of convolution only calculates the product when the two input vectors overlap completely,
-# avoiding border artifacts. The outcome of such a convolution will
-# be an array of `max(M,N) - min(M,N) + 1` elements in length, where `M` and `N` represent the number
-# of elements in the arrays being convolved. For more detailed information on this,
-# see [jax.numpy.convolve](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.convolve.html).
-
-
-# create three filters
-basis_obj = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=3)
-_, w = basis_obj.evaluate_on_grid(ws)
-
-plt.plot(w)
-
-spk_conv = nmo.convolve.reshape_convolve(spk, w)
-
-# valid convolution should be of shape n_samples - ws + 1
-print(f"Shape of the convolution output: {spk_conv.shape}")
-
-# %%
-# ## Causal, Anti-Causal, and Acausal filters
-# NaN padding appropriately the output of the convolution allows to model causal, anti-causal and acausal filters.
-# A causal filter captures how an event or task variable influences the future firing-rate.
-# An example usage case would be that of characterizing the refractory period of a neuron
-# (i.e. the drop in firing rate immediately after a spike event). Another example could be characterizing how
-# the current position of an animal in a maze would affect its future spiking activity.
-#
-# On the other hand, if we are interested in capturing the firing rate modulation before an event occurs we may want
-# to use an anti-causal filter. An example of that may be the preparatory activity of pre-motor cortex that happens
-# before a movement is initiated (here the event is. "movement onset").
-#
-# Finally, if one wants to capture both causal
-# and anti-causal effects one should use the acausal filters.
-# Below we provide a function that runs the convolution in "valid" mode and pads the convolution output
-# for the different filter types.
-
-
-# pad according to the causal direction of the filter, after squeeze,
-# the dimension is (n_filters, n_samples)
-spk_causal_conv = nmo.convolve.create_convolutional_predictor(
- w, spk, predictor_causality="causal"
-)
-spk_anticausal_conv = nmo.convolve.create_convolutional_predictor(
- w, spk, predictor_causality="anti-causal"
-)
-spk_acausal_conv = nmo.convolve.create_convolutional_predictor(
- w, spk, predictor_causality="acausal"
-)
-
-
-# %%
-# Plot the results
-
-# NaN padded area
-rect_causal = patches.Rectangle((0, -2.5), ws, 5, alpha=0.3, color='grey')
-rect_anticausal = patches.Rectangle((len(spk)-ws, -2.5), ws, 5, alpha=0.3, color='grey')
-rect_acausal_left = patches.Rectangle((0, -2.5), (ws-1)//2, 5, alpha=0.3, color='grey')
-rect_acausal_right = patches.Rectangle((len(spk) - (ws-1)//2, -2.5), (ws-1)//2, 5, alpha=0.3, color='grey')
-
-# Set this figure as the thumbnail
-# mkdocs_gallery_thumbnail_number = 2
-
-plt.figure(figsize=(6, 4))
-
-shift_spk = - spk - 0.1
-ax = plt.subplot(311)
-
-plt.title('valid + nan-pad')
-ax.add_patch(rect_causal)
-plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
-plt.plot(np.arange(spk.shape[0]), spk_causal_conv)
-plt.ylabel('causal')
-
-ax = plt.subplot(312)
-ax.add_patch(rect_anticausal)
-plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
-plt.plot(np.arange(spk.shape[0]), spk_anticausal_conv)
-plt.ylabel('anti-causal')
-
-ax = plt.subplot(313)
-ax.add_patch(rect_acausal_left)
-ax.add_patch(rect_acausal_right)
-plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
-plt.plot(np.arange(spk.shape[0]), spk_acausal_conv)
-plt.ylabel('acausal')
-plt.tight_layout()
-
-# %%
-# ## Convolve using `Basis.compute_features`
-# All the parameters of `create_convolutional_predictor` can be passed to a `Basis` directly
-# at initialization. Note that you must set `mode == "conv"` to actually perform convolution
-# with `Basis.compute_features`. Let's see how we can get the same results through `Basis`.
-
-# define basis with different predictor causality
-causal_basis = nmo.basis.RaisedCosineBasisLinear(
- n_basis_funcs=3, mode="conv", window_size=ws,
- predictor_causality="causal"
-)
-
-acausal_basis = nmo.basis.RaisedCosineBasisLinear(
- n_basis_funcs=3, mode="conv", window_size=ws,
- predictor_causality="acausal"
-)
-
-anticausal_basis = nmo.basis.RaisedCosineBasisLinear(
- n_basis_funcs=3, mode="conv", window_size=ws,
- predictor_causality="anti-causal"
-)
-
-# compute convolutions
-basis_causal_conv = causal_basis.compute_features(spk)
-basis_acausal_conv = acausal_basis.compute_features(spk)
-basis_anticausal_conv = anticausal_basis.compute_features(spk)
-
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 00000000..5a5b2c43
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,160 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+
+import nemos
+import sys, os
+from pathlib import Path
+
+sys.path.insert(0, str(Path('..', 'src').resolve()))
+sys.path.insert(0, os.path.abspath('sphinxext'))
+
+
+project = 'nemos'
+copyright = '2024'
+author = 'E Balzani'
+version = release = nemos.__version__
+
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+
+# The Root document
+root_doc = "index"
+
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.autosummary',
+ 'sphinx.ext.coverage',
+ 'sphinx.ext.viewcode', # Links to source code
+ 'sphinx.ext.doctest',
+ 'sphinx_copybutton', # Adds copy button to code blocks
+ 'sphinx_design', # For layout components
+ 'myst_nb',
+ 'sphinx_contributors',
+ 'sphinx_code_tabs',
+ 'sphinx.ext.mathjax',
+ 'sphinx_autodoc_typehints',
+ 'sphinx_togglebutton',
+]
+
+myst_enable_extensions = [
+ "amsmath",
+ "attrs_inline",
+ "colon_fence",
+ "dollarmath",
+ "html_admonition",
+ "html_image",
+]
+
+templates_path = ['_templates']
+exclude_patterns = ['_build', "docstrings", 'Thumbs.db', 'nextgen', '.DS_Store']
+
+
+
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
+
+
+# Generate the API documentation when building
+autosummary_generate = True
+numpydoc_show_class_members = True
+autodoc_default_options = {
+ 'members': True,
+ 'inherited-members': True,
+ 'undoc-members': True,
+ 'show-inheritance': True,
+ 'special-members': '__call__, __add__, __mul__, __pow__'
+}
+
+# # napolean configs
+napoleon_google_docstring = False
+napoleon_numpy_docstring = True
+napoleon_include_init_with_doc = False
+napoleon_include_private_with_doc = False
+napoleon_include_special_with_doc = True
+napoleon_use_admonition_for_examples = False
+napoleon_use_admonition_for_notes = False
+napoleon_use_admonition_for_references = False
+napoleon_use_ivar = False
+napoleon_use_param = True
+napoleon_use_rtype = True
+
+autodoc_typehints = "description" # Use "description" to place hints in the description
+autodoc_type_aliases = {
+ "ArrayLike": "ArrayLike",
+ "NDArray": "NDArray",
+ "TsdFrame": "pynapple.TsdFrame",
+ "JaxArray": "JaxArray",
+}
+autodoc_typehints_format = "short"
+
+numfig = True
+
+html_theme = 'pydata_sphinx_theme'
+
+html_favicon = "assets/NeMoS_favicon.ico"
+
+# Additional theme options
+html_theme_options = {
+ "icon_links": [
+ {
+ "name": "GitHub",
+ "url": "https://github.com/flatironinstitute/nemos/",
+ "icon": "fab fa-github",
+ "type": "fontawesome",
+ },
+ {
+ "name": "X",
+ "url": "https://x.com/nemos_neuro",
+ "icon": "fab fa-square-x-twitter",
+ "type": "fontawesome",
+ },
+ ],
+ "show_prev_next": True,
+ "header_links_before_dropdown": 6,
+ "navigation_depth": 3,
+ "logo": {
+ "image_light": "_static/NeMoS_Logo_CMYK_Full.svg",
+ "image_dark": "_static/NeMoS_Logo_CMYK_White.svg",
+ }
+}
+
+html_sidebars = {
+ "index": [],
+ "installation":[],
+ "quickstart": [],
+ "background/README": [],
+ "how_to_guide/README": [],
+ "tutorials/README": [],
+ "**": ["search-field.html", "sidebar-nav-bs.html"],
+}
+
+
+# Path for static files (custom stylesheets or JavaScript)
+html_static_path = ['assets/stylesheets', "assets"]
+html_css_files = ['custom.css']
+
+html_js_files = [
+ "https://code.iconify.design/2/2.2.1/iconify.min.js"
+]
+
+# Copybutton settings (to hide prompt)
+copybutton_prompt_text = r">>> |\$ |# "
+copybutton_prompt_is_regexp = True
+
+sphinxemoji_style = 'twemoji'
+
+nb_execution_timeout = 60 * 15 # Set timeout in seconds (e.g., 15 minutes)
+
+nitpicky = True
+
+# Get exclusion patterns from an environment variable
+exclude_tutorials = os.environ.get("EXCLUDE_TUTORIALS", "false").lower() == "true"
+
+if exclude_tutorials:
+ nb_execution_excludepatterns = ["tutorials/**", "how_to_guide/**", "background/**"]
\ No newline at end of file
diff --git a/docs/developers_notes/01-base_class.md b/docs/developers_notes/01-base_class.md
index 8050a6c8..70de4276 100644
--- a/docs/developers_notes/01-base_class.md
+++ b/docs/developers_notes/01-base_class.md
@@ -4,7 +4,7 @@
The `base_class` module introduces the `Base` class and abstract classes defining broad model categories and feature constructors. These abstract classes **must** inherit from `Base`.
-The `Base` class is envisioned as the foundational component for any object type (e.g., basis, regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_regressor.BaseRegressor` is building block for GLMs, GAMS, etc. while `observation_models.Observations` is the building block for the Poisson observations, Gamma observations, ... etc.).
+The `Base` class is envisioned as the foundational component for any object type (e.g., basis, regression, dimensionality reduction, clustering, observation models, regularizers etc.). In contrast, abstract classes derived from `Base` define overarching object categories (e.g., `base_regressor.BaseRegressor` is building block for GLMs, GAMS, etc. while [`observation_models.Observations`](nemos.observation_models.Observations) is the building block for the Poisson observations, Gamma observations, ... etc.).
Designed to be compatible with the `scikit-learn` API, the class structure aims to facilitate access to `scikit-learn`'s robust pipeline and cross-validation modules. This is achieved while leveraging the accelerated computational capabilities of `jax` and `jaxopt` in the backend, which is essential for analyzing extensive neural recordings and fitting large models.
@@ -47,8 +47,9 @@ The `Base` class aligns with the `scikit-learn` API for `base.BaseEstimator`. Th
For a detailed understanding, consult the [`scikit-learn` API Reference](https://scikit-learn.org/stable/modules/classes.html) and [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html).
-!!! Note
- We've intentionally omitted the `get_metadata_routing` method. Given its current experimental status and its lack of relevance to the `GLM` class, this method was excluded. Should future needs arise around parameter routing, consider directly inheriting from `sklearn.BaseEstimator`. More information can be found [here](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing).
+:::{note}
+We've intentionally omitted the `get_metadata_routing` method. Given its current experimental status and its lack of relevance to the [`GLM`](nemos.glm.GLM) class, this method was excluded. Should future needs arise around parameter routing, consider directly inheriting from `sklearn.BaseEstimator`. More information can be found [here](https://scikit-learn.org/stable/metadata_routing.html#metadata-routing).
+:::
### Public methods
diff --git a/docs/developers_notes/02-base_regressor.md b/docs/developers_notes/02-base_regressor.md
index 2e03e751..d29887e5 100644
--- a/docs/developers_notes/02-base_regressor.md
+++ b/docs/developers_notes/02-base_regressor.md
@@ -1,21 +1,24 @@
# The Abstract Class `BaseRegressor`
-`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of number of abstract methods such as `fit`, `predict`, `score`. This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures.
+`BaseRegressor` is an abstract class that inherits from `Base`, stipulating the implementation of number of abstract methods such as [`fit`](nemos.glm.GLM.fit), [`predict`](nemos.glm.GLM.predict), [`score`](nemos.glm.GLM.score). This ensures seamless assimilation with `scikit-learn` pipelines and cross-validation procedures.
-!!! Example
- The current package version includes a concrete class named `nemos.glm.GLM`. This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the "GLM regression" category.
- As a `BaseRegressor`, it **must** implement the `fit`, `score`, `predict` and the other abstract methods of this class, see below.
+:::{admonition} Example
+:class: note
+
+The current package version includes a concrete class named [`nemos.glm.GLM`](nemos.glm.GLM). This class inherits from `BaseRegressor`, which in turn inherits `Base`, since it falls under the "GLM regression" category.
+As a `BaseRegressor`, it **must** implement the [`fit`](nemos.glm.GLM.fit), [`score`](nemos.glm.GLM.score), [`predict`](nemos.glm.GLM.predict) and the other abstract methods of this class, see below.
+:::
### Abstract Methods
For subclasses derived from `BaseRegressor` to function correctly, they must implement the following:
-1. `fit`: Adapt the model using input data `X` and corresponding observations `y`.
-2. `predict`: Provide predictions based on the trained model and input data `X`.
-3. `score`: Score the accuracy of model predictions using input data `X` against the actual observations `y`.
-4. `simulate`: Simulate data based on the trained regression model.
-5. `update`: Run a single optimization step, and stores the updated parameter and solver state. Used by stochastic optimization schemes.
+1. [`fit`](nemos.glm.GLM.fit): Adapt the model using input data `X` and corresponding observations `y`.
+2. [`predict`](nemos.glm.GLM.predict): Provide predictions based on the trained model and input data `X`.
+3. [`score`](nemos.glm.GLM.score): Score the accuracy of model predictions using input data `X` against the actual observations `y`.
+4. [`simulate`](nemos.glm.GLM.simulate): Simulate data based on the trained regression model.
+5. [`update`](nemos.glm.GLM.update): Run a single optimization step, and stores the updated parameter and solver state. Used by stochastic optimization schemes.
6. `_predict_and_compute_loss`: Compute prediction and evaluates the loss function prvided the parameters and `X` and `y`. This is used by the `instantiate_solver` method which sets up the solver.
7. `_check_params`: Check the parameter structure.
8. `_check_input_dimensionality`: Check the input dimensionality matches model expectation.
@@ -29,7 +32,7 @@ input and parameters conform with the model requirements.
Public attributes are stored as properties:
-- `regularizer`: An instance of the `nemos.regularizer.Regularizer` class. The setter for this property accepts either the instance directly or a string that is used to instantiate the appropriate regularizer.
+- `regularizer`: An instance of the [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) class. The setter for this property accepts either the instance directly or a string that is used to instantiate the appropriate regularizer.
- `regularizer_strength`: A float quantifying the amount of regularization.
- `solver_name`: One of the `jaxopt` solver supported solvers, currently "GradientDescent", "BFGS", "LBFGS", "ProximalGradient" and, "NonlinearCG".
- `solver_kwargs`: Extra keyword arguments to be passed at solver initialization.
@@ -37,11 +40,14 @@ Public attributes are stored as properties:
When implementing a new subclass of `BaseRegressor`, the only attributes you must interact directly with are those that operate on the solver, i.e. `solver_init_state`, `solver_update`, `solver_run`.
-Typically, in `YourRegressor` you will call `self.solver_init_state` at the parameter initialization step, `self.sovler_run` in `fit`, and `self.solver_update` in `update`.
+Typically, in `YourRegressor` you will call `self.solver_init_state` at the parameter initialization step, `self.sovler_run` in [`fit`](nemos.glm.GLM.fit), and `self.solver_update` in [`update`](nemos.glm.GLM.update).
+
+:::{admonition} Solvers
+:class: note
-!!! note "Solvers"
- Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run`, `init_state`, and `update` method with the appropriate input/output types).
- We rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable. In the future we may provide a number of custom solvers optimized for convex stochastic optimization.
+Solvers are typically optimizers from the `jaxopt` package, but in principle they could be custom optimization routines as long as they respect the `jaxopt` api (i.e., have a `run`, `init_state`, and [`update`](nemos.glm.GLM.update) method with the appropriate input/output types).
+We rely on `jaxopt` because it provides a comprehensive set of robust, GPU accelerated, batchable and differentiable optimizers in JAX, that are highly customizable. In the future we may provide a number of custom solvers optimized for convex stochastic optimization.
+:::
## Contributor Guidelines
diff --git a/docs/developers_notes/03-glm.md b/docs/developers_notes/03-glm.md
index 3eb5928b..2ab98f10 100644
--- a/docs/developers_notes/03-glm.md
+++ b/docs/developers_notes/03-glm.md
@@ -6,53 +6,54 @@
Generalized Linear Models (GLM) provide a flexible framework for modeling a variety of data types while establishing a relationship between multiple predictors and a response variable. A GLM extends the traditional linear regression by allowing for response variables that have error distribution models other than a normal distribution, such as binomial or Poisson distributions.
-The `nemos.glm` module currently offers implementations of two GLM classes:
+The [`nemos.glm`](nemos_glm) module currently offers implementations of two GLM classes:
-1. **`GLM`:** A direct implementation of a feedforward GLM.
-2. **`PopulationGLM`:** An implementation of a GLM for fitting a populaiton of neuron in a vectorized manner. This class inherits from `GLM` and redefines the `fit` and `_predict` to fit the model and predict the firing rate.
+1. [`GLM`](nemos.glm.GLM): A direct implementation of a feedforward GLM.
+2. [`PopulationGLM`](nemos.glm.PopulationGLM): An implementation of a GLM for fitting a populaiton of neuron in a vectorized manner. This class inherits from [`GLM`](nemos.glm.GLM) and redefines the [`fit`](nemos.glm.GLM.fit) and `_predict` to fit the model and predict the firing rate.
Our design aligns with the `scikit-learn` API, facilitating seamless integration of our GLM classes with the well-established `scikit-learn` pipeline and its cross-validation tools.
The classes provided here are modular by design offering a standard foundation for any GLM variant.
-Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.), a regularization strategies (Ridge, Lasso, etc.) and an optimization scheme during initialization. This is done using the [`nemos.observation_models.Observations`](../05-observation_models/#the-abstract-class-observations), [`nemos.regularizer.Regularizer`](../06-regularizer/#the-abstract-class-regularizer) objects as well as the compatible `jaxopt` solvers, respectively.
+Instantiating a specific GLM simply requires providing an observation model (Gamma, Poisson, etc.), a regularization strategies (Ridge, Lasso, etc.) and an optimization scheme during initialization. This is done using the [`nemos.observation_models.Observations`](nemos.observation_models.Observations), [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) objects as well as the compatible `jaxopt` solvers, respectively.
-
+(the-concrete-class-glm)=
## The Concrete Class `GLM`
-The `GLM` class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind.
+The [`GLM`](nemos.glm.GLM) class provides a direct implementation of the GLM model and is designed with `scikit-learn` compatibility in mind.
### Inheritance
-`GLM` inherits from [`BaseRegressor`](../02-base_regressor/#the-abstract-class-baseregressor). This inheritance mandates the direct implementation of methods like `predict`, `fit`, `score` `update`, and `simulate`, plus a number of validation methods.
+[`GLM`](nemos.glm.GLM) inherits from [`BaseRegressor`](02-base_regressor.md). This inheritance mandates the direct implementation of methods like [`predict`](nemos.glm.GLM.predict), [`fit`](nemos.glm.GLM.fit), [`score`](nemos.glm.GLM.score), [`update`](nemos.glm.GLM.update), and [`simulate`(nemos.glm.GLM.[`GLM`](nemos.glm.GLM) inherits from [`BaseRegressor`](02-base_regressor.md). This inheritance mandates the direct implementation of methods like [`predict`](nemos.glm.GLM.predict), [`fit`](nemos.glm.GLM.fit), [`score`](nemos.glm.GLM.score), [`update`](nemos.glm.GLM.update), and [`simulate`](nemos.glm.GLM.simulate), plus a number of validation methods.
+), plus a number of validation methods.
### Attributes
-- **`observation_model`**: Property that represents the GLM observation model, which is an object of the [`nemos.observation_models.Observations`](../05-observation_models/#the-abstract-class-observations) type. This model determines the log-likelihood and the emission probability mechanism for the `GLM`.
+- **`observation_model`**: Property that represents the GLM observation model, which is an object of the [`nemos.observation_models.Observations`](nemos.observation_models.Observations) type. This model determines the log-likelihood and the emission probability mechanism for the [`GLM`](nemos.glm.GLM).
- **`coef_`**: Stores the solution for spike basis coefficients as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
- **`intercept_`**: Stores the bias terms' solutions as `jax.ndarray` after the fitting process. It is initialized as `None` during class instantiation.
- **`dof_resid_`**: The degrees of freedom of the model's residual. this quantity is used to estimate the scale parameter, see below, and compute frequentist confidence intervals.
- **`scale_`**: The scale parameter of the observation distribution, which together with the rate, uniquely specifies a distribution of the exponential family. Example: a 1D Gaussian is specified by the mean which is the rate, and the standard deviation, which is the scale.
- **`solver_state_`**: Indicates the solver's state. For specific solver states, refer to the [`jaxopt` documentation](https://jaxopt.github.io/stable/index.html#).
-Additionally, the `GLM` class inherits the attributes of `BaseRegressor`, see the [relative note](02-base_regressor.md) for more information.
+Additionally, the [`GLM`](nemos.glm.GLM) class inherits the attributes of `BaseRegressor`, see the [relative note](02-base_regressor.md) for more information.
### Public Methods
-- **`predict`**: Validates input and computes the mean rates of the `GLM` by invoking the inverse-link function of the `observation_models` attribute.
-- **`score`**: Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `observation_models` to determine log-likelihood or pseudo-$R^2$.
-- **`fit`**: Validates input and aligns the Poisson GLM with spike train data. It leverages the `observation_models` and `regularizer` to define the model's loss function and instantiate the regularizer.
-- **`simulate`**: Simulates spike trains using the GLM as a feedforward network, invoking the `observation_models.sample_generator` method for emission probability.
-- **`initialize_params`**: Initialize model parameters, setting to zero the coefficients, and setting the intercept by matching the firing rate.
-- **`initialize_state`**: Initialize the state of the solver.
-- **`update`**: Run a step of optimization and update the parameter and solver step.
+- [`predict`](nemos.glm.GLM.predict): Validates input and computes the mean rates of the [`GLM`](nemos.glm.GLM) by invoking the inverse-link function of the `observation_models` attribute.
+- [`score`](nemos.glm.GLM.score): Validates input and assesses the Poisson GLM using either log-likelihood or pseudo-$R^2$. This method uses the `observation_models` to determine log-likelihood or pseudo-$R^2$.
+- [`fit`](nemos.glm.GLM.fit): Validates input and aligns the Poisson GLM with spike train data. It leverages the `observation_models` and `regularizer` to define the model's loss function and instantiate the regularizer.
+- [`simulate`](nemos.glm.GLM.simulate): Simulates spike trains using the GLM as a feedforward network, invoking the `observation_models.sample_generator` method for emission probability.
+- [`initialize_params`](nemos.glm.GLM.initialize_params): Initialize model parameters, setting to zero the coefficients, and setting the intercept by matching the firing rate.
+- [`initialize_state`](nemos.glm.GLM.initialize_state): Initialize the state of the solver.
+- [`update`](nemos.glm.GLM.update): Run a step of optimization and update the parameter and solver step.
### Private Methods
@@ -61,20 +62,20 @@ Here we list the private method related to the model computations:
- **`_predict`**: Forecasts rates based on current model parameters and the inverse-link function of the `observation_models`.
- **`_predict_and_compute_loss`**: Predicts the rate and calculates the mean Poisson negative log-likelihood, excluding normalization constants.
-A number of `GLM` specific private methods are used for checking parameters and inputs, while the methods related for checking the solver-regularizer configurations/instantiation are inherited from `BaseRergessor`.
+A number of [`GLM`](nemos.glm.GLM) specific private methods are used for checking parameters and inputs, while the methods related for checking the solver-regularizer configurations/instantiation are inherited from `BaseRergessor`.
## The Concrete Class `PopulationGLM`
-The `PopulationGLM` class is an extension of the `GLM`, designed to fit multiple neurons jointly. This involves vectorized fitting processes that efficiently handle multiple neurons simultaneously, leveraging the inherent parallelism.
+The [`PopulationGLM`](nemos.glm.PopulationGLM) class is an extension of the [`GLM`](nemos.glm.GLM), designed to fit multiple neurons jointly. This involves vectorized fitting processes that efficiently handle multiple neurons simultaneously, leveraging the inherent parallelism.
### `PopulationGLM` Specific Attributes
-- **`feature_mask`**: A mask that determines which features are used as predictors for each neuron. It can be a matrix of shape `(num_features, num_neurons)` or a `FeaturePytree` of binary values, where 1 indicates that a feature is used for a particular neuron and 0 indicates it is not.
+- **`feature_mask`**: A mask that determines which features are used as predictors for each neuron. It can be a matrix of shape `(num_features, num_neurons)` or a [`FeaturePytree`](nemos.pytrees.FeaturePytree) of binary values, where 1 indicates that a feature is used for a particular neuron and 0 indicates it is not.
### Overridden Methods
-- **`fit`**: Overridden to handle fitting of the model to a neural population. This method validates input including the mask and fits the model parameters (coefficients and intercepts) to the data.
+- [`fit`](nemos.glm.PopulationGLM.fit): Overridden to handle fitting of the model to a neural population. This method validates input including the mask and fits the model parameters (coefficients and intercepts) to the data.
- **`_predict`**: Computes the predicted firing rates using the model parameters and the feature mask.
@@ -85,9 +86,9 @@ The `PopulationGLM` class is an extension of the `GLM`, designed to fit multiple
When crafting a functional (i.e., concrete) GLM class:
-- You **must** inherit from `GLM` or one of its derivatives.
+- You **must** inherit from [`GLM`](nemos.glm.GLM) or one of its derivatives.
- If you inherit directly from `BaseRegressor`, you **must** implement all the abstract methods, see the [`BaseRegressor` page](02-base_regressor.md) for more details.
-- If you inherit `GLM` or any of the other concrete classes directly, there won't be any abstract methods.
+- If you inherit [`GLM`](nemos.glm.GLM) or any of the other concrete classes directly, there won't be any abstract methods.
- You **may** embed additional parameter and input checks if required by the specific GLM subclass.
- You **may** override some of the computations if needed by the model specifications.
diff --git a/docs/developers_notes/04-basis_module.md b/docs/developers_notes/04-basis_module.md
index d834048d..d0f939a0 100644
--- a/docs/developers_notes/04-basis_module.md
+++ b/docs/developers_notes/04-basis_module.md
@@ -2,7 +2,7 @@
## Introduction
-The `nemos.basis` module provides objects that allow users to construct and evaluate basis functions of various types. The classes are hierarchically organized as follows:
+The [`nemos.basis`](nemos_basis) module provides objects that allow users to construct and evaluate basis functions of various types. The classes are hierarchically organized as follows:
```
Abstract Class Basis
@@ -26,44 +26,49 @@ Abstract Class Basis
└─ Concrete Subclass OrthExponentialBasis
```
-The super-class `Basis` provides two public methods, [`compute_features`](#the-public-method-compute_features) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method `__call__` that is specific for each concrete class. See below for more details.
+The super-class [`Basis`](nemos.basis.Basis) provides two public methods, [`compute_features`](the-public-method-compute_features) and [`evaluate_on_grid`](the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the abstract method [`__call__`](nemos.basis.Basis.__call__) that is specific for each concrete class. See below for more details.
## The Class `nemos.basis.Basis`
+(the-public-method-compute_features)=
### The Public Method `compute_features`
-The `compute_features` method checks input consistency and applies the basis function to the inputs.
-`Basis` can operate in two modes defined at initialization: `"eval"` and `"conv"`. When a basis is in mode `"eval"`,
-`compute_features` evaluates the basis at the given input samples. When in mode `"conv"`, it will convolve the samples
+The [`compute_features`](nemos.basis.Basis.compute_features) method checks input consistency and applies the basis function to the inputs.
+[`Basis`](nemos.basis.Basis) can operate in two modes defined at initialization: `"eval"` and `"conv"`. When a basis is in mode `"eval"`,
+[`compute_features`](nemos.basis.Basis.compute_features) evaluates the basis at the given input samples. When in mode `"conv"`, it will convolve the samples
with a bank of kernels, one per basis function.
It accepts one or more NumPy array or pynapple `Tsd` object as input, and performs the following steps:
1. Checks that the inputs all have the same sample size `M`, and raises a `ValueError` if this is not the case.
2. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
-3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using `evaluate_on_grid` and then applies the convolution to the input with `nemos.convolve.create_convolutional_predictor`.
+3. In `"eval"` mode, calls the `__call__` method on the input, which is the subclass-specific implementation of the basis set evaluation. In `"conv"` mode, generates a filter bank using [`compute_features`](nemos.basis.Basis.evaluate_on_grid) and then applies the convolution to the input with [`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor).
4. Returns a NumPy array or pynapple `TsdFrame` of shape `(M, n_basis_funcs)`, with each basis element evaluated at the samples.
-!!! note "Multiple epochs"
- Note that the convolution works gracefully with multiple disjoint epochs, when a pynapple time series is used as
- input.
+:::{admonition} Multiple epochs
+:class: note
+Note that the convolution works gracefully with multiple disjoint epochs, when a pynapple time series is used as
+input.
+:::
+
+(the-public-method-evaluate_on_grid)=
### The Public Method `evaluate_on_grid`
-The `evaluate_on_grid` method evaluates the basis set on a grid of equidistant sample points. The user specifies the input as a series of integers, one for each dimension of the basis function, that indicate the number of sample points in each coordinate of the grid.
+The [`compute_features`](nemos.basis.Basis.compute_features) method evaluates the basis set on a grid of equidistant sample points. The user specifies the input as a series of integers, one for each dimension of the basis function, that indicate the number of sample points in each coordinate of the grid.
This method performs the following steps:
1. Checks that the number of inputs matches what the basis being evaluated expects (e.g., one input for a 1-D basis, N inputs for an N-D basis, or the sum of N 1-D bases), and raises a `ValueError` if this is not the case.
2. Calls `_get_samples` method, which returns equidistant samples over the domain of the basis function. The domain may depend on the type of basis.
-3. Calls the `__call__` method.
+3. Calls the [`__call__`](nemos.basis.Basis.__call__) method.
4. Returns both the sample grid points of shape `(m1, ..., mN)`, and the evaluation output at each grid point of shape `(m1, ..., mN, n_basis_funcs)`, where `mi` is the number of sample points for the i-th axis of the grid.
### Abstract Methods
-The `nemos.basis.Basis` class has the following abstract methods, which every concrete subclass must implement:
+The [`nemos.basis.Basis`](nemos.basis.Basis) class has the following abstract methods, which every concrete subclass must implement:
-1. `__call__`: Evaluates a basis over some specified samples.
+1. [`__call__`](nemos.basis.Basis.__call__): Evaluates a basis over some specified samples.
2. `_check_n_basis_min`: Checks the minimum number of basis functions required. This requirement can be specific to the type of basis.
## Contributors Guidelines
@@ -71,8 +76,8 @@ The `nemos.basis.Basis` class has the following abstract methods, which every co
### Implementing Concrete Basis Objects
To write a usable (i.e., concrete, non-abstract) basis object, you
-- **Must** inherit the abstract superclass `Basis`
-- **Must** define the `__call__` and `_check_n_basis_min` methods with the expected input/output format, see [API Guide](../../reference/nemos/basis/) for the specifics.
-- **Should not** overwrite the `compute_features` and `evaluate_on_grid` methods inherited from `Basis`.
-- **May** inherit any number of abstract intermediate classes (e.g., `SplineBasis`).
+- **Must** inherit the abstract superclass [`Basis`](nemos.basis.Basis)
+- **Must** define the [`__call__`](nemos.basis.Basis.__call__) and `_check_n_basis_min` methods with the expected input/output format, see [API Reference](nemos_basis) for the specifics.
+- **Should not** overwrite the [`compute_features`](nemos.basis.Basis.compute_features) and [`compute_features`](nemos.basis.Basis.evaluate_on_grid) methods inherited from [`Basis`](nemos.basis.Basis).
+- **May** inherit any number of abstract intermediate classes (e.g., [`SplineBasis`](nemos.basis.SplineBasis)).
diff --git a/docs/developers_notes/05-observation_models.md b/docs/developers_notes/05-observation_models.md
index 9b5e9173..0573c6eb 100644
--- a/docs/developers_notes/05-observation_models.md
+++ b/docs/developers_notes/05-observation_models.md
@@ -2,45 +2,50 @@
## Introduction
-The `observation_models` module provides objects representing the observations of GLM-like models.
+The [`observation_models`](observation_models) module provides objects representing the observations of GLM-like models.
-The abstract class `Observations` defines the structure of the subclasses which specify observation types, such as Poisson, Gamma, etc. These objects serve as attributes of the [`nemos.glm.GLM`](../03-glm/#the-concrete-class-glm) class, equipping the GLM with a negative log-likelihood. This is used to define the optimization objective, the deviance which measures model fit quality, and the emission of new observations, for simulating new data.
+The abstract class [`Observations`](nemos.observation_models.Observations) defines the structure of the subclasses which specify observation types, such as Poisson, Gamma, etc. These objects serve as attributes of the [`nemos.glm.GLM`](the-concrete-class-glm) class, equipping the GLM with a negative log-likelihood. This is used to define the optimization objective, the deviance which measures model fit quality, and the emission of new observations, for simulating new data.
+(the-abstract-class-observations)=
## The Abstract class `Observations`
-The abstract class `Observations` is the backbone of any observation model. Any class inheriting `Observations` must reimplement the `_negative_log_likelihood`, `log_likelihood`, `sample_generator`, `deviance`, and `estimate_scale` methods.
+The abstract class [`Observations`](nemos.observation_models.Observations) is the backbone of any observation model. Any class inheriting [`Observations`](nemos.observation_models.Observations) must reimplement the `_negative_log_likelihood`, [`log_likelihood`](nemos.observation_models.Observations.log_likelihood), [`sample_generator`](nemos.observation_models.Observations.sample_generator), [`deviance`](nemos.observation_models.Observations.deviance), and [`estimate_scale`](nemos.observation_models.Observations.estimate_scale) methods.
### Abstract Methods
-For subclasses derived from `Observations` to function correctly, they must implement the following:
+For subclasses derived from [`Observations`](nemos.observation_models.Observations) to function correctly, they must implement the following:
- **_negative_log_likelihood**: Computes the negative-log likelihood of the model up to a normalization constant. This method is usually part of the objective function used to learn GLM parameters.
-- **log_likelihood**: Computes the full log-likelihood including the normalization constant.
+- [`log_likelihood`](nemos.observation_models.Observations.log_likelihood): Computes the full log-likelihood including the normalization constant.
-- **sample_generator**: Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics, [see below](#suff-stat). For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma).
+- [`sample_generator`](nemos.observation_models.Observations.sample_generator): Returns the random emission probability function. This typically invokes `jax.random` emission probability, provided some sufficient statistics, [see below](#suff-stat). For distributions in the exponential family, the sufficient statistics are the canonical parameter and the scale. In GLMs, the canonical parameter is entirely specified by the model's weights, while the scale is either fixed (i.e., Poisson) or needs to be estimated (i.e., Gamma).
-- **deviance**: Computes the deviance based on the model's estimated rates and observations.
+- [`deviance`](nemos.observation_models.Observations.deviance): Computes the deviance based on the model's estimated rates and observations.
-- **estimate_scale**: A method for estimating the scale parameter of the model. Rate and scale are sufficient to fully characterize distributions from the exponential family.
+- [`estimate_scale`](nemos.observation_models.Observations.estimate_scale): A method for estimating the scale parameter of the model. Rate and scale are sufficient to fully characterize distributions from the exponential family.
-???+ info "Sufficient statistics"
- In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from [[1]](#ref-1).
+:::{dropdown} Sufficient statistics
+:color: info
+:icon: info
+
+In statistics, a statistic is sufficient with respect to a statistical model and its associated unknown parameters if "no other statistic that can be calculated from the same sample provides any additional information as to the value of the parameters", adapted from [[1]](#ref-1).
+:::
### Public Methods
-- **pseudo_r2**: Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[$^{[2]}$](#ref-2).
-- **check_inverse_link_function**: Check that the link function is a auto-differentiable, vectorized function form $\mathbb{R} \longrightarrow \mathbb{R}$.
+- [`pseudo_r2`](nemos.observation_models.Observations.pseudo_r2): Method for computing the pseudo-$R^2$ of the model based on the residual deviance. There is no consensus definition for the pseudo-$R^2$, what we used here is the definition by Cohen at al. 2003[$^{[2]}$](#ref-2).
+- [`check_inverse_link_function`](nemos.observation_models.Observations.check_inverse_link_function): Check that the link function is a auto-differentiable, vectorized function form $\mathbb{R} \longrightarrow \mathbb{R}$.
## Contributor Guidelines
To implement an observation model class you
-- **Must** inherit from `Observations`
+- **Must** inherit from [`Observations`](nemos.observation_models.Observations)
- **Must** provide a concrete implementation of the abstract methods, see above.
-- **Should not** reimplement the `pseudo_r2` method as well as the `check_inverse_link_function` auxiliary method.
+- **Should not** reimplement the [`pseudo_r2`](nemos.observation_models.Observations.pseudo_r2) method as well as the [`check_inverse_link_function`](nemos.observation_models.Observations.check_inverse_link_function) auxiliary method.
## References
[1]
diff --git a/docs/developers_notes/06-regularizer.md b/docs/developers_notes/06-regularizer.md
index 54b19afc..bfb049b7 100644
--- a/docs/developers_notes/06-regularizer.md
+++ b/docs/developers_notes/06-regularizer.md
@@ -2,11 +2,11 @@
## Introduction
-The `regularizer` module introduces an archetype class `Regularizer` which provides the structural components for each concrete sub-class.
+The [`regularizer`](regularizers) module introduces an archetype class [`Regularizer`](nemos.regularizer.Regularizer) which provides the structural components for each concrete sub-class.
-Objects of type `Regularizer` provide methods to define a regularized optimization objective. These objects serve as attribute of the [`nemos.glm.GLM`](../03-glm/#the-concrete-class-glm), equipping the glm with an appropriate regularization scheme.
+Objects of type [`Regularizer`](nemos.regularizer.Regularizer) provide methods to define a regularized optimization objective. These objects serve as attribute of the [`nemos.glm.GLM`](the-concrete-class-glm), equipping the glm with an appropriate regularization scheme.
-Each `Regularizer` object defines a default solver, and a set of allowed solvers, which depends on the loss function characteristics (smooth vs non-smooth).
+Each [`Regularizer`](nemos.regularizer.Regularizer) object defines a default solver, and a set of allowed solvers, which depends on the loss function characteristics (smooth vs non-smooth).
```
Abstract Class Regularizer
@@ -20,30 +20,32 @@ Abstract Class Regularizer
└─ Concrete Class GroupLasso
```
-!!! note
- If we need advanced adaptive solvers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver).
+:::{note}
+If we need advanced adaptive solvers (e.g., Adam, LAMB etc.) in the future, we should consider adding [`Optax`](https://optax.readthedocs.io/en/latest/) as a dependency, which is compatible with `jaxopt`, see [here](https://jaxopt.github.io/stable/_autosummary/jaxopt.OptaxSolver.html#jaxopt.OptaxSolver).
+:::
+(the-abstract-class-regularizer)=
## The Abstract Class `Regularizer`
-The abstract class `Regularizer` enforces the implementation of the `penalized_loss` and `get_proximal_operator` methods.
+The abstract class [`Regularizer`](nemos.regularizer.Regularizer) enforces the implementation of the [`penalized_loss`](nemos.regularizer.Regularizer.penalized_loss) and [`get_proximal_operator`](nemos.regularizer.Regularizer.get_proximal_operator) methods.
### Attributes
-The attributes of `Regularizer` consist of the `default_solver` and `allowed_solvers`, which are stored as read-only properties of type string and tuple of strings respectively.
+The attributes of [`Regularizer`](nemos.regularizer.Regularizer) consist of the `default_solver` and `allowed_solvers`, which are stored as read-only properties of type string and tuple of strings respectively.
### Abstract Methods
-- **`penalized_loss`**: Returns a penalized version of the input loss function which is uniquely defined by the regularization scheme and the regularizer strength parameter.
-- **`get_proximal_operator`**: Returns the proximal projection operator which is uniquely defined by the regularization scheme.
+- [`penalized_loss`](nemos.regularizer.Regularizer.penalized_loss): Returns a penalized version of the input loss function which is uniquely defined by the regularization scheme and the regularizer strength parameter.
+- [`get_proximal_operator`](nemos.regularizer.Regularizer.get_proximal_operator): Returns the proximal projection operator which is uniquely defined by the regularization scheme.
## The `UnRegularized` Class
-The `UnRegularized` class extends the base `Regularizer` class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process.
+The [`UnRegularized`](nemos.regularizer.UnRegularized) class extends the base [`Regularizer`](nemos.regularizer.Regularizer) class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process.
### Concrete Methods Specifics
-- **`penalized_loss`**: Returns the original loss without any changes.
-- **`get_proximal_operator`**: Returns the identity operator.
+- [`penalized_loss`](nemos.regularizer.UnRegularized.penalized_loss): Returns the original loss without any changes.
+- [`get_proximal_operator`](nemos.regularizer.UnRegularized.get_proximal_operator): Returns the identity operator.
## Contributor Guidelines
@@ -52,18 +54,21 @@ The `UnRegularized` class extends the base `Regularizer` class and is designed s
When developing a functional (i.e., concrete) `Regularizer` class:
-- **Must** inherit from `Regularizer` or one of its derivatives.
-- **Must** implement the `penalized_loss` and `get_proximal_operator` methods.
+- **Must** inherit from [`Regularizer`](nemos.regularizer.Regularizer) or one of its derivatives.
+- **Must** implement the [`penalized_loss`](nemos.regularizer.Regularizer.penalized_loss) and [`get_proximal_operator`](nemos.regularizer.Regularizer.get_proximal_operator) methods.
- **Must** define a default solver and a tuple of allowed solvers.
-- **May** require extra initialization parameters, like the `mask` argument of `GroupLasso`.
-
-??? tip "Convergence Test"
- When adding a new regularizer, you must include a convergence test, which verifies that
- the model parameters the regularizer finds for a convex problem such as the GLM are identical
- whether one minimizes the penalized loss directly and uses the proximal operator (i.e., when
- using `ProximalGradient`). In practice, this means you should test the result of the `ProximalGradient`
- optimization against that of either `GradientDescent` (if your regularization is differentiable) or
- `Nelder-Mead` from [`scipy.optimize.minimize`](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-neldermead.html)
- (or another non-gradient based method, if your regularization is non-differentiable). You can refer to NeMoS `test_lasso_convergence`
- from `tests/test_convergence.py` for a concrete example.
-
+- **May** require extra initialization parameters, like the `mask` argument of [`GroupLasso`](nemos.regularizer.GroupLasso).
+
+:::{dropdown} Convergence Test
+:icon: light-bulb
+:color: success
+
+When adding a new regularizer, you must include a convergence test, which verifies that
+the model parameters the regularizer finds for a convex problem such as the GLM are identical
+whether one minimizes the penalized loss directly and uses the proximal operator (i.e., when
+using `ProximalGradient`). In practice, this means you should test the result of the `ProximalGradient`
+optimization against that of either `GradientDescent` (if your regularization is differentiable) or
+`Nelder-Mead` from [`scipy.optimize.minimize`](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-neldermead.html)
+(or another non-gradient based method, if your regularization is non-differentiable). You can refer to NeMoS `test_lasso_convergence`
+from `tests/test_convergence.py` for a concrete example.
+:::
\ No newline at end of file
diff --git a/docs/developers_notes/README.md b/docs/developers_notes/README.md
index 429775e5..6f44b3dd 100644
--- a/docs/developers_notes/README.md
+++ b/docs/developers_notes/README.md
@@ -1,4 +1,20 @@
-# Introduction
+
+# For Developers
+
+## Contents
+
+```{toctree}
+:maxdepth: 1
+
+01-base_class.md
+02-base_regressor.md
+03-glm.md
+04-basis_module.md
+05-observation_models.md
+06-regularizer.md
+```
+
+## Introduction
Welcome to the Developer Notes of the NeMoS project. These notes aim to provide detailed technical information about the various modules, classes, and functions that make up this library, as well as guidelines on how to write code that integrates nicely with our package. They are intended to help current and future developers understand the design decisions, structure, and functioning of the library, and to provide guidance on how to modify, extend, and maintain the codebase.
diff --git a/docs/gallery_conf.py b/docs/gallery_conf.py
deleted file mode 100644
index ca01af54..00000000
--- a/docs/gallery_conf.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-import re
-import sys
-
-from mkdocs_gallery.sorting import FileNameSortKey
-
-min_reported_time = 0
-if "SOURCE_DATE_EPOCH" in os.environ:
- min_reported_time = sys.maxint if sys.version_info[0] == 2 else sys.maxsize
-
-# To be used as the "base" config,
-# mkdocs-gallery is a port of sphinx-gallery. For a detailed list
-# of configuration options see https://sphinx-gallery.github.io/stable/configuration.html
-conf = {
- # report runtime if larger than this value
- "min_reported_time": min_reported_time,
- # order your section in file name alphabetical order
- "within_subsection_order": FileNameSortKey,
- # run every script that matches pattern
- # (here we match every file that ends in .py)
- "filename_pattern": re.escape(os.sep) + r"plot_.+\.py$",
- "ignore_pattern": r"(_plot_.+\.py$|_helpers\.py$)",
-}
diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py
deleted file mode 100644
index 541e0fd5..00000000
--- a/docs/gen_ref_pages.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""Generate the code reference pages and navigation.
-
-See [CCN template repo](https://ccn-template.readthedocs.io/en/latest/notes/03-documentation/) for why.
-"""
-
-from pathlib import Path
-import re
-import mkdocs_gen_files
-
-SKIP_MODULES = ("styles", "_documentation_utils", "_regularizer_builder")
-
-
-def skip_module(module_path: Path):
- return any(p in SKIP_MODULES for p in module_path.with_suffix("").parts)
-
-
-def filter_nav(iter_literate_nav):
- filtered_nav = []
- for line in iter_literate_nav:
- if not any(re.search(rf"\[{p}]", line) for p in SKIP_MODULES):
- filtered_nav.append(line)
- return filtered_nav
-
-
-nav = mkdocs_gen_files.Nav()
-
-for path in sorted(Path("src").rglob("*.py")):
-
- module_path = path.relative_to("src").with_suffix("")
-
- if skip_module(module_path):
- continue
-
- doc_path = path.relative_to("src").with_suffix(".md")
- full_doc_path = Path("reference", doc_path)
-
- parts = tuple(module_path.parts)
-
- if parts[-1] == "__init__":
- parts = parts[:-1]
- doc_path = doc_path.with_name("index.md")
- full_doc_path = full_doc_path.with_name("index.md")
- elif parts[-1] == "__main__":
- continue
-
- nav[parts] = doc_path.as_posix()
-
- # if the md file name is `module.md`, generate documentation from docstrings
- if full_doc_path.name != 'index.md':
- with mkdocs_gen_files.open(full_doc_path, "w") as fd:
- ident = ".".join(parts)
- fd.write(f"::: {ident}")
-
- # if the md file name is `index.md`, add the list of modules with hyperlinks
- else:
- this_module_path = Path("src") / Path(*parts)
- module_index = ""
- for module_scripts in sorted(this_module_path.rglob("*.py")):
-
- if "__init__" in module_scripts.name:
- continue
- elif skip_module(module_scripts):
- continue
-
- tabs = ""
- cumlative_path = []
- for i, p in enumerate(module_scripts.parts[len(this_module_path.parts):]):
- cumlative_path.append(p)
- relative_path = Path(*cumlative_path)
- module_index += tabs + (f"* [{p.replace('.py', '')}]"
- f"({relative_path.as_posix().replace('.py', '.md')})\n")
- tabs = "\t" + tabs
-
- with mkdocs_gen_files.open(full_doc_path, "w") as fd:
- fd.write(module_index)
-
-
-
-
- mkdocs_gen_files.set_edit_path(full_doc_path, path)
-
-with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file:
-
- literate_nav = nav.build_literate_nav()
-
- # Filter out private modules and the styles directory
- nav_file.writelines(filter_nav(literate_nav))
\ No newline at end of file
diff --git a/docs/getting_help.md b/docs/getting_help.md
index 9381a49c..eb6fc896 100644
--- a/docs/getting_help.md
+++ b/docs/getting_help.md
@@ -1,8 +1,5 @@
----
-hide:
- - navigation
- - toc
----
+
+# Getting Help
We communicate via several channels on Github:
diff --git a/docs/how_to_guide/README.md b/docs/how_to_guide/README.md
index fa4b58fb..1a3ec6c8 100644
--- a/docs/how_to_guide/README.md
+++ b/docs/how_to_guide/README.md
@@ -3,9 +3,82 @@
Familiarize with NeMoS modules and learn how to take advantage of the `pynapple` and `scikit-learn` compatibility.
-??? attention "Additional requirements"
- To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
- You can install all of the required packages with the following command:
- ```
- pip install nemos[examples]
- ```
+:::{dropdown} Additional requirements
+:color: warning
+:icon: alert
+To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
+You can install all of the required packages with the following command:
+```
+pip install nemos[examples]
+```
+:::
+
+
+::::{grid} 1 2 3 3
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_02_glm_demo.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_03_population_glm.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_04_batch_glm.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_05_sklearn_pipeline_cv_demo.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_06_glm_pytree.md
+```
+:::
+
+::::
diff --git a/docs/how_to_guide/plot_02_glm_demo.py b/docs/how_to_guide/plot_02_glm_demo.md
similarity index 55%
rename from docs/how_to_guide/plot_02_glm_demo.py
rename to docs/how_to_guide/plot_02_glm_demo.md
index f1c6e3b2..fe18d3ec 100644
--- a/docs/how_to_guide/plot_02_glm_demo.py
+++ b/docs/how_to_guide/plot_02_glm_demo.md
@@ -1,11 +1,54 @@
-"""
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+
# GLM Demo: Toy Model Examples
-!!! warning
- This demonstration is currently in its alpha stage. It presents various regularization techniques on
- GLMs trained on a Gaussian noise stimuli, and a minimal example of fitting and simulating a pair of coupled
- neurons. More work needs to be done to properly compare the performance of the regularization strategies on
- realistic simulations and real neural recordings.
+:::{warning}
+This demonstration is currently in its alpha stage. It presents various regularization techniques on
+GLMs trained on a Gaussian noise stimuli, and a minimal example of fitting and simulating a pair of coupled
+neurons. More work needs to be done to properly compare the performance of the regularization strategies on
+realistic simulations and real neural recordings.
+:::
## Introduction
@@ -23,8 +66,7 @@
we are going to use for this tutorial, and generate some synthetic
data.
-"""
-
+```{code-cell} ipython3
import jax
import matplotlib.pyplot as plt
import numpy as np
@@ -47,33 +89,37 @@
# generate counts
rate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + b_true)
spikes = np.random.poisson(rate)
+```
+
+## The Feed-Forward GLM
-# %%
-# ## The Feed-Forward GLM
-#
-# ### Model Definition
-# The class implementing the feed-forward GLM is `nemos.glm.GLM`.
-# In order to define the class, one **must** provide:
-#
-# - **Observation Model**: The observation model for the GLM, e.g. an object of the class of type
-# `nemos.observation_models.Observations`. So far, only the `PoissonObservations`
-# model has been implemented.
-# - **Regularizer**: The desired regularizer, e.g. an object of the `nemos.regularizer.Regularizer` class.
-# Currently, we implemented the un-regularized, Ridge, Lasso, and Group-Lasso regularization.
-#
-# The default for the GLM class is the `PoissonObservations` with log-link function with a Ridge regularization.
-# Here is how to define the model.
+### Model Definition
+The class implementing the feed-forward GLM is [`nemos.glm.GLM`](nemos.glm.GLM).
+In order to define the class, one **must** provide:
+- **Observation Model**: The observation model for the GLM, e.g. an object of the class of type
+[`nemos.observation_models.Observations`](nemos.observation_models.Observations). So far, only the [`PoissonObservations`](nemos.observation_models.PoissonObservations)
+model has been implemented.
+- **Regularizer**: The desired regularizer, e.g. an object of the [`nemos.regularizer.Regularizer`](nemos.regularizer.Regularizer) class.
+Currently, we implemented the un-regularized, Ridge, Lasso, and Group-Lasso regularization.
+
+The default for the GLM class is the [`PoissonObservations`](nemos.observation_models.PoissonObservations) with log-link function with a Ridge regularization.
+Here is how to define the model.
+
+
+```{code-cell} ipython3
# default Poisson GLM with Ridge regularization and Poisson observation model.
model = nmo.glm.GLM()
print("Regularization type: ", type(model.regularizer))
print("Observation model:", type(model.observation_model))
+```
+
+### Model Configuration
+One could visualize the model hyperparameters by calling [`get_params`](nemos.glm.GLM.get_params) method.
-# %%
-# ### Model Configuration
-# One could visualize the model hyperparameters by calling `get_params` method.
+```{code-cell} ipython3
# get the glm model parameters only
print("\nGLM model parameters:")
for key, value in model.get_params(deep=False).items():
@@ -86,11 +132,13 @@
if key in model.get_params(deep=False):
continue
print(f"\t- {key}: {value}")
+```
-# %%
-# These parameters can be configured at initialization and/or
-# set after the model is initialized with the following syntax:
+These parameters can be configured at initialization and/or
+set after the model is initialized with the following syntax:
+
+```{code-cell} ipython3
# Poisson observation model with soft-plus NL
observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
@@ -104,10 +152,12 @@
print("Regularizer type: ", type(model.regularizer))
print("Observation model:", type(model.observation_model))
+```
+
+Hyperparameters can be set at any moment via the [`set_params`](nemos.glm.GLM.set_params) method.
-# %%
-# Hyperparameters can be set at any moment via the `set_params` method.
+```{code-cell} ipython3
model.set_params(
regularizer=nmo.regularizer.Lasso(),
observation_model__inverse_link_function=jax.numpy.exp
@@ -115,23 +165,27 @@
print("Updated regularizer: ", model.regularizer)
print("Updated NL: ", model.observation_model.inverse_link_function)
+```
+
+:::{warning}
+Each [`Regularizer`](regularizers) has an associated attribute [`Regularizer.allowed_solvers`](nemos.regularizer.Regularizer.allowed_solvers)
+which lists the optimizers that are suited for each optimization problem.
+For example, a [`Ridge`](nemos.regularizer.Ridge) is differentiable and can be fit with `GradientDescent`
+, `BFGS`, etc., while a [`Lasso`](nemos.regularizer.Lasso) should use the `ProximalGradient` method instead.
+If the provided `solver_name` is not listed in the `allowed_solvers` this will raise an
+exception.
+:::
+
-# %%
-# !!! warning
-# Each `Regularizer` has an associated attribute `Regularizer.allowed_solvers`
-# which lists the optimizers that are suited for each optimization problem.
-# For example, a `Ridge` is differentiable and can be fit with `GradientDescent`
-# , `BFGS`, etc., while a `Lasso` should use the `ProximalGradient` method instead.
-# If the provided `solver_name` is not listed in the `allowed_solvers` this will raise an
-# exception.
-
-# %%
-# ### Model Fit
-# Fitting the model is as straight forward as calling the `model.fit`
-# providing the design tensor and the population counts.
-# Additionally one may provide an initial parameter guess.
-# The same exact syntax works for any configuration.
+### Model Fit
+Fitting the model is as straight forward as calling the `model.fit`
+providing the design tensor and the population counts.
+Additionally one may provide an initial parameter guess.
+The same exact syntax works for any configuration.
+
+
+```{code-cell} ipython3
# fit a ridge regression Poisson GLM
model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
model.fit(X, spikes)
@@ -139,18 +193,20 @@
print("Ridge results")
print("True weights: ", w_true)
print("Recovered weights: ", model.coef_)
+```
+
+## K-fold Cross Validation with `sklearn`
+Our implementation follows the `scikit-learn` api, this enables us
+to take advantage of the `scikit-learn` tool-box seamlessly, while at the same time
+we take advantage of the `jax` GPU acceleration and auto-differentiation in the
+back-end.
+
+Here is an example of how we can perform 5-fold cross-validation via `scikit-learn`.
+
+**Ridge**
-# %%
-# ## K-fold Cross Validation with `sklearn`
-# Our implementation follows the `scikit-learn` api, this enables us
-# to take advantage of the `scikit-learn` tool-box seamlessly, while at the same time
-# we take advantage of the `jax` GPU acceleration and auto-differentiation in the
-# back-end.
-#
-# Here is an example of how we can perform 5-fold cross-validation via `scikit-learn`.
-#
-# **Ridge**
+```{code-cell} ipython3
parameter_grid = {"regularizer_strength": np.logspace(-1.5, 1.5, 6)}
# in practice, you should use more folds than 2, but for the purposes of this
# demo, 2 is sufficient.
@@ -161,12 +217,14 @@
print("Best hyperparameter: ", cls.best_params_)
print("True weights: ", w_true)
print("Recovered weights: ", cls.best_estimator_.coef_)
+```
-# %%
-# We can compare the Ridge cross-validated results with other regularization schemes.
-#
-# **Lasso**
+We can compare the Ridge cross-validated results with other regularization schemes.
+**Lasso**
+
+
+```{code-cell} ipython3
model.set_params(regularizer=nmo.regularizer.Lasso(), solver_name="ProximalGradient")
cls = model_selection.GridSearchCV(model, parameter_grid, cv=2)
cls.fit(X, spikes)
@@ -175,10 +233,12 @@
print("Best hyperparameter: ", cls.best_params_)
print("True weights: ", w_true)
print("Recovered weights: ", cls.best_estimator_.coef_)
+```
-# %%
-# **Group Lasso**
+**Group Lasso**
+
+```{code-cell} ipython3
# define groups by masking. Mask size (n_groups, n_features)
mask = np.zeros((2, 5))
mask[0, [0, -1]] = 1
@@ -195,12 +255,14 @@
print("Best hyperparameter: ", cls.best_params_)
print("True weights: ", w_true)
print("Recovered weights: ", cls.best_estimator_.coef_)
+```
+
+## Simulate Spikes
+We can generate spikes in response to a feedforward-stimuli
+through the `model.simulate` method.
-# %%
-# ## Simulate Spikes
-# We can generate spikes in response to a feedforward-stimuli
-# through the `model.simulate` method.
+```{code-cell} ipython3
# here we are creating a new data input, of 20 timepoints (arbitrary)
# with the same number of features (mandatory)
Xnew = np.random.normal(size=(20, ) + X.shape[1:])
@@ -210,32 +272,36 @@
plt.figure()
plt.eventplot(np.where(spikes)[0])
+```
+## Simulate a Recurrently Coupled Network
+In this section, we will show you how to generate spikes from a population; We assume that the coupling
+filters are known or inferred.
-# %%
-# ## Simulate a Recurrently Coupled Network
-# In this section, we will show you how to generate spikes from a population; We assume that the coupling
-# filters are known or inferred.
-#
-# !!! warning
-# Making sure that the dynamics of your recurrent neural network are stable is non-trivial[$^{[1]}$](#ref-1). In particular,
-# coupling weights obtained by fitting a GLM by maximum-likelihood can generate unstable dynamics. If the
-# dynamics of your recurrently coupled model are unstable, you can try a `soft-plus` non-linearity
-# instead of an exponential, and you can "shrink" your weights until stability is reached.
-#
+:::{warning}
+Making sure that the dynamics of your recurrent neural network are stable is non-trivial[$^{[1]}$](#ref-1). In particular,
+coupling weights obtained by fitting a GLM by maximum-likelihood can generate unstable dynamics. If the
+dynamics of your recurrently coupled model are unstable, you can try a `soft-plus` non-linearity
+instead of an exponential, and you can "shrink" your weights until stability is reached.
+:::
+
+
+```{code-cell} ipython3
# Neural population parameters
n_neurons = 2
coupling_filter_duration = 100
+```
-# %%
-# Let's define the coupling filters that we will use to simulate
-# the pairwise interactions between the neurons. We will model the
-# filters as a difference of two Gamma probability density function.
-# The negative component will capture inhibitory effects such as the
-# refractory period of a neuron, while the positive component will
-# describe excitation.
+Let's define the coupling filters that we will use to simulate
+the pairwise interactions between the neurons. We will model the
+filters as a difference of two Gamma probability density function.
+The negative component will capture inhibitory effects such as the
+refractory period of a neuron, while the positive component will
+describe excitation.
+
+```{code-cell} ipython3
np.random.seed(101)
# Gamma parameter for the inhibitory component of the filter
@@ -269,11 +335,13 @@
_, coupling_basis = basis.evaluate_on_grid(coupling_filter_bank.shape[0])
coupling_coeff = nmo.simulation.regress_filter(coupling_filter_bank, coupling_basis)
intercept = -4 * np.ones(n_neurons)
+```
+
+We can check that our approximation worked by plotting the original filters
+and the basis expansion
-# %%
-# We can check that our approximation worked by plotting the original filters
-# and the basis expansion
+```{code-cell} ipython3
# plot coupling functions
n_basis_coupling = coupling_basis.shape[1]
fig, axs = plt.subplots(n_neurons, n_neurons)
@@ -286,11 +354,13 @@
axs[unit_i, unit_j].plot(np.dot(coupling_basis, coeff), ls="--", color="k", label="basis function")
axs[0, 0].legend()
plt.tight_layout()
+```
-# %%
-# Define a squared stimulus current for the first neuron, and no stimulus for
-# the second neuron
+Define a squared stimulus current for the first neuron, and no stimulus for
+the second neuron
+
+```{code-cell} ipython3
# define a squared current parameters
simulation_duration = 1000
stimulus_onset = 200
@@ -321,11 +391,12 @@
# initialize the spikes for the recurrent simulation
init_spikes = np.zeros((coupling_filter_duration, n_neurons))
+```
+We can now simulate spikes by calling the `simulate_recurrent` function for the `nemos.simulate` module.
-# %%
-# We can now simulate spikes by calling the `simulate_recurrent` function for the `nemos.simulate` module.
+```{code-cell} ipython3
# call simulate, with both the recurrent coupling
# and the input
spikes, rates = nmo.simulation.simulate_recurrent(
@@ -337,12 +408,14 @@
coupling_basis_matrix=coupling_basis,
init_y=init_spikes
)
+```
+
+And finally plot the results for both neurons.
-# %%
-# And finally plot the results for both neurons.
+```{code-cell} ipython3
# mkdocs_gallery_thumbnail_number = 4
-plt.figure()
+fig = plt.figure()
ax = plt.subplot(111)
ax.spines['top'].set_visible(False)
@@ -360,7 +433,31 @@
plt.ylim(-0.011, .13)
plt.ylabel("count/bin")
plt.legend()
-
-# %%
-# ## References
-# [1] Arribas, Diego, Yuan Zhao, and Il Memming Park. "Rescuing neural spike train models from bad MLE." Advances in Neural Information Processing Systems 33 (2020): 2293-2303.
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/how_to_guide"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/how_to_guide")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_02_glm_demo.svg")
+```
+
+
+## References
+(ref-1)=
+[1] [Arribas, Diego, Yuan Zhao, and Il Memming Park. "Rescuing neural spike train models from bad MLE." Advances in Neural Information Processing Systems 33 (2020): 2293-2303.](https://arxiv.org/abs/2010.12362)
\ No newline at end of file
diff --git a/docs/how_to_guide/plot_03_glm_pytree.py b/docs/how_to_guide/plot_03_glm_pytree.py
deleted file mode 100644
index 02181d32..00000000
--- a/docs/how_to_guide/plot_03_glm_pytree.py
+++ /dev/null
@@ -1,298 +0,0 @@
-"""# FeaturePytree example
-
-This small example notebook shows how to use our custom FeaturePytree objects
-instead of arrays to represent the design matrix. It will show that these two
-representations are equivalent.
-
-This demo will fit the Poisson-GLM to some synthetic data. We will first show
-the simple case, with a single neuron receiving some input. We will then show a
-two-neuron system, to demonstrate how FeaturePytree can make it easier to
-separate examine separate types of inputs.
-
-First, however, let's briefly discuss FeaturePytrees.
-
-"""
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-import nemos as nmo
-
-np.random.seed(111)
-
-# %%
-# ## FeaturePytrees
-#
-# A FeaturePytree is a custom NeMoS object used to represent design matrices,
-# GLM coefficients, and other similar variables. It is a simple
-# [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), a dictionary
-# with strings as keys and arrays as values. These arrays must all have the
-# same number of elements along the first dimension, which represents the time
-# points, but can have different numbers of elements along the other dimensions
-# (and even different numbers of dimensions).
-
-example_pytree = nmo.pytrees.FeaturePytree(feature_0=np.random.normal(size=(100, 1, 2)),
- feature_1=np.random.normal(size=(100, 2)),
- feature_2=np.random.normal(size=(100, 5)))
-example_pytree
-
-# %%
-#
-# FeaturePytrees can be indexed into like dictionary, so we can grab a
-# single one of their features:
-
-example_pytree['feature_0'].shape
-
-# %%
-#
-# We can grab the number of time points by getting the length or using the
-# `shape` attribute
-
-print(len(example_pytree))
-print(example_pytree.shape)
-
-# %%
-#
-# We can also jointly index into the FeaturePytree's leaves:
-
-example_pytree[:10]
-
-# %%
-#
-# We can add new features after initialization, as long as they have the same
-# number of time points.
-
-example_pytree['feature_3'] = np.zeros((100, 2, 4))
-
-# %%
-#
-# However, if we try to add a new feature with the wrong number of time points,
-# we'll get an exception:
-
-try:
- example_pytree['feature_4'] = np.zeros((99, 2, 4))
-except ValueError as e:
- print(e)
-
-# %%
-#
-# Similarly, if we try to add a feature that's not an array:
-
-try:
- example_pytree['feature_4'] = "Strings are very predictive"
-except ValueError as e:
- print(e)
-
-# %%
-#
-# FeaturePytrees are intended to be used with
-# [jax.tree_util.tree_map](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html),
-# a useful function for performing computations on arbitrary pytrees,
-# preserving their structure.
-
-# %%
-# We can map lambda functions:
-mapped = jax.tree_util.tree_map(lambda x: x**2, example_pytree)
-print(mapped)
-mapped['feature_1']
-# %%
-# Or functions from jax or numpy that operate on arrays:
-mapped = jax.tree_util.tree_map(jnp.exp, example_pytree)
-print(mapped)
-mapped['feature_1']
-# %%
-# We can change the dimensionality of our pytree:
-mapped = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=-1), example_pytree)
-print(mapped)
-mapped['feature_1']
-# %%
-# Or the number of time points:
-mapped = jax.tree_util.tree_map(lambda x: x[::10], example_pytree)
-print(mapped)
-mapped['feature_1']
-# %%
-#
-# If we map something whose output cannot be a FeaturePytree (because its
-# values are scalars or non-arrays), we return a dictionary of arrays instead:
-print(jax.tree_util.tree_map(jnp.mean, example_pytree))
-print(jax.tree_util.tree_map(lambda x: x.shape, example_pytree))
-import matplotlib.pyplot as plt
-import pynapple as nap
-
-# %%
-#
-# ## FeaturePytrees and GLM
-#
-# These properties make FeaturePytrees useful for representing design matrices
-# and similar objects for the GLM.
-#
-# First, let's get our dataset and do some initial exploration of it. To do so,
-# we'll use pynapple to [stream
-# data](https://pynapple.org/examples/tutorial_pynapple_dandi.html)
-# from the DANDI archive.
-#
-# !!! attention
-#
-# We need some additional packages for this portion, which you can install
-# with `pip install dandi pynapple`
-
-io = nmo.fetch.download_dandi_data(
- "000582",
- "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
-)
-
-nwb = nap.NWBFile(io.read(), lazy_loading=False)
-
-print(nwb)
-
-# %%
-#
-# This data set has cells that are tuned for head direction and 2d position.
-# Let's compute some simple tuning curves to see if we can find a cell that
-# looks tuned for both.
-
-tc, binsxy = nap.compute_2d_tuning_curves(nwb['units'], nwb['SpatialSeriesLED1'].dropna(), 20)
-fig, axes = plt.subplots(3, 3, figsize=(9, 9))
-for i, ax in zip(tc.keys(), axes.flatten()):
- ax.imshow(tc[i], origin="lower", aspect="auto")
- ax.set_title("Unit {}".format(i))
-axes[-1,-1].remove()
-plt.tight_layout()
-
-# compute head direction.
-diff = nwb['SpatialSeriesLED1'].values-nwb['SpatialSeriesLED2'].values
-head_dir = np.arctan2(*diff.T)
-head_dir = nap.Tsd(nwb['SpatialSeriesLED1'].index, head_dir)
-
-tune_head = nap.compute_1d_tuning_curves(nwb['units'], head_dir.dropna(), 30)
-
-fig, axes = plt.subplots(3, 3, figsize=(9, 9), subplot_kw={'projection': 'polar'})
-for i, ax in zip(tune_head.columns, axes.flatten()):
- ax.plot(tune_head.index, tune_head[i])
- ax.set_title("Unit {}".format(i))
-axes[-1,-1].remove()
-
-# %%
-#
-# Okay, let's use unit number 7.
-#
-# Now let's set up our design matrix. First, let's fit the head direction by
-# itself. Head direction is a circular variable (pi and -pi are adjacent to
-# each other), so we need to use a basis that has this property as well.
-# `CyclicBSplineBasis` is one such basis.
-#
-# Let's create our basis and then arrange our data properly.
-
-unit_no = 7
-spikes = nwb['units'][unit_no]
-
-basis = nmo.basis.CyclicBSplineBasis(10, order=5)
-x = np.linspace(-np.pi, np.pi, 100)
-plt.figure()
-plt.plot(x, basis(x))
-
-# Find the interval on which head_dir has no NaNs
-head_dir = head_dir.dropna()
-# Grab the second (of two), since the first one is really short
-valid_data= head_dir.time_support.loc[[1]]
-head_dir = head_dir.restrict(valid_data)
-# Count spikes at the same rate as head direction, over the same epoch
-spikes = spikes.count(bin_size=1/head_dir.rate, ep=valid_data)
-# the time points for spike are in the middle of these bins (whereas for
-# head_dir they're at the ends), so use interpolate to shift head_dir to the
-# center.
-head_dir = head_dir.interpolate(spikes)
-
-X = nmo.pytrees.FeaturePytree(head_direction=basis(head_dir))
-
-# %%
-#
-# Now we'll fit our GLM and then see what our head direction tuning looks like:
-model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)
-model.fit(X, spikes)
-print(model.coef_['head_direction'])
-
-bs_vis = basis(x)
-tuning = jnp.einsum('b, tb->t', model.coef_['head_direction'], bs_vis)
-plt.figure()
-plt.polar(x, tuning)
-
-# %%
-#
-# This looks like a smoothed version of our tuning curve, like we'd expect!
-#
-# For a more direct comparison, we can plot the tuning function based on the model predicted
-# firing rates with that estimated from the counts.
-
-
-# predict rates and convert back to pynapple
-rates_nap = nap.TsdFrame(t=head_dir.t, d=np.asarray(model.predict(X)))
-# compute tuning function
-tune_head_model = nap.compute_1d_tuning_curves_continuous(rates_nap, head_dir, 30)
-# compare model prediction with data
-fig, ax = plt.subplots(1, 1, subplot_kw={'projection': 'polar'})
-ax.plot(tune_head[7], label="counts")
-# multiply by the sampling rate for converting to spike/sec.
-ax.plot(tune_head_model * rates_nap.rate, label="model")
-
-# Let's compare this to using arrays, to see what it looks like:
-
-model = nmo.glm.GLM()
-model.fit(X['head_direction'], spikes)
-model.coef_
-
-# %%
-#
-# We can see that the solution is identical, as is the way of interacting with
-# the GLM object.
-#
-# However, with a single type of feature, it's unclear why exactly this is
-# helpful. Let's add a feature for the animal's position in space. For this
-# feature, we need a 2d basis. Let's use some raised cosine bumps and organize
-# our data similarly.
-
-pos_basis = nmo.basis.RaisedCosineBasisLinear(10) * nmo.basis.RaisedCosineBasisLinear(10)
-spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data)
-
-X['spatial_position'] = pos_basis(*spatial_pos.values.T)
-
-# %%
-#
-# Running the GLM is identical to before, but we can see that our coef_
-# FeaturePytree now has two separate keys, one for each feature type.
-
-model = nmo.glm.GLM(solver_name="LBFGS")
-model.fit(X, spikes)
-model.coef_
-
-# %%
-#
-# Let's visualize our tuning. Head direction looks pretty much the same (though
-# the values are slightly different, as we can see when printing out the
-# coefficients).
-
-bs_vis = basis(x)
-tuning = jnp.einsum('b,nb->n', model.coef_['head_direction'], bs_vis)
-print(model.coef_['head_direction'])
-plt.figure()
-plt.polar(x, tuning.T)
-
-# %%
-#
-# And the spatial tuning again looks like a smoothed version of our earlier
-# tuning curves.
-_, _, pos_bs_vis = pos_basis.evaluate_on_grid(50, 50)
-pos_tuning = jnp.einsum('b,ijb->ij', model.coef_['spatial_position'], pos_bs_vis)
-plt.figure()
-plt.imshow(pos_tuning)
-
-# %%
-#
-# We could do all this with matrices as well, but we have to pay attention to
-# indices in a way that is annoying:
-
-X_mat = nmo.utils.pynapple_concatenate_jax([X['head_direction'], X['spatial_position']], -1)
-
-model = nmo.glm.GLM()
-model.fit(X_mat, spikes)
-model.coef_[..., :basis.n_basis_funcs]
diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_03_population_glm.md
similarity index 51%
rename from docs/how_to_guide/plot_04_population_glm.py
rename to docs/how_to_guide/plot_03_population_glm.md
index 70dac9cd..7c27acd5 100644
--- a/docs/how_to_guide/plot_04_population_glm.py
+++ b/docs/how_to_guide/plot_03_population_glm.md
@@ -1,29 +1,73 @@
-"""
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+
# Population GLM
Fitting the activity of a neural population with NeMoS can be much more efficient than fitting each individual
neuron in a loop. The reason for this is that NeMoS leverages the powerful GPU-vectorization implemented by `JAX`.
-!!! note
- For an unregularized, Lasso, Ridge, or group-Lasso GLM, fitting a GLM one neuron at the time, or fitting jointly
- the neural population is equivalent. The main difference between the approaches is that the former is more
- memory efficient, the latter is computationally more efficient (it takes less time to fit).
+:::{note}
+For an unregularized, Lasso, Ridge, or group-Lasso GLM, fitting a GLM one neuron at the time, or fitting jointly
+the neural population is equivalent. The main difference between the approaches is that the former is more
+memory efficient, the latter is computationally more efficient (it takes less time to fit).
+:::
## Fitting a Population GLM
-NeMoS has a dedicated `nemos.GLM.PopulationGLM` class for fitting jointly a neural population. The API
- is very similar to that the regular `nemos.glm.GLM`, but with a few differences:
+NeMoS has a dedicated [`nemos.GLM.PopulationGLM`](nemos.glm.PopulationGLM) class for fitting jointly a neural population. The API
+ is very similar to that the regular [`GLM`](nemos.glm.GLM), but with a few differences:
- 1. The `y` input to the methods `fit` and `score` must be a two-dimensional array of shape `(n_samples, n_neurons)`.
+ 1. The `y` input to the methods [`fit`](nemos.glm.PopulationGLM.fit) and [`score`](nemos.glm.PopulationGLM.score) must be a two-dimensional array of shape `(n_samples, n_neurons)`.
2. You can optionally pass a `feature_mask` in the form of an array of 0s and 1s with shape `(n_features, n_neurons)`
that specifies which features are used as predictors for each neuron. More on this [later](#neuron-specific-features).
Let's generate some synthetic data and fit a population model.
-"""
+```{code-cell} ipython3
import jax.numpy as jnp
import matplotlib.pyplot as plt
+
import numpy as np
import nemos as nmo
@@ -47,34 +91,41 @@
spikes = np.random.poisson(rate)
print(spikes.shape)
+```
+
+We can now instantiate the [`PopulationGLM`](nemos.glm.PopulationGLM) model and fit.
+
-# %%
-# We can now instantiate the `PopulationGLM` model and fit.
+```{code-cell} ipython3
model = nmo.glm.PopulationGLM()
model.fit(X, spikes)
print(f"population GLM log-likelihood: {model.score(X, spikes)}")
+```
-# %%
-# ## Neuron-specific features
-# If you want to model neurons with different input features, the way to do so is to specify a `feature_mask`.
-# Let's assume that we have two neurons, share one shared input, and have an extra private one, for a total of
-# 3 inputs.
+(neuron-specific-features)=
+## Neuron-specific features
+If you want to model neurons with different input features, the way to do so is to specify a `feature_mask`.
+Let's assume that we have two neurons, share one shared input, and have an extra private one, for a total of
+3 inputs.
+
+```{code-cell} ipython3
# let's take the first three input
n_features = 3
input_features = X[:, :3]
+```
+
+Let's assume that:
+
+ - `input_features[:, 0]` is shared.
+ - `input_features[:, 1]` is an input only for the first neuron.
+ - `input_features[:, 2]` is an input only for the second neuron.
+We can simulate this scenario,
-# %%
-# Let's assume that:
-#
-# - `input_features[:, 0]` is shared.
-# - `input_features[:, 1]` is an input only for the first neuron.
-# - `input_features[:, 2]` is an input only for the second neuron.
-#
-# We can simulate this scenario,
+```{code-cell} ipython3
# model the rate of the first neuron using only the first two features and weights.
rate_neuron_1 = jnp.exp(np.dot(input_features[:, [0, 1]], w_true[: 2, 0]))
@@ -84,10 +135,12 @@
# stack the rates in a (n_samples, n_neurons) array and generate spikes
rate = np.hstack((rate_neuron_1[:, np.newaxis], rate_neuron_2[:, np.newaxis]))
spikes = np.random.poisson(rate)
+```
-# %%
-# We can impose the same constraint to the `PopulationGLM` by masking the weights.
+We can impose the same constraint to the [`PopulationGLM`](nemos.glm.PopulationGLM) by masking the weights.
+
+```{code-cell} ipython3
# initialize the mask to a matrix of 1s.
feature_mask = np.ones((n_features, n_neurons))
@@ -99,11 +152,13 @@
# visualize the mask
print(feature_mask)
+```
+
+The mask can be passed at initialization or set after the model is initialized, but cannot be changed
+after the model is fit.
-# %%
-# The mask can be passed at initialization or set after the model is initialized, but cannot be changed
-# after the model is fit.
+```{code-cell} ipython3
# set a quasi-newton solver and low tolerance for better numerical precision
model = nmo.glm.PopulationGLM(solver_name="LBFGS", solver_kwargs={"tol": 10**-12})
@@ -112,19 +167,23 @@
# fit the model
model.fit(input_features, spikes)
+```
-# %%
-# If we print the model coefficients, we can see the effect of the mask.
+If we print the model coefficients, we can see the effect of the mask.
+
+```{code-cell} ipython3
print(model.coef_)
+```
+
+The coefficient for the first neuron corresponding to the last feature is zero, as well as
+the coefficient of the second neuron corresponding to the second feature.
+
+To convince ourselves that this is equivalent to fit each neuron individually with the correct features,
+let's go ahead and try.
-# %%
-# The coefficient for the first neuron corresponding to the last feature is zero, as well as
-# the coefficient of the second neuron corresponding to the second feature.
-#
-# To convince ourselves that this is equivalent to fit each neuron individually with the correct features,
-# let's go ahead and try.
+```{code-cell} ipython3
# features for each neuron
features_by_neuron = {
0: [0, 1],
@@ -155,14 +214,38 @@
if neuron == 1:
plt.legend()
plt.tight_layout()
-
-# %%
-# ## FeaturePytree
-# `PopulationGLM` is compatible with [`FeaturePytree`](../plot_03_glm_pytree). If you structured your predictors
-# in a `FeaturePytree`, the `feature_mask` needs to be a dictionary of the same structure, containing arrays
-# of shape `(n_neurons, )`.
-# The example above can be reformulated as follows,
-
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/how_to_guide"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/how_to_guide")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_03_population_glm.svg")
+```
+
+## FeaturePytree
+[`PopulationGLM`](nemos.glm.PopulationGLM) is compatible with [`FeaturePytree`](nemos.pytrees.FeaturePytree). If you structured your predictors
+in a [`FeaturePytree`](nemos.pytrees.FeaturePytree), the `feature_mask` needs to be a dictionary of the same structure, containing arrays
+of shape `(n_neurons, )`.
+The example above can be reformulated as follows,
+
+
+```{code-cell} ipython3
# restructure the input as FeaturePytree
pytree_features = nmo.pytrees.FeaturePytree(
shared=input_features[:, :1],
@@ -183,3 +266,4 @@
# print the coefficients
print(model_tree.coef_)
+```
diff --git a/docs/how_to_guide/plot_04_batch_glm.md b/docs/how_to_guide/plot_04_batch_glm.md
new file mode 100644
index 00000000..707e90da
--- /dev/null
+++ b/docs/how_to_guide/plot_04_batch_glm.md
@@ -0,0 +1,293 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+
+# Batching example
+
+Here we demonstrate how to setup and run a stochastic gradient descent in `nemos`
+by batching and using the [`update`](nemos.glm.GLM.update) method of the model class.
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+nap.nap_config.suppress_conversion_warnings = True
+
+# set random seed
+np.random.seed(123)
+```
+
+## Simulate data
+
+Let's generate some data artificially
+
+
+```{code-cell} ipython3
+n_neurons = 10
+T = 50
+
+times = np.linspace(0, T, 5000).reshape(-1, 1)
+rate = np.exp(np.sin(times + np.linspace(0, np.pi*2, n_neurons).reshape(1, n_neurons)))
+```
+
+Get the spike times from the rate and generate a `TsGroup` object
+
+
+```{code-cell} ipython3
+spike_t, spike_id = np.where(np.random.poisson(rate))
+units = nap.Tsd(spike_t/T, spike_id).to_tsgroup()
+```
+
+## Model configuration
+
+Let's imagine this dataset do not fit in memory. We can use a batching approach to train the GLM.
+First we need to instantiate the [`PopulationGLM`](nemos.glm.PopulationGLM) . The default algorithm for [`PopulationGLM`](nemos.glm.PopulationGLM) is gradient descent.
+We suggest to use it for batching.
+
+:::{note}
+You must shutdown the dynamic update of the step for fitting a batched (also called stochastic) gradient descent.
+In jaxopt, this can be done by setting the parameters `acceleration` to False and setting the `stepsize`.
+:::
+
+
+```{code-cell} ipython3
+glm = nmo.glm.PopulationGLM(
+ solver_name="GradientDescent",
+ solver_kwargs={"stepsize": 0.1, "acceleration": False}
+ )
+```
+
+## Basis instantiation
+
+Here we instantiate the basis. `ws` is 40 time bins. It corresponds to a 200 ms windows
+
+
+```{code-cell} ipython3
+ws = 40
+basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=ws)
+```
+
+## Batch definition
+
+The batch size needs to be larger than the window size of the convolution kernel defined above.
+
+
+```{code-cell} ipython3
+batch_size = 5 # second
+```
+
+Here we define a batcher function that generate a random 5 s of design matrix and spike counts.
+This function will be called during each iteration of the stochastic gradient descent.
+
+
+```{code-cell} ipython3
+def batcher():
+ # Grab a random time within the time support. Here is the time support is one epoch only so it's easy.
+ t = np.random.uniform(units.time_support[0, 0], units.time_support[0, 1]-batch_size)
+
+ # Bin the spike train in a 1s batch
+ ep = nap.IntervalSet(t, t+batch_size)
+ counts = units.restrict(ep).count(0.005) # count in 5 ms bins
+
+ # Convolve
+ X = basis.compute_features(counts)
+
+ # Return X and counts
+ return X, counts
+```
+
+## Solver initialization
+
+First we need to initialize the gradient descent solver within the [`PopulationGLM`](nemos.glm.PopulationGLM) .
+This gets you the initial parameters and the first state of the solver.
+
+
+```{code-cell} ipython3
+params = glm.initialize_params(*batcher())
+state = glm.initialize_state(*batcher(), params)
+```
+
+## Batch learning
+
+Let's do a few iterations of gradient descent calling the `batcher` function at every step.
+At each step, we store the log-likelihood of the model for each neuron evaluated on the batch
+
+
+```{code-cell} ipython3
+n_step = 500
+logl = np.zeros(n_step)
+
+for i in range(n_step):
+
+ # Get a batch of data
+ X, Y = batcher()
+
+ # Do one step of gradient descent.
+ params, state = glm.update(params, state, X, Y)
+
+ # Score the model along the time axis
+ logl[i] = glm.score(X, Y, score_type="log-likelihood")
+```
+
+:::{admonition} Input validation
+:class: warning
+
+The `update` method does not perform input validation each time it is called.
+This design choice speeds up computation by avoiding repetitive checks. However,
+it requires that all inputs to the `update` method strictly conform to the expected
+dimensionality and structure as established during the initialization of the solver.
+Failure to comply with these expectations will likely result in runtime errors or
+incorrect computations.
+:::
+
+First let's plot the log-likelihood to see if the model is converging.
+
+
+```{code-cell} ipython3
+fig = plt.figure()
+plt.plot(logl)
+plt.xlabel("Iteration")
+plt.ylabel("Log-likelihood")
+plt.show()
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/how_to_guide"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/how_to_guide")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_04_batch_glm.svg")
+```
+
+We can see that the log-likelihood is increasing but did not reach plateau yet.
+The number of iterations can be increased to continue learning.
+
+We can take a look at the coefficients.
+Here we extract the weight matrix of shape `(n_neurons*n_basis, n_neurons)`
+and reshape it to `(n_neurons, n_basis, n_neurons)`.
+We then average along basis to get a weight matrix of shape `(n_neurons, n_neurons)`.
+
+
+```{code-cell} ipython3
+W = glm.coef_.reshape(len(units), basis.n_basis_funcs, len(units))
+Wm = np.mean(np.abs(W), 1)
+
+# Let's plot it.
+
+plt.figure()
+plt.imshow(Wm)
+plt.xlabel("Neurons")
+plt.ylabel("Neurons")
+plt.show()
+```
+
+## Model comparison
+
+Since this example is small enough, we can fit the full model and compare the scores.
+Here we generate the design matrix and spike counts for the whole dataset.
+
+
+```{code-cell} ipython3
+Y = units.count(0.005)
+X = basis.compute_features(Y)
+full_model = nmo.glm.PopulationGLM().fit(X, Y)
+```
+
+Now that the full model is fitted, we are scoring the full model and the batch model against the full datasets to compare the scores.
+The score is pseudo-R2
+
+
+```{code-cell} ipython3
+full_scores = full_model.score(
+ X, Y, aggregate_sample_scores=lambda x:np.mean(x, axis=0), score_type="pseudo-r2-McFadden"
+)
+batch_scores = glm.score(
+ X, Y, aggregate_sample_scores=lambda x:np.mean(x, axis=0), score_type="pseudo-r2-McFadden"
+)
+```
+
+Let's compare scores for each neurons as well as the coefficients.
+
+
+```{code-cell} ipython3
+plt.figure(figsize=(10, 8))
+gs = plt.GridSpec(3,2)
+plt.subplot(gs[0,:])
+plt.bar(np.arange(0, n_neurons), full_scores, 0.4, label="Full model")
+plt.bar(np.arange(0, n_neurons)+0.5, batch_scores, 0.4, label="Batch model")
+plt.ylabel("Pseudo R2")
+plt.xlabel("Neurons")
+plt.ylim(0, 1)
+plt.legend()
+plt.subplot(gs[1:,0])
+plt.imshow(Wm)
+plt.title("Batch model")
+plt.subplot(gs[1:,1])
+Wm2 = np.mean(
+ np.abs(
+ full_model.coef_.reshape(len(units), basis.n_basis_funcs, len(units))
+ )
+ , 1)
+plt.imshow(Wm2)
+plt.title("Full model")
+plt.tight_layout()
+plt.show()
+```
+
+As we can see, with a few iterations, the batch model manage to recover a similar coefficient matrix.
diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py
deleted file mode 100644
index 7454d6f1..00000000
--- a/docs/how_to_guide/plot_05_batch_glm.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""
-# Batching example
-
-Here we demonstrate how to setup and run a stochastic gradient descent in `nemos`
-by batching and using the `update` method of the model class.
-
-"""
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-nap.nap_config.suppress_conversion_warnings = True
-
-# set random seed
-np.random.seed(123)
-
-# %%
-# ## Simulate data
-#
-# Let's generate some data artificially
-n_neurons = 10
-T = 50
-
-times = np.linspace(0, T, 5000).reshape(-1, 1)
-rate = np.exp(np.sin(times + np.linspace(0, np.pi*2, n_neurons).reshape(1, n_neurons)))
-
-# %%
-# Get the spike times from the rate and generate a `TsGroup` object
-spike_t, spike_id = np.where(np.random.poisson(rate))
-units = nap.Tsd(spike_t/T, spike_id).to_tsgroup()
-
-
-# %%
-# ## Model configuration
-#
-# Let's imagine this dataset do not fit in memory. We can use a batching approach to train the GLM.
-# First we need to instantiate the `PopulationGLM`. The default algorithm for `PopulationGLM` is gradient descent.
-# We suggest to use it for batching.
-#
-# !!! Note
-# You must shutdown the dynamic update of the step for fitting a batched (also called stochastic) gradient descent.
-# In jaxopt, this can be done by setting the parameters `acceleration` to False and setting the `stepsize`.
-#
-glm = nmo.glm.PopulationGLM(
- solver_name="GradientDescent",
- solver_kwargs={"stepsize": 0.1, "acceleration": False}
- )
-
-# %%
-# ## Basis instantiation
-#
-# Here we instantiate the basis with a window size of 40 time bins. It corresponds to a 200ms windows
-# for a 5ms bin size.
-basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=40)
-
-# %%
-# ## Batch definition
-#
-# The batch size needs to be larger than the window size of the convolution kernel defined above.
-batch_size = 5 # second
-
-
-# %%
-# Here we define a batcher function that generate a random 5 s of design matrix and spike counts.
-# This function will be called during each iteration of the stochastic gradient descent.
-def batcher():
- # Grab a random time within the time support. Here is the time support is one epoch only so it's easy.
- t = np.random.uniform(units.time_support[0, 0], units.time_support[0, 1]-batch_size)
-
- # Bin the spike train in a 1s batch
- ep = nap.IntervalSet(t, t+batch_size)
- counts = units.restrict(ep).count(0.005) # count in 5 ms bins
-
- # Convolve
- X = basis.compute_features(counts)
-
- # Return X and counts
- return X, counts
-
-
-# %%
-# ## Solver initialization
-#
-# First we need to initialize the gradient descent solver within the `PopulationGLM`.
-# This gets you the initial parameters and the first state of the solver.
-params = glm.initialize_params(*batcher())
-state = glm.initialize_state(*batcher(), params)
-
-# %%
-# ## Batch learning
-#
-# Let's do a few iterations of gradient descent calling the `batcher` function at every step.
-# At each step, we store the log-likelihood of the model for each neuron evaluated on the batch
-n_step = 500
-logl = np.zeros(n_step)
-
-for i in range(n_step):
-
- # Get a batch of data
- X, Y = batcher()
-
- # Do one step of gradient descent.
- params, state = glm.update(params, state, X, Y)
-
- # Score the model along the time axis
- logl[i] = glm.score(X, Y, score_type="log-likelihood")
-
-
-# %%
-#
-# !!! Warning "Input validation"
-# The `update` method does not perform input validation each time it is called.
-# This design choice speeds up computation by avoiding repetitive checks. However,
-# it requires that all inputs to the `update` method strictly conform to the expected
-# dimensionality and structure as established during the initialization of the solver.
-# Failure to comply with these expectations will likely result in runtime errors or
-# incorrect computations.
-#
-# First let's plot the log-likelihood to see if the model is converging.
-
-plt.figure()
-plt.plot(logl)
-plt.xlabel("Iteration")
-plt.ylabel("Log-likelihood")
-plt.show()
-
-
-# %%
-# We can see that the log-likelihood is increasing but did not reach plateau yet.
-# The number of iterations can be increased to continue learning.
-#
-# We can take a look at the coefficients.
-# Here we extract the weight matrix of shape `(n_neurons*n_basis, n_neurons)`
-# and reshape it to `(n_neurons, n_basis, n_neurons)`.
-# We then average along basis to get a weight matrix of shape `(n_neurons, n_neurons)`.
-
-W = glm.coef_.reshape(len(units), basis.n_basis_funcs, len(units))
-Wm = np.mean(np.abs(W), 1)
-
-# Let's plot it.
-
-plt.figure()
-plt.imshow(Wm)
-plt.xlabel("Neurons")
-plt.ylabel("Neurons")
-plt.show()
-
-# %%
-# ## Model comparison
-#
-# Since this example is small enough, we can fit the full model and compare the scores.
-# Here we generate the design matrix and spike counts for the whole dataset.
-Y = units.count(0.005)
-X = basis.compute_features(Y)
-full_model = nmo.glm.PopulationGLM().fit(X, Y)
-
-# %%
-# Now that the full model is fitted, we are scoring the full model and the batch model against the full datasets to compare the scores.
-# The score is pseudo-R2
-full_scores = full_model.score(
- X, Y, aggregate_sample_scores=lambda x:np.mean(x, axis=0), score_type="pseudo-r2-McFadden"
-)
-batch_scores = glm.score(
- X, Y, aggregate_sample_scores=lambda x:np.mean(x, axis=0), score_type="pseudo-r2-McFadden"
-)
-
-# %%
-# Let's compare scores for each neurons as well as the coefficients.
-
-plt.figure(figsize=(10, 8))
-gs = plt.GridSpec(3,2)
-plt.subplot(gs[0,:])
-plt.bar(np.arange(0, n_neurons), full_scores, 0.4, label="Full model")
-plt.bar(np.arange(0, n_neurons)+0.5, batch_scores, 0.4, label="Batch model")
-plt.ylabel("Pseudo R2")
-plt.xlabel("Neurons")
-plt.ylim(0, 1)
-plt.legend()
-plt.subplot(gs[1:,0])
-plt.imshow(Wm)
-plt.title("Batch model")
-plt.subplot(gs[1:,1])
-Wm2 = np.mean(
- np.abs(
- full_model.coef_.reshape(len(units), basis.n_basis_funcs, len(units))
- )
- , 1)
-plt.imshow(Wm2)
-plt.title("Full model")
-plt.tight_layout()
-plt.show()
-
-# %%
-# As we can see, with a few iterations, the batch model manage to recover a similar coefficient matrix.
-
-
-
diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md
new file mode 100644
index 00000000..7684857c
--- /dev/null
+++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md
@@ -0,0 +1,608 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+(sklearn-how-to)=
+# Selecting basis by cross-validation with scikit-learn
+
+In this demo, we will demonstrate how to select an appropriate basis and its hyperparameters using cross-validation.
+In particular, we will learn:
+
+1. What a scikit-learn pipeline is.
+2. Why pipelines are useful.
+3. How to combine NeMoS [`Basis`](nemos.basis.Basis) and [`GLM`](nemos.glm.GLM) objects in a pipeline.
+4. How to select the number of bases and the basis type through cross-validation (or any other hyperparameter in the pipeline).
+5. How to use a custom scoring metric to quantify the performance of each configuration.
+
+
+
+## What is a scikit-learn pipeline
+
+
+
+A pipeline is a sequence of data transformations leading up to a model. Each step before the final one transforms the input data into a different representation, and then the final model step fits, predicts, or scores based on the previous step's output and some observations. Setting up such machinery can be simplified using the [`Pipeline`](https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html) class from scikit-learn.
+
+To set up a scikit-learn [`Pipeline`](https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html), ensure that:
+
+1. Each intermediate step is a [scikit-learn transformer object](https://scikit-learn.org/stable/data_transforms.html) with a `transform` and/or `fit_transform` method.
+2. The final step is an [estimator object](https://scikit-learn.org/stable/developers/develop.html#estimators) with a `fit` method, or a model with `fit`, `predict`, and `score` methods.
+
+Each transformation step takes a 2D array `X` of shape `(num_samples, num_original_features)` as input and outputs another 2D array of shape `(num_samples, num_transformed_features)`. The final step takes a pair `(X, y)`, where `X` is as before, and `y` is a 1D array of shape `(n_samples,)` containing the observations to be modeled.
+
+You can define a pipeline as follows:
+```python
+from sklearn.pipeline import Pipeline
+
+# Assume transformer_i/predictor is a transformer/model object
+pipe = Pipeline(
+ [
+ ("label_1", transformer_1),
+ ("label_2", transformer_2),
+ ...,
+ ("label_n", transformer_n),
+ ("label_model", model)
+ ]
+)
+```
+
+Note that you have to assign a label to each step of the pipeline.
+:::{tip}
+Here we used a placeholder `"label_i"` for demonstration; you should choose a more descriptive name depending on the type of transformation step.
+:::
+
+Calling `pipe.fit(X, y)` will perform the following computations:
+```python
+# Chain of transformations
+X1 = transformer_1.fit_transform(X)
+X2 = transformer_2.fit_transform(X1)
+# ...
+Xn = transformer_n.fit_transform(Xn_1)
+
+# Fit step
+model.fit(Xn, y)
+```
+And the same holds for `pipe.score` and `pipe.predict`.
+
+## Why pipelines are useful
+
+Pipelines not only streamline and simplify your code but also offer several other advantages. The real power of pipelines becomes evident when combined with the scikit-learn [`model_selection`](https://scikit-learn.org/1.5/api/sklearn.model_selection.html) module, which includes cross-validation and similar methods. This combination allows you to tune hyperparameters at each step of the pipeline in a straightforward manner.
+
+In the following sections, we will showcase this approach with a concrete example: selecting the appropriate basis type and number of bases for a GLM regression in NeMoS.
+
+## Combining basis transformations and GLM in a pipeline
+Let's start by creating some toy data.
+
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scipy.stats
+import seaborn as sns
+from sklearn.model_selection import GridSearchCV
+from sklearn.pipeline import Pipeline
+
+import nemos as nmo
+
+# some helper plotting functions
+from nemos import _documentation_utils as doc_plots
+
+# predictors, shape (n_samples, n_features)
+X = np.random.uniform(low=0, high=1, size=(1000, 1))
+# observed counts, shape (n_samples,)
+rate = 2 * (
+ scipy.stats.norm.pdf(X, scale=0.1, loc=0.25)
+ + scipy.stats.norm.pdf(X, scale=0.1, loc=0.75)
+)
+y = np.random.poisson(rate).astype(float).flatten()
+```
+
+Let's now plot the simulated neuron's tuning curve, which is bimodal, Gaussian-shaped, and has peaks at 0.25 and 0.75.
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots()
+ax.scatter(X.flatten(), y, alpha=0.2)
+ax.set_xlabel("input")
+ax.set_ylabel("spike count")
+sns.despine(ax=ax)
+```
+
+### Converting NeMoS `Basis` to a transformer
+In order to use NeMoS [`Basis`](nemos.basis.Basis) in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the [`TransformerBasis`](nemos.basis.TransformerBasis) wrapper class.
+
+Instantiating a [`TransformerBasis`](nemos.basis.TransformerBasis) can be done either using the constructor directly or with [`Basis.to_transformer()`](nemos.basis.Basis.to_transformer):
+
+
+```{code-cell} ipython3
+bas = nmo.basis.RaisedCosineBasisLinear(5, mode="conv", window_size=5)
+# these two ways of creating the TransformerBasis are equivalent
+trans_bas_a = nmo.basis.TransformerBasis(bas)
+trans_bas_b = bas.to_transformer()
+```
+
+[`TransformerBasis`](nemos.basis.TransformerBasis) provides convenient access to the underlying [`Basis`](nemos.basis.Basis) object's attributes:
+
+
+```{code-cell} ipython3
+print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs)
+```
+
+We can also set attributes of the underlying [`Basis`](nemos.basis.Basis). Note that -- because [`TransformerBasis`](nemos.basis.TransformerBasis) is created with a copy of the [`Basis`](nemos.basis.Basis) object passed to it -- this does not change the original [`Basis`](nemos.basis.Basis), and neither does changing the original [`Basis`](nemos.basis.Basis) change [`TransformerBasis`](nemos.basis.TransformerBasis) we created:
+
+
+```{code-cell} ipython3
+trans_bas_a.n_basis_funcs = 10
+bas.n_basis_funcs = 100
+
+print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs)
+```
+
+### Creating and fitting a pipeline
+We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data.
+
+This is exactly what `Pipeline` is for!
+
+
+```{code-cell} ipython3
+pipeline = Pipeline(
+ [
+ (
+ "transformerbasis",
+ nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6)),
+ ),
+ (
+ "glm",
+ nmo.glm.GLM(regularizer_strength=0.5, regularizer="Ridge"),
+ ),
+ ]
+)
+
+pipeline.fit(X, y)
+```
+
+Note how NeMoS models are already scikit-learn compatible and can be used directly in the pipeline.
+
+Visualize the fit:
+
+
+```{code-cell} ipython3
+# Predict the rate.
+# Note that you need a 2D input even if x is a flat array.
+# We are using expand dim to add the extra-dimension
+x = np.sort(X, axis=0)
+predicted_rate = pipeline.predict(x)
+```
+
+```{code-cell} ipython3
+fig, ax = plt.subplots()
+
+ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
+ax.set_xlabel("input")
+ax.set_ylabel("spike count")
+
+
+ax.plot(
+ x,
+ predicted_rate,
+ label="predicted rate",
+ color="tab:orange",
+)
+
+ax.legend()
+sns.despine(ax=ax)
+```
+
+The current model captures the bimodal distribution of responses, appropriately picking out the peaks. However, it doesn't do a good job capturing the actual firing rate: the peaks are too low and the valleys are not low enough. This might be because of our choice of basis and/or regularizer strength, so let's see if tuning those parameters results in a better fit! We could do this manually, but doing this with the sklearn pipeline will make everything much easier!
+
+
+
+
+### Select the number of basis by cross-validation
+
+
+
+
+:::{warning}
+Please keep in mind that while [`GLM.score`](nemos.glm.GLM.score) supports different ways of evaluating goodness-of-fit through the `score_type` argument, `pipeline.score(X, y, score_type="...")` does not propagate this, and uses the default value of `log-likelihood`.
+
+To evaluate a pipeline, please create a custom scorer (e.g. `pseudo_r2` below) and call `my_custom_scorer(pipeline, X, y)`.
+:::
+
+#### Define the parameter grid
+
+Let's define candidate values for the parameters of each step of the pipeline we want to cross-validate. In this case the number of basis functions in the transformation step and the ridge regularization's strength in the GLM fit:
+
+
+```{code-cell} ipython3
+param_grid = dict(
+ glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
+ transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
+)
+```
+
+:::{admonition} Grid definition
+:class: info
+In order to define a parameter grid dictionary for a pipeline, you must structure the dictionary keys as follows:
+
+- Start with the pipeline label (`"glm"` or `"transformerbasis"` for us). This determines which pipeline step has the relevant hyperparameter.
+- Add `"__"` followed by the hyperparameter name (for example, `"n_basis_funcs"`).
+- If the hyperparameter is itself an object with attributes, add another `"__"` followed by the attribute name. For instance, `"glm__observation_model__inverse_link_function"`
+ would be a valid key for cross-validating over the link function of the GLM's `observation_model` attribute `inverse_link_function`.
+The values in the dictionary are the parameters to be tested.
+:::
+
+
+
+#### Run the grid search
+Let's run a 5-fold cross-validation of the hyperparameters with the scikit-learn [`model_selection.GridsearchCV`](https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV) class.
+:::{dropdown} K-Fold cross-validation
+:color: info
+:icon: info
+
+
+K-fold cross-validation is a robust method used for selecting hyperparameters. In this procedure, the data is divided into K equally sized chunks (or folds). The model is trained on K-1 of these chunks, with the remaining chunk used for evaluation. This process is repeated K times, with each chunk being used exactly once as the evaluation set.
+After completing the K iterations, the K evaluation scores are averaged to provide a reliable estimate of the model's performance. To select the optimal hyperparameters, K-fold cross-validation can be applied over a grid of potential hyperparameters, with the set yielding the highest average score being chosen as the best.
+:::
+
+```{code-cell} ipython3
+gridsearch = GridSearchCV(
+ pipeline,
+ param_grid=param_grid,
+ cv=5
+)
+
+# run the 5-fold cross-validation grid search
+gridsearch.fit(X, y)
+```
+
+:::{dropdown} Manual cross-validation
+:color: info
+:icon: info
+To appreciate how much boiler-plate code we are saving by calling scikit-learn cross-validation, below
+we can see how this cross-validation will look like in a manual loop.
+
+```python
+from itertools import product
+from copy import deepcopy
+
+regularizer_strength = (0.1, 0.01, 0.001, 1e-6)
+n_basis_funcs = (3, 5, 10, 20, 100)
+
+# define the folds
+n_folds = 5
+fold_idx = np.arange(X.shape[0] - X.shape[0] % n_folds).reshape(n_folds, -1)
+
+
+# Initialize the scores
+scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds))
+
+# Dictionary to store coefficients
+coeffs = {}
+
+# initialize basis and model
+basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6))
+model = nmo.glm.GLM(regularizer="Ridge")
+
+# loop over combinations
+for fold in range(n_folds):
+ test_idx = fold_idx[fold]
+ train_idx = fold_idx[[x for x in range(n_folds) if x != fold]].flatten()
+ for i, params in enumerate(product(regularizer_strength, n_basis_funcs)):
+ reg_strength, n_basis = params
+
+ # deepcopy the basis and model
+ bas = deepcopy(basis)
+ glm = deepcopy(model)
+
+ # set the parameters
+ bas.n_basis_funcs = n_basis
+ glm.regularizer_strength = reg_strength
+
+ # fit the model
+ glm.fit(bas.transform(X[train_idx]), y[train_idx])
+
+ # store score and coefficients
+ scores[i, fold] = glm.score(bas.transform(X[test_idx]), y[test_idx])
+ coeffs[(i, fold)] = (glm.coef_, glm.intercept_)
+
+# get the best mean test score
+i_best = np.argmax(scores.mean(axis=1))
+# get the overall best coeffs
+fold_best = np.argmax(scores[i_best])
+
+# set up the best model
+model.coef_ = coeffs[(i_best, fold_best)][0]
+model.intercept_ = coeffs[(i_best, fold_best)][1]
+
+# get the best hyperparameters
+best_reg_strength = regularizer_strength[i_best // len(n_basis_funcs)]
+best_n_basis = n_basis_funcs[i_best % len(n_basis_funcs)]
+```
+:::
+
+
+
+#### Visualize the scores
+
+Let's extract the scores from `gridsearch` and take a look at how the different parameter values of our pipeline influence the test score:
+
+
+```{code-cell} ipython3
+cvdf = pd.DataFrame(gridsearch.cv_results_)
+
+cvdf_wide = cvdf.pivot(
+ index="param_transformerbasis__n_basis_funcs",
+ columns="param_glm__regularizer_strength",
+ values="mean_test_score",
+)
+
+doc_plots.plot_heatmap_cv_results(cvdf_wide)
+```
+
+The plot displays the model's log-likelihood for each parameter combination in the grid. The parameter combination with the highest score, which is the one selected by the procedure, is highlighted with a blue rectangle. We can thus see that we need 10 or more basis functions, and that all of the tested regularization strengths agree with each other. In general, we want the fewest number of basis functions required to get a good fit, so we'll choose 10 here.
+
+#### Visualize the predicted rate
+Finally, visualize the predicted firing rates using the best model found by our grid-search, which gives a better fit than the randomly chosen parameter values we tried in the beginning:
+
+
+```{code-cell} ipython3
+# Predict the ate using the best configuration,
+x = np.sort(X, axis=0)
+predicted_rate = gridsearch.best_estimator_.predict(x)
+```
+
+```{code-cell} ipython3
+fig, ax = plt.subplots()
+
+ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
+ax.set_xlabel("input")
+ax.set_ylabel("spike count")
+
+
+ax.plot(
+ x,
+ predicted_rate,
+ label="predicted rate",
+ color="tab:orange",
+)
+
+ax.legend()
+sns.despine(ax=ax)
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/how_to_guide"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/how_to_guide")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_05_sklearn_pipeline_cv_demo.svg")
+```
+
+🚀🚀🚀 **Success!** 🚀🚀🚀
+
+We are now able to capture the distribution of the firing rate appropriately: both peaks and valleys in the spiking activity are matched by our model predicitons.
+
+### Evaluating different bases directly
+
+In the previous example we set the number of basis functions of the [`Basis`](nemos.basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well.
+
+Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`:
+
+
+```{code-cell} ipython3
+param_grid = dict(
+ glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
+ transformerbasis___basis=(
+ nmo.basis.RaisedCosineBasisLinear(5),
+ nmo.basis.RaisedCosineBasisLinear(10),
+ nmo.basis.RaisedCosineBasisLog(5),
+ nmo.basis.RaisedCosineBasisLog(10),
+ nmo.basis.MSplineBasis(5),
+ nmo.basis.MSplineBasis(10),
+ ),
+)
+```
+
+Then run the grid search:
+
+
+```{code-cell} ipython3
+gridsearch = GridSearchCV(
+ pipeline,
+ param_grid=param_grid,
+ cv=5,
+)
+
+# run the 5-fold cross-validation grid search
+gridsearch.fit(X, y)
+```
+
+Wrangling the output data a bit and looking at the scores:
+
+
+```{code-cell} ipython3
+cvdf = pd.DataFrame(gridsearch.cv_results_)
+
+# Read out the number of basis functions
+cvdf["transformerbasis_config"] = [
+ f"{b.__class__.__name__} - {b.n_basis_funcs}"
+ for b in cvdf["param_transformerbasis___basis"]
+]
+
+cvdf_wide = cvdf.pivot(
+ index="transformerbasis_config",
+ columns="param_glm__regularizer_strength",
+ values="mean_test_score",
+)
+
+doc_plots.plot_heatmap_cv_results(cvdf_wide)
+```
+
+As shown in the table, the model with the highest score, highlighted in blue, used a RaisedCosineBasisLinear basis (as used above), which appears to be a suitable choice for our toy data.
+We can confirm that by plotting the firing rate predictions:
+
+
+```{code-cell} ipython3
+# Predict the rate using the optimal configuration
+x = np.sort(X, axis=0)
+predicted_rate = gridsearch.best_estimator_.predict(x)
+```
+
+```{code-cell} ipython3
+fig, ax = plt.subplots()
+
+ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
+ax.set_xlabel("input")
+ax.set_ylabel("spike count")
+
+ax.plot(
+ x,
+ predicted_rate,
+ label="predicted rate",
+ color="tab:orange",
+)
+
+ax.legend()
+sns.despine(ax=ax)
+```
+
+The plot confirms that the firing rate distribution is accurately captured by our model predictions.
+
+
+
+
+:::{warning}
+Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error:
+
+```python
+param_grid = dict(
+ glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
+ transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
+ transformerbasis___basis=(
+ nmo.basis.RaisedCosineBasisLinear(5),
+ nmo.basis.RaisedCosineBasisLinear(10),
+ nmo.basis.RaisedCosineBasisLog(5),
+ nmo.basis.RaisedCosineBasisLog(10),
+ nmo.basis.MSplineBasis(5),
+ nmo.basis.MSplineBasis(10),
+ ),
+)
+```
+:::
+
+
+
+## Create a custom scorer
+By default, the GLM score method returns the model log-likelihood. If you want to try a different metric, such as the pseudo-R2, you can create a custom scorer and pass it to the cross-validation object:
+
+
+```{code-cell} ipython3
+from sklearn.metrics import make_scorer
+
+pseudo_r2 = make_scorer(
+ nmo.observation_models.PoissonObservations().pseudo_r2
+)
+```
+
+We can now run the grid search providing the custom scorer
+
+
+```{code-cell} ipython3
+gridsearch = GridSearchCV(
+ pipeline,
+ param_grid=param_grid,
+ cv=5,
+ scoring=pseudo_r2,
+)
+
+# Run the 5-fold cross-validation grid search
+gridsearch.fit(X, y)
+```
+
+And finally, we can plot each model's score.
+
+
+
+
+Plot the pseudo-R2 scores
+
+
+```{code-cell} ipython3
+cvdf = pd.DataFrame(gridsearch.cv_results_)
+
+# Read out the number of basis functions
+cvdf["transformerbasis_config"] = [
+ f"{b.__class__.__name__} - {b.n_basis_funcs}"
+ for b in cvdf["param_transformerbasis___basis"]
+]
+
+cvdf_wide = cvdf.pivot(
+ index="transformerbasis_config",
+ columns="param_glm__regularizer_strength",
+ values="mean_test_score",
+)
+
+doc_plots.plot_heatmap_cv_results(cvdf_wide, label="pseudo-R2")
+```
+
+As you can see, the results with pseudo-R2 agree with those of the negative log-likelihood. Note that this new metric is normalized between 0 and 1, with a higher score indicating better performance.
diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md
new file mode 100644
index 00000000..6945460e
--- /dev/null
+++ b/docs/how_to_guide/plot_06_glm_pytree.md
@@ -0,0 +1,404 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+# FeaturePytree example
+
+This small example notebook shows how to use our custom FeaturePytree objects
+instead of arrays to represent the design matrix. It will show that these two
+representations are equivalent.
+
+This demo will fit the Poisson-GLM to some synthetic data. We will first show
+the simple case, with a single neuron receiving some input. We will then show a
+two-neuron system, to demonstrate how FeaturePytree can make it easier to
+separate examine separate types of inputs.
+
+First, however, let's briefly discuss [`FeaturePytrees`](nemos.pytrees.FeaturePytree).
+
+```{code-cell} ipython3
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+import nemos as nmo
+
+np.random.seed(111)
+```
+
+## FeaturePytrees
+
+A FeaturePytree is a custom NeMoS object used to represent design matrices,
+GLM coefficients, and other similar variables. It is a simple
+[pytree](https://jax.readthedocs.io/en/latest/pytrees.html), a dictionary
+with strings as keys and arrays as values. These arrays must all have the
+same number of elements along the first dimension, which represents the time
+points, but can have different numbers of elements along the other dimensions
+(and even different numbers of dimensions).
+
+
+```{code-cell} ipython3
+example_pytree = nmo.pytrees.FeaturePytree(feature_0=np.random.normal(size=(100, 1, 2)),
+ feature_1=np.random.normal(size=(100, 2)),
+ feature_2=np.random.normal(size=(100, 5)))
+example_pytree
+```
+
+FeaturePytrees can be indexed into like dictionary, so we can grab a
+single one of their features:
+
+
+```{code-cell} ipython3
+example_pytree['feature_0'].shape
+```
+
+We can grab the number of time points by getting the length or using the
+`shape` attribute
+
+
+```{code-cell} ipython3
+print(len(example_pytree))
+print(example_pytree.shape)
+```
+
+We can also jointly index into the FeaturePytree's leaves:
+
+
+```{code-cell} ipython3
+example_pytree[:10]
+```
+
+We can add new features after initialization, as long as they have the same
+number of time points.
+
+
+```{code-cell} ipython3
+example_pytree['feature_3'] = np.zeros((100, 2, 4))
+```
+
+However, if we try to add a new feature with the wrong number of time points,
+we'll get an exception:
+
+
+```{code-cell} ipython3
+try:
+ example_pytree['feature_4'] = np.zeros((99, 2, 4))
+except ValueError as e:
+ print(e)
+```
+
+Similarly, if we try to add a feature that's not an array:
+
+
+```{code-cell} ipython3
+try:
+ example_pytree['feature_4'] = "Strings are very predictive"
+except ValueError as e:
+ print(e)
+```
+
+FeaturePytrees are intended to be used with
+[jax.tree_util.tree_map](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html),
+a useful function for performing computations on arbitrary pytrees,
+preserving their structure.
+
+
+
+
+We can map lambda functions:
+
+
+```{code-cell} ipython3
+mapped = jax.tree_util.tree_map(lambda x: x**2, example_pytree)
+print(mapped)
+mapped['feature_1']
+```
+
+Or functions from jax or numpy that operate on arrays:
+
+
+```{code-cell} ipython3
+mapped = jax.tree_util.tree_map(jnp.exp, example_pytree)
+print(mapped)
+mapped['feature_1']
+```
+
+We can change the dimensionality of our pytree:
+
+
+```{code-cell} ipython3
+mapped = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=-1), example_pytree)
+print(mapped)
+mapped['feature_1']
+```
+
+Or the number of time points:
+
+
+```{code-cell} ipython3
+mapped = jax.tree_util.tree_map(lambda x: x[::10], example_pytree)
+print(mapped)
+mapped['feature_1']
+```
+
+If we map something whose output cannot be a FeaturePytree (because its
+values are scalars or non-arrays), we return a dictionary of arrays instead:
+
+
+```{code-cell} ipython3
+print(jax.tree_util.tree_map(jnp.mean, example_pytree))
+print(jax.tree_util.tree_map(lambda x: x.shape, example_pytree))
+import matplotlib.pyplot as plt
+import pynapple as nap
+
+nap.nap_config.suppress_conversion_warnings = True
+```
+
+## FeaturePytrees and GLM
+
+These properties make FeaturePytrees useful for representing design matrices
+and similar objects for the [`GLM`](nemos.glm.GLM).
+
+First, let's get our dataset and do some initial exploration of it. To do so,
+we'll use pynapple to [stream
+data](https://pynapple.org/examples/tutorial_pynapple_dandi.html)
+from the DANDI archive.
+
+:::{attention}
+
+We need some additional packages for this portion, which you can install
+with `pip install dandi pynapple`
+:::
+
+```{code-cell} ipython3
+io = nmo.fetch.download_dandi_data(
+ "000582",
+ "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
+)
+
+nwb = nap.NWBFile(io.read(), lazy_loading=False)
+
+print(nwb)
+```
+
+This data set has cells that are tuned for head direction and 2d position.
+Let's compute some simple tuning curves to see if we can find a cell that
+looks tuned for both.
+
+
+```{code-cell} ipython3
+tc, binsxy = nap.compute_2d_tuning_curves(nwb['units'], nwb['SpatialSeriesLED1'].dropna(), 20)
+fig, axes = plt.subplots(3, 3, figsize=(9, 9))
+for i, ax in zip(tc.keys(), axes.flatten()):
+ ax.imshow(tc[i], origin="lower", aspect="auto")
+ ax.set_title("Unit {}".format(i))
+axes[-1,-1].remove()
+plt.tight_layout()
+
+# compute head direction.
+diff = nwb['SpatialSeriesLED1'].values-nwb['SpatialSeriesLED2'].values
+head_dir = np.arctan2(*diff.T)
+head_dir = nap.Tsd(nwb['SpatialSeriesLED1'].index, head_dir)
+
+tune_head = nap.compute_1d_tuning_curves(nwb['units'], head_dir.dropna(), 30)
+
+fig, axes = plt.subplots(3, 3, figsize=(9, 9), subplot_kw={'projection': 'polar'})
+for i, ax in zip(tune_head.columns, axes.flatten()):
+ ax.plot(tune_head.index, tune_head[i])
+ ax.set_title("Unit {}".format(i))
+axes[-1,-1].remove()
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/how_to_guide"
+# if local store in ../_build/html/...
+else:
+ path = Path("../_build/html/_static/thumbnails/how_to_guide")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_06_glm_pytree.svg")
+```
+
+
+Okay, let's use unit number 7.
+
+Now let's set up our design matrix. First, let's fit the head direction by
+itself. Head direction is a circular variable (pi and -pi are adjacent to
+each other), so we need to use a basis that has this property as well.
+[`CyclicBSplineBasis`](nemos.basis.CyclicBSplineBasis) is one such basis.
+
+Let's create our basis and then arrange our data properly.
+
+
+```{code-cell} ipython3
+unit_no = 7
+spikes = nwb['units'][unit_no]
+
+basis = nmo.basis.CyclicBSplineBasis(10, order=5)
+x = np.linspace(-np.pi, np.pi, 100)
+plt.figure()
+plt.plot(x, basis(x))
+
+# Find the interval on which head_dir has no NaNs
+head_dir = head_dir.dropna()
+# Grab the second (of two), since the first one is really short
+valid_data= head_dir.time_support.loc[[1]]
+head_dir = head_dir.restrict(valid_data)
+# Count spikes at the same rate as head direction, over the same epoch
+spikes = spikes.count(bin_size=1/head_dir.rate, ep=valid_data)
+# the time points for spike are in the middle of these bins (whereas for
+# head_dir they're at the ends), so use interpolate to shift head_dir to the
+# center.
+head_dir = head_dir.interpolate(spikes)
+
+X = nmo.pytrees.FeaturePytree(head_direction=basis(head_dir))
+```
+
+Now we'll fit our GLM and then see what our head direction tuning looks like:
+
+
+```{code-cell} ipython3
+model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.001)
+model.fit(X, spikes)
+print(model.coef_['head_direction'])
+
+bs_vis = basis(x)
+tuning = jnp.einsum('b, tb->t', model.coef_['head_direction'], bs_vis)
+plt.figure()
+plt.polar(x, tuning)
+```
+
+This looks like a smoothed version of our tuning curve, like we'd expect!
+
+For a more direct comparison, we can plot the tuning function based on the model predicted
+firing rates with that estimated from the counts.
+
+
+```{code-cell} ipython3
+# predict rates and convert back to pynapple
+rates_nap = nap.TsdFrame(t=head_dir.t, d=np.asarray(model.predict(X)))
+# compute tuning function
+tune_head_model = nap.compute_1d_tuning_curves_continuous(rates_nap, head_dir, 30)
+# compare model prediction with data
+fig, ax = plt.subplots(1, 1, subplot_kw={'projection': 'polar'})
+ax.plot(tune_head[7], label="counts")
+# multiply by the sampling rate for converting to spike/sec.
+ax.plot(tune_head_model * rates_nap.rate, label="model")
+
+# Let's compare this to using arrays, to see what it looks like:
+
+model = nmo.glm.GLM()
+model.fit(X['head_direction'], spikes)
+model.coef_
+```
+
+We can see that the solution is identical, as is the way of interacting with
+the GLM object.
+
+However, with a single type of feature, it's unclear why exactly this is
+helpful. Let's add a feature for the animal's position in space. For this
+feature, we need a 2d basis. Let's use some raised cosine bumps and organize
+our data similarly.
+
+
+```{code-cell} ipython3
+pos_basis = nmo.basis.RaisedCosineBasisLinear(10) * nmo.basis.RaisedCosineBasisLinear(10)
+spatial_pos = nwb['SpatialSeriesLED1'].restrict(valid_data)
+
+X['spatial_position'] = pos_basis(*spatial_pos.values.T)
+```
+
+Running the GLM is identical to before, but we can see that our coef_
+FeaturePytree now has two separate keys, one for each feature type.
+
+
+```{code-cell} ipython3
+model = nmo.glm.GLM(solver_name="LBFGS")
+model.fit(X, spikes)
+model.coef_
+```
+
+Let's visualize our tuning. Head direction looks pretty much the same (though
+the values are slightly different, as we can see when printing out the
+coefficients).
+
+
+```{code-cell} ipython3
+bs_vis = basis(x)
+tuning = jnp.einsum('b,nb->n', model.coef_['head_direction'], bs_vis)
+print(model.coef_['head_direction'])
+plt.figure()
+plt.polar(x, tuning.T)
+```
+
+And the spatial tuning again looks like a smoothed version of our earlier
+tuning curves.
+
+
+```{code-cell} ipython3
+_, _, pos_bs_vis = pos_basis.evaluate_on_grid(50, 50)
+pos_tuning = jnp.einsum('b,ijb->ij', model.coef_['spatial_position'], pos_bs_vis)
+plt.figure()
+plt.imshow(pos_tuning)
+```
+
+We could do all this with matrices as well, but we have to pay attention to
+indices in a way that is annoying:
+
+
+```{code-cell} ipython3
+X_mat = nmo.utils.pynapple_concatenate_jax([X['head_direction'], X['spatial_position']], -1)
+
+model = nmo.glm.GLM()
+model.fit(X_mat, spikes)
+model.coef_[..., :basis.n_basis_funcs]
+```
diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
deleted file mode 100644
index ca9b167a..00000000
--- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
+++ /dev/null
@@ -1,506 +0,0 @@
-"""
-# Selecting basis by cross-validation with scikit-learn
-
-In this demo, we will demonstrate how to select an appropriate basis and its hyperparameters using cross-validation.
-In particular, we will learn:
-
-1. What a scikit-learn pipeline is.
-2. Why pipelines are useful.
-3. How to combine NeMoS `Basis` and `GLM` objects in a pipeline.
-4. How to select the number of bases and the basis type through cross-validation (or any other hyperparameter in the pipeline).
-5. How to use a custom scoring metric to quantify the performance of each configuration.
-
-"""
-
-# %%
-# ## What is a scikit-learn pipeline
-#
-#
-#
-# A pipeline is a sequence of data transformations leading up to a model. Each step before the final one transforms the input data into a different representation, and then the final model step fits, predicts, or scores based on the previous step's output and some observations. Setting up such machinery can be simplified using the `Pipeline` class from scikit-learn.
-#
-# To set up a scikit-learn `Pipeline`, ensure that:
-#
-# 1. Each intermediate step is a [scikit-learn transformer object](https://scikit-learn.org/stable/data_transforms.html) with a `transform` and/or `fit_transform` method.
-# 2. The final step is an [estimator object](https://scikit-learn.org/stable/developers/develop.html#estimators) with a `fit` method, or a model with `fit`, `predict`, and `score` methods.
-#
-# Each transformation step takes a 2D array `X` of shape `(num_samples, num_original_features)` as input and outputs another 2D array of shape `(num_samples, num_transformed_features)`. The final step takes a pair `(X, y)`, where `X` is as before, and `y` is a 1D array of shape `(n_samples,)` containing the observations to be modeled.
-#
-# You can define a pipeline as follows:
-# ```python
-# from sklearn.pipeline import Pipeline
-#
-# # Assume transformer_i/predictor is a transformer/model object
-# pipe = Pipeline(
-# [
-# ("label_1", transformer_1),
-# ("label_2", transformer_2),
-# ...,
-# ("label_n", transformer_n),
-# ("label_model", model)
-# ]
-# )
-# ```
-#
-# Note that you have to assign a label to each step of the pipeline.
-# !!! tip
-# Here we used a placeholder `"label_i"` for demonstration; you should choose a more descriptive name depending on the type of transformation step.
-#
-# Calling `pipe.fit(X, y)` will perform the following computations:
-# ```python
-# # Chain of transformations
-# X1 = transformer_1.fit_transform(X)
-# X2 = transformer_2.fit_transform(X1)
-# # ...
-# Xn = transformer_n.fit_transform(Xn_1)
-#
-# # Fit step
-# model.fit(Xn, y)
-# ```
-# And the same holds for `pipe.score` and `pipe.predict`.
-#
-# ## Why pipelines are useful
-#
-# Pipelines not only streamline and simplify your code but also offer several other advantages. The real power of pipelines becomes evident when combined with the scikit-learn `model_selection` module, which includes cross-validation and similar methods. This combination allows you to tune hyperparameters at each step of the pipeline in a straightforward manner.
-#
-# In the following sections, we will showcase this approach with a concrete example: selecting the appropriate basis type and number of bases for a GLM regression in NeMoS.
-#
-# ## Combining basis transformations and GLM in a pipeline
-# Let's start by creating some toy data.
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import scipy.stats
-import seaborn as sns
-from sklearn.model_selection import GridSearchCV
-from sklearn.pipeline import Pipeline
-
-import nemos as nmo
-
-# some helper plotting functions
-from nemos import _documentation_utils as doc_plots
-
-# predictors, shape (n_samples, n_features)
-X = np.random.uniform(low=0, high=1, size=(1000, 1))
-# observed counts, shape (n_samples,)
-rate = 2 * (
- scipy.stats.norm.pdf(X, scale=0.1, loc=0.25)
- + scipy.stats.norm.pdf(X, scale=0.1, loc=0.75)
-)
-y = np.random.poisson(rate).astype(float).flatten()
-
-# %%
-# Let's now plot the simulated neuron's tuning curve, which is bimodal, Gaussian-shaped, and has peaks at 0.25 and 0.75.
-
-# %%
-fig, ax = plt.subplots()
-ax.scatter(X.flatten(), y, alpha=0.2)
-ax.set_xlabel("input")
-ax.set_ylabel("spike count")
-sns.despine(ax=ax)
-
-# %%
-# ### Converting NeMoS `Basis` to a transformer
-# In order to use NeMoS `Basis` in a pipeline, we need to convert it into a scikit-learn transformer. This can be achieved through the `TransformerBasis` wrapper class.
-#
-# Instantiating a `TransformerBasis` can be done either using the constructor directly or with `Basis.to_transformer()`:
-
-# %%
-bas = nmo.basis.RaisedCosineBasisLinear(5, mode="conv", window_size=5)
-# these two ways of creating the TransformerBasis are equivalent
-trans_bas_a = nmo.basis.TransformerBasis(bas)
-trans_bas_b = bas.to_transformer()
-
-# %%
-# `TransformerBasis` provides convenient access to the underlying `Basis` object's attributes:
-
-# %%
-print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs)
-
-# %%
-# We can also set attributes of the underlying `Basis`. Note that -- because `TransformerBasis` is created with a copy of the `Basis` object passed to it -- this does not change the original `Basis`, and neither does changing the original `Basis` change `TransformerBasis` we created:
-
-# %%
-trans_bas_a.n_basis_funcs = 10
-bas.n_basis_funcs = 100
-
-print(bas.n_basis_funcs, trans_bas_a.n_basis_funcs, trans_bas_b.n_basis_funcs)
-
-# %%
-# ### Creating and fitting a pipeline
-# We might want to combine first transforming the input data with our basis functions, then fitting a GLM on the transformed data.
-#
-# This is exactly what `Pipeline` is for!
-
-# %%
-pipeline = Pipeline(
- [
- (
- "transformerbasis",
- nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6)),
- ),
- (
- "glm",
- nmo.glm.GLM(regularizer_strength=0.5, regularizer="Ridge"),
- ),
- ]
-)
-
-pipeline.fit(X, y)
-
-# %%
-# Note how NeMoS models are already scikit-learn compatible and can be used directly in the pipeline.
-#
-# Visualize the fit:
-
-# %%
-
-# Predict the rate.
-# Note that you need a 2D input even if x is a flat array.
-# We are using expand dim to add the extra-dimension
-x = np.sort(X, axis=0)
-predicted_rate = pipeline.predict(x)
-
-# %%
-fig, ax = plt.subplots()
-
-ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
-ax.set_xlabel("input")
-ax.set_ylabel("spike count")
-
-
-ax.plot(
- x,
- predicted_rate,
- label="predicted rate",
- color="tab:orange",
-)
-
-ax.legend()
-sns.despine(ax=ax)
-
-# %%
-# The current model captures the bimodal distribution of responses, appropriately picking out the peaks. However, it doesn't do a good job capturing the actual firing rate: the peaks are too low and the valleys are not low enough. This might be because of our choice of basis and/or regularizer strength, so let's see if tuning those parameters results in a better fit! We could do this manually, but doing this with the sklearn pipeline will make everything much easier!
-
-# %%
-# ### Select the number of basis by cross-validation
-
-# %%
-# !!! warning
-# Please keep in mind that while `GLM.score` supports different ways of evaluating goodness-of-fit through the `score_type` argument, `pipeline.score(X, y, score_type="...")` does not propagate this, and uses the default value of `log-likelihood`.
-#
-# To evaluate a pipeline, please create a custom scorer (e.g. `pseudo_r2` below) and call `my_custom_scorer(pipeline, X, y)`.
-#
-# #### Define the parameter grid
-#
-# Let's define candidate values for the parameters of each step of the pipeline we want to cross-validate. In this case the number of basis functions in the transformation step and the ridge regularization's strength in the GLM fit:
-
-# %%
-param_grid = dict(
- glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
- transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
-)
-
-# %%
-# !!! note "Grid definition"
-# In order to define a parameter grid dictionary for a pipeline, you must structure the dictionary keys as follows:
-#
-# - Start with the pipeline label (`"glm"` or `"transformerbasis"` for us). This determines which pipeline step has the relevant hyperparameter.
-# - Add `"__"` followed by the hyperparameter name (for example, `"n_basis_funcs"`).
-# - If the hyperparameter is itself an object with attributes, add another `"__"` followed by the attribute name. For instance, `"glm__observation_model__inverse_link_function"`
-# would be a valid key for cross-validating over the link function of the GLM's `observation_model` attribute `inverse_link_function`.
-# The values in the dictionary are the parameters to be tested.
-
-
-
-# %%
-# #### Run the grid search
-# Let's run a 5-fold cross-validation of the hyperparameters with the scikit-learn `model_selection.GridsearchCV` class.
-# ??? info "K-Fold cross-validation"
-#
-#
-#
-# K-fold cross-validation (modified from scikit-learn docs)
-#
-# K-fold cross-validation is a robust method used for selecting hyperparameters. In this procedure, the data is divided into K equally sized chunks (or folds). The model is trained on K-1 of these chunks, with the remaining chunk used for evaluation. This process is repeated K times, with each chunk being used exactly once as the evaluation set.
-# After completing the K iterations, the K evaluation scores are averaged to provide a reliable estimate of the model's performance. To select the optimal hyperparameters, K-fold cross-validation can be applied over a grid of potential hyperparameters, with the set yielding the highest average score being chosen as the best.
-
-
-# %%
-gridsearch = GridSearchCV(
- pipeline,
- param_grid=param_grid,
- cv=5
-)
-
-# run the 5-fold cross-validation grid search
-gridsearch.fit(X, y)
-
-# %%
-#
-# ??? note "Manual cross-validation"
-# To appreciate how much boiler-plate code we are saving by calling scikit-learn cross-validation, below
-# we can see how this cross-validation will look like in a manual loop.
-#
-# ```python
-# from itertools import product
-# from copy import deepcopy
-#
-# regularizer_strength = (0.1, 0.01, 0.001, 1e-6)
-# n_basis_funcs = (3, 5, 10, 20, 100)
-#
-# # define the folds
-# n_folds = 5
-# fold_idx = np.arange(X.shape[0] - X.shape[0] % n_folds).reshape(n_folds, -1)
-#
-#
-# # Initialize the scores
-# scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds))
-#
-# # Dictionary to store coefficients
-# coeffs = {}
-#
-# # initialize basis and model
-# basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineBasisLinear(6))
-# model = nmo.glm.GLM(regularizer="Ridge")
-#
-# # loop over combinations
-# for fold in range(n_folds):
-# test_idx = fold_idx[fold]
-# train_idx = fold_idx[[x for x in range(n_folds) if x != fold]].flatten()
-# for i, params in enumerate(product(regularizer_strength, n_basis_funcs)):
-# reg_strength, n_basis = params
-#
-# # deepcopy the basis and model
-# bas = deepcopy(basis)
-# glm = deepcopy(model)
-#
-# # set the parameters
-# bas.n_basis_funcs = n_basis
-# glm.regularizer_strength = reg_strength
-#
-# # fit the model
-# glm.fit(bas.transform(X[train_idx]), y[train_idx])
-#
-# # store score and coefficients
-# scores[i, fold] = glm.score(bas.transform(X[test_idx]), y[test_idx])
-# coeffs[(i, fold)] = (glm.coef_, glm.intercept_)
-#
-# # get the best mean test score
-# i_best = np.argmax(scores.mean(axis=1))
-# # get the overall best coeffs
-# fold_best = np.argmax(scores[i_best])
-#
-# # set up the best model
-# model.coef_ = coeffs[(i_best, fold_best)][0]
-# model.intercept_ = coeffs[(i_best, fold_best)][1]
-#
-# # get the best hyperparameters
-# best_reg_strength = regularizer_strength[i_best // len(n_basis_funcs)]
-# best_n_basis = n_basis_funcs[i_best % len(n_basis_funcs)]
-# ```
-
-# %%
-# #### Visualize the scores
-#
-# Let's extract the scores from `gridsearch` and take a look at how the different parameter values of our pipeline influence the test score:
-
-
-cvdf = pd.DataFrame(gridsearch.cv_results_)
-
-cvdf_wide = cvdf.pivot(
- index="param_transformerbasis__n_basis_funcs",
- columns="param_glm__regularizer_strength",
- values="mean_test_score",
-)
-
-doc_plots.plot_heatmap_cv_results(cvdf_wide)
-
-# %%
-# The plot displays the model's log-likelihood for each parameter combination in the grid. The parameter combination with the highest score, which is the one selected by the procedure, is highlighted with a blue rectangle. We can thus see that we need 10 or more basis functions, and that all of the tested regularization strengths agree with each other. In general, we want the fewest number of basis functions required to get a good fit, so we'll choose 10 here.
-#
-# #### Visualize the predicted rate
-# Finally, visualize the predicted firing rates using the best model found by our grid-search, which gives a better fit than the randomly chosen parameter values we tried in the beginning:
-
-# %%
-
-# Predict the ate using the best configuration,
-x = np.sort(X, axis=0)
-predicted_rate = gridsearch.best_estimator_.predict(x)
-
-# %%
-fig, ax = plt.subplots()
-
-ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
-ax.set_xlabel("input")
-ax.set_ylabel("spike count")
-
-
-ax.plot(
- x,
- predicted_rate,
- label="predicted rate",
- color="tab:orange",
-)
-
-ax.legend()
-sns.despine(ax=ax)
-
-# %%
-# :rocket::rocket::rocket: **Success!** :rocket::rocket::rocket:
-# We are now able to capture the distribution of the firing rate appropriately: both peaks and valleys in the spiking activity are matched by our model predicitons.
-#
-# ### Evaluating different bases directly
-#
-# In the previous example we set the number of basis functions of the `Basis` wrapped in our `TransformerBasis`. However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well.
-#
-# Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`:
-
-# %%
-param_grid = dict(
- glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
- transformerbasis___basis=(
- nmo.basis.RaisedCosineBasisLinear(5),
- nmo.basis.RaisedCosineBasisLinear(10),
- nmo.basis.RaisedCosineBasisLog(5),
- nmo.basis.RaisedCosineBasisLog(10),
- nmo.basis.MSplineBasis(5),
- nmo.basis.MSplineBasis(10),
- ),
-)
-
-# %%
-# Then run the grid search:
-
-# %%
-gridsearch = GridSearchCV(
- pipeline,
- param_grid=param_grid,
- cv=5,
-)
-
-# run the 5-fold cross-validation grid search
-gridsearch.fit(X, y)
-
-
-# %%
-# Wrangling the output data a bit and looking at the scores:
-
-# %%
-cvdf = pd.DataFrame(gridsearch.cv_results_)
-
-# Read out the number of basis functions
-cvdf["transformerbasis_config"] = [
- f"{b.__class__.__name__} - {b.n_basis_funcs}"
- for b in cvdf["param_transformerbasis___basis"]
-]
-
-cvdf_wide = cvdf.pivot(
- index="transformerbasis_config",
- columns="param_glm__regularizer_strength",
- values="mean_test_score",
-)
-
-doc_plots.plot_heatmap_cv_results(cvdf_wide)
-
-
-# %%
-# As shown in the table, the model with the highest score, highlighted in blue, used a RaisedCosineBasisLinear basis (as used above), which appears to be a suitable choice for our toy data.
-# We can confirm that by plotting the firing rate predictions:
-
-# Predict the rate using the optimal configuration
-x = np.sort(X, axis=0)
-predicted_rate = gridsearch.best_estimator_.predict(x)
-
-# %%
-fig, ax = plt.subplots()
-
-ax.scatter(X.flatten(), y, alpha=0.2, label="generated spike counts")
-ax.set_xlabel("input")
-ax.set_ylabel("spike count")
-
-ax.plot(
- x,
- predicted_rate,
- label="predicted rate",
- color="tab:orange",
-)
-
-ax.legend()
-sns.despine(ax=ax)
-
-# %%
-# The plot confirms that the firing rate distribution is accurately captured by our model predictions.
-
-# %%
-# !!! warning
-# Please note that because it would lead to unexpected behavior, mixing the two ways of defining values for the parameter grid is not allowed. The following would lead to an error:
-#
-# ```python
-# param_grid = dict(
-# glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
-# transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
-# transformerbasis___basis=(
-# nmo.basis.RaisedCosineBasisLinear(5),
-# nmo.basis.RaisedCosineBasisLinear(10),
-# nmo.basis.RaisedCosineBasisLog(5),
-# nmo.basis.RaisedCosineBasisLog(10),
-# nmo.basis.MSplineBasis(5),
-# nmo.basis.MSplineBasis(10),
-# ),
-# )
-# ```
-
-# %%
-# ## Create a custom scorer
-# By default, the GLM score method returns the model log-likelihood. If you want to try a different metric, such as the pseudo-R2, you can create a custom scorer and pass it to the cross-validation object:
-
-# %%
-from sklearn.metrics import make_scorer
-
-pseudo_r2 = make_scorer(
- nmo.observation_models.PoissonObservations().pseudo_r2
-)
-
-# %%
-# We can now run the grid search providing the custom scorer
-
-# %%
-gridsearch = GridSearchCV(
- pipeline,
- param_grid=param_grid,
- cv=5,
- scoring=pseudo_r2,
-)
-
-# Run the 5-fold cross-validation grid search
-gridsearch.fit(X, y)
-
-# %%
-# And finally, we can plot each model's score.
-
-# %%
-# Plot the pseudo-R2 scores
-cvdf = pd.DataFrame(gridsearch.cv_results_)
-
-# Read out the number of basis functions
-cvdf["transformerbasis_config"] = [
- f"{b.__class__.__name__} - {b.n_basis_funcs}"
- for b in cvdf["param_transformerbasis___basis"]
-]
-
-cvdf_wide = cvdf.pivot(
- index="transformerbasis_config",
- columns="param_glm__regularizer_strength",
- values="mean_test_score",
-)
-
-doc_plots.plot_heatmap_cv_results(cvdf_wide, label="pseudo-R2")
-
-# %%
-# As you can see, the results with pseudo-R2 agree with those of the negative log-likelihood. Note that this new metric is normalized between 0 and 1, with a higher score indicating better performance.
-
diff --git a/docs/index.md b/docs/index.md
index d7c362d8..4da5ef06 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,29 +1,27 @@
----
-hide:
- - navigation
- - toc
----
+(id:_home)=
-#
-[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/flatironinstitute/nemos/blob/main/LICENSE)
-![Python version](https://img.shields.io/badge/python-3.10%7C3.11%7C3.12-blue.svg)
-[![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
-![PyPI - Version](https://img.shields.io/pypi/v/nemos)
-[![codecov](https://codecov.io/gh/flatironinstitute/nemos/graph/badge.svg?token=vvtrcTFNeu)](https://codecov.io/gh/flatironinstitute/nemos)
-[![Documentation Status](https://readthedocs.org/projects/nemos/badge/?version=latest)](https://nemos.readthedocs.io/en/latest/?badge=latest)
-[![nemos CI](https://github.com/flatironinstitute/nemos/actions/workflows/ci.yml/badge.svg)](https://github.com/flatironinstitute/nemos/actions/workflows/ci.yml)
+```{toctree}
+:maxdepth: 2
+:hidden:
-__Learning Resources:__ [:material-book-open-variant-outline: Neuromatch Academy's Lessons](https://compneuro.neuromatch.io/tutorials/W1D3_GeneralizedLinearModels/student/W1D3_Tutorial1.html) | [:material-youtube: Cosyne 2018 Tutorial](https://www.youtube.com/watch?v=NFeGW5ljUoI&t=424s)
-__Useful Links:__ [:material-chat-question: Getting Help](getting_help) | [:material-alert-circle-outline: Issue Tracker](https://github.com/flatironinstitute/nemos/issues) | [:material-order-bool-ascending-variant: Contributing Guidelines](https://github.com/flatironinstitute/nemos/blob/main/CONTRIBUTING.md)
-
-
+Install
+Quickstart
+Background
+How-To Guide
+Tutorials
+Getting Help
+API Reference
+For Developers
+```
+## __Neural ModelS__
-## __Overview__
NeMoS (Neural ModelS) is a statistical modeling framework optimized for systems neuroscience and powered by [JAX](https://jax.readthedocs.io/en/latest/).
It streamlines the process of defining and selecting models, through a collection of easy-to-use methods for feature design.
@@ -33,74 +31,99 @@ focusing on the Generalized Linear Model (GLM).
We provide a **Poisson GLM** for analyzing spike counts, and a **Gamma GLM** for calcium or voltage imaging traces.
-
-- :material-hammer-wrench:{ .lg .middle } __Installation Instructions__
+::::{grid} 1 2 3 3
+
+:::{grid-item-card} **Installation Instructions**
+:link: installation.html
+:link-alt: Install
+---
+
+Run the following `pip` command in your virtual environment.
+
+```{code-block}
- ---
-
- Run the following `pip` command in your __virtual environment__.
- === "macOS/Linux"
+pip install nemos
- ```bash
- pip install nemos
- ```
+```
- === "Windows"
-
- ```
- python -m pip install nemos
- ```
-
- *For more information see:*
- [:octicons-arrow-right-24: Install](installation)
+:::
-- :material-clock-fast:{ .lg .middle } __Getting Started__
+:::{grid-item-card} **Getting Started**
+:link: quickstart.html
+:link-alt: Quickstart
- ---
+---
- New to NeMoS? Get the ball rolling with our quickstart.
+New to NeMoS? Get the ball rolling with our quickstart.
- [:octicons-arrow-right-24: Quickstart](quickstart)
+:::
-- :material-book-open-variant-outline:{ .lg .middle } __Background__
+:::{grid-item-card} **Background**
+:link: background/README.html
+:link-alt: Background
- ---
+---
- Refresh your theoretical knowledge before diving into data analysis with our notes.
+Refresh your theoretical knowledge before diving into data analysis with our notes.
- [:octicons-arrow-right-24: Background](generated/background)
+:::
-- :material-brain:{ .lg .middle} __Neural Modeling__
+:::{grid-item-card} **How-to Guide**
+:link: how_to_guide/README.html
+:link-alt: How-to-Guide
- ---
+---
- Explore fully worked examples to learn how to analyze neural recordings from scratch.
+Already familiar with the concepts? Learn how you to process and analyze your data with NeMoS.
- *Requires familiarity with the theory.*
- [:octicons-arrow-right-24: Tutorials](generated/tutorials)
-- :material-lightbulb-on-10:{ .lg .middle } __How-To Guide__
+
+
+:::
- *Requires familiarity with the theory.*
- [:octicons-arrow-right-24: How-To Guide](generated/how_to_guide)
+:::{grid-item-card} **Neural Modeling**
+:link: tutorials/README.html
+:link-alt: Tutorials
-- :material-cog:{ .lg .middle } __API Guide__
+---
- ---
+Explore fully worked examples to learn how to analyze neural recordings from scratch.
- Access a detailed description of each module and function, including parameters and functionality.
+
-## :material-scale-balance:{ .lg } License
+:::
+
+:::{grid-item-card} **API Reference**
+:link: api_reference.html
+:link-alt: API Reference
+
+---
+
+Access a detailed description of each module and function, including parameters and functionality.
+
+:::
+
+::::
+
+
+
+
+
+## __License__
Open source, [licensed under MIT](https://github.com/flatironinstitute/nemos/blob/main/LICENSE).
@@ -109,4 +132,16 @@ Open source, [licensed under MIT](https://github.com/flatironinstitute/nemos/blo
This package is supported by the Center for Computational Neuroscience, in the Flatiron Institute of the Simons Foundation.
-
+```{image} assets/logo_flatiron_white.svg
+:alt: Flatiron Center for Computational Neuroscience logo White.
+:class: only-dark
+:width: 200px
+:target: https://www.simonsfoundation.org/flatiron/center-for-computational-neuroscience/
+```
+
+```{image} assets/CCN-logo-wText.png
+:alt: Flatiron Center for Computational Neuroscience logo.
+:class: only-light
+:width: 200px
+:target: https://www.simonsfoundation.org/flatiron/center-for-computational-neuroscience/
+```
diff --git a/docs/installation.md b/docs/installation.md
index 80e2b3b5..0c2572d7 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,7 +1,4 @@
----
-hide:
- - navigation
----
+# Install
## Prerequisites
@@ -60,8 +57,11 @@ To install NeMoS on a system without a GPU, run this command from within your ac
### GPU Installation
-!!! warning
- JAX does not guarantee GPU support for Windows, see [here](https://jax.readthedocs.io/en/latest/installation.html#supported-platforms) for updates.
+:::{warning}
+
+JAX does not guarantee GPU support for Windows, see [here](https://jax.readthedocs.io/en/latest/installation.html#supported-platforms) for updates.
+
+:::
For systems equipped with a GPU, you need to specifically install the GPU-enabled versions of `jax` and `jaxlib` before installing NeMoS.
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..32bb2452
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/quickstart.md b/docs/quickstart.md
index aa234489..062bdb25 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -2,6 +2,8 @@
hide:
- navigation
---
+# Quickstart
+
## **Overview**
NeMoS is a neural modeling software package designed to model neural spiking activity and other time-series data
@@ -151,12 +153,12 @@ The `basis` module includes objects that perform two types of transformations on
### **Non-linear Mapping**
Non-linear mapping is the default mode of operation of any `basis` object. To instantiate a basis for non-linear mapping,
-you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Guide](../reference/nemos/basis) for detailed information).
+you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Reference](nemos_basis) for detailed information).
```python
@@ -188,7 +190,7 @@ shape `(n_samples, n_basis_funcs)`, where each column represents a feature gener
### **Convolution**
@@ -211,8 +213,11 @@ Once the basis is initialized, you can call `compute_features` on an input of sh
`(n_samples, n_signals)` to perform the convolution. The output will be a 2-dimensional array of shape
`(n_samples, n_basis_funcs)` or `(n_samples, n_basis_funcs * n_signals)` respectively.
-!!! warning "Signal length and window size"
- The `window_size` must be shorter than the number of samples in the signal(s) being convolved.
+:::{admonition} Signal length and window size
+:class: warning
+
+The `window_size` must be shorter than the number of samples in the signal(s) being convolved.
+:::
```python
@@ -234,12 +239,12 @@ Once the basis is initialized, you can call `compute_features` on an input of sh
```
-For additional information on one-dimensional convolutions, see [here](../generated/background/plot_03_1D_convolution).
+For additional information on one-dimensional convolutions, see [here](convolution_background).
## **Continuous Observations**
-By default, NeMoS' GLM uses [Poisson observations](../reference/nemos/observation_models/#nemos.observation_models.PoissonObservations), which are a natural choice for spike counts. However, the package also supports a [Gamma](../reference/nemos/observation_models/#nemos.observation_models.GammaObservations) GLM, which is more appropriate for modeling continuous, non-negative observations such as calcium transients.
+By default, NeMoS' GLM uses [Poisson observations](nemos.observation_models.PoissonObservations), which are a natural choice for spike counts. However, the package also supports a [Gamma](nemos.observation_models.GammaObservations) GLM, which is more appropriate for modeling continuous, non-negative observations such as calcium transients.
To change the default observation model, set the `observation_model` argument during initialization:
@@ -254,13 +259,13 @@ To change the default observation model, set the `observation_model` argument du
```
-Take a look at our [tutorial](../generated/tutorials/plot_06_calcium_imaging) for a detailed example.
+Take a look at our [tutorial](tutorial-calcium-imaging) for a detailed example.
## **Regularization**
-NeMoS supports various regularization schemes, including [Ridge](../reference/nemos/regularizer/#nemos.regularizer.Ridge) ($L_2$), [Lasso](../reference/nemos/regularizer/#nemos.regularizer.Lasso) ($L_1$), and [Group Lasso](../reference/nemos/regularizer/#nemos.regularizer.GroupLasso), to prevent overfitting and improve model generalization.
+NeMoS supports various regularization schemes, including [Ridge](nemos.regularizer.Ridge) ($L_2$), [Lasso](nemos.regularizer.Lasso) ($L_1$), and [Group Lasso](nemos.regularizer.GroupLasso), to prevent overfitting and improve model generalization.
You can specify the regularization scheme and its strength when initializing the GLM model:
@@ -279,11 +284,11 @@ You can specify the regularization scheme and its strength when initializing the
## **Pre-processing with `pynapple`**
-!!! warning
-
- This section assumes some familiarity with the `pynapple` package for time series manipulation and data
- exploration. If you'd like to learn more about it, take a look at the [`pynapple` documentation](https://pynapple-org.github.io/pynapple/).
+:::{note}
+This section assumes some familiarity with the `pynapple` package for time series manipulation and data
+exploration. If you'd like to learn more about it, take a look at the [`pynapple` documentation](https://pynapple-org.github.io/pynapple/).
+:::
`pynapple` is an extremely helpful tool when working with time series data. You can easily perform operations such
as restricting your time series to specific epochs (sleep/wake, context A vs. context B, etc.), as well as common
@@ -296,7 +301,7 @@ also be a `pynapple` time series.
A canonical example of this behavior is the `predict` method of `GLM`.
-```python
+```ipython
>>> import numpy as np
>>> import pynapple as nap
@@ -321,11 +326,12 @@ A canonical example of this behavior is the `predict` method of `GLM`.
Let's see how you can greatly streamline your analysis pipeline by integrating `pynapple` and NeMoS.
-!!! note
- You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1).
+:::{note}
+You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1).
+:::
-```python
+```ipython
>>> import nemos as nmo
>>> import pynapple as nap
@@ -355,7 +361,7 @@ Let's see how you can greatly streamline your analysis pipeline by integrating `
Finally, let's compare the tuning curves
-```python
+```ipython
>>> import numpy as np
>>> import matplotlib.pyplot as plt
@@ -381,7 +387,7 @@ Finally, let's compare the tuning curves
```
-
+
## **Compatibility with `scikit-learn`**
@@ -394,7 +400,7 @@ For example, if we would like to tune the critical hyper-parameter `regularizer_
[^1]: For a detailed explanation and practical examples, refer to the [cross-validation page](https://scikit-learn.org/stable/modules/cross_validation.html) in the `scikit-learn` documentation.
-```python
+```ipython
>>> # set up the model
>>> import nemos as nmo
@@ -410,7 +416,7 @@ For example, if we would like to tune the critical hyper-parameter `regularizer_
Fit a 5-fold cross-validation scheme for comparing two different regularizer strengths:
-```python
+```ipython
>>> from sklearn.model_selection import GridSearchCV
@@ -426,13 +432,15 @@ Fit a 5-fold cross-validation scheme for comparing two different regularizer str
```
-!!! info "Cross-Validation in NeMoS"
+:::{admonition} Cross-Validation in NeMoS
+:class: info
- For more information and a practical example on how to construct a parameter grid and cross-validate hyperparameters across an entire pipeline, please refer to the [tutorial on pipelining and cross-validation](../generated/how_to_guide/plot_06_sklearn_pipeline_cv_demo).
+For more information and a practical example on how to construct a parameter grid and cross-validate hyperparameters across an entire pipeline, please refer to the [tutorial on pipelining and cross-validation](sklearn-how-to).
+:::
Finally, we can print the regularizer strength with the best cross-validated performance:
-```python
+```ipython
>>> # print best regularizer strength
>>> print(cls.best_params_)
diff --git a/docs/tutorials/README.md b/docs/tutorials/README.md
index eef9ea5b..5dde4e54 100644
--- a/docs/tutorials/README.md
+++ b/docs/tutorials/README.md
@@ -2,9 +2,94 @@
A gallery of fully worked out tutorials analyzing neural recordings from different brain regions and recording modalities.
-??? attention "Additional requirements"
- To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
- You can install all of the required packages with the following command:
- ```
- pip install nemos[examples]
- ```
+:::{dropdown} Additional requirements
+:color: warning
+:icon: alert
+To run the tutorials, you may need to install some additional packages used for plotting and data fetching.
+You can install all of the required packages with the following command:
+```
+pip install nemos[examples]
+```
+:::
+
+::::{grid} 1 2 3 3
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_01_current_injection.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_02_head_direction.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_03_grid_cells.md
+```
+
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_04_v1_cells.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+
+plot_05_place_cells.md
+```
+:::
+
+:::{grid-item-card}
+
+
+
+```{toctree}
+:maxdepth: 2
+plot_06_calcium_imaging.md
+```
+:::
+
+::::
\ No newline at end of file
diff --git a/docs/tutorials/plot_01_current_injection.md b/docs/tutorials/plot_01_current_injection.md
new file mode 100644
index 00000000..70ebac16
--- /dev/null
+++ b/docs/tutorials/plot_01_current_injection.md
@@ -0,0 +1,787 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+# Fit injected current
+
+For our first example, we will look at a very simple dataset: patch-clamp
+recordings from a single neuron in layer 4 of mouse primary visual cortex. This
+data is from the [Allen Brain
+Atlas](https://celltypes.brain-map.org/experiment/electrophysiology/478498617),
+and experimenters injected current directly into the cell, while recording the
+neuron's membrane potential and spiking behavior. The experiments varied the
+shape of the current across many sweeps, mapping the neuron's behavior in
+response to a wide range of potential inputs.
+
+For our purposes, we will examine only one of these sweeps, "Noise 1", in which
+the experimentalists injected three pulses of current. The current is a square
+pulse multiplied by a sinusoid of a fixed frequency, with some random noise
+riding on top.
+
+![Allen Brain Atlas view of the data we will analyze.](../assets/allen_data.png)
+
+In the figure above (from the Allen Brain Atlas website), we see the
+approximately 22 second sweep, with the input current plotted in the first row,
+the intracellular voltage in the second, and the recorded spikes in the third.
+(The grey lines and dots in the second and third rows comes from other sweeps
+with the same stimulus, which we'll ignore in this exercise.) When fitting the
+Generalized Linear Model, we are attempting to model the spiking behavior, and
+we generally do not have access to the intracellular voltage, so for the rest
+of this notebook, we'll use only the input current and the recorded spikes
+displayed in the first and third rows.
+
+First, let us see how to load in the data and reproduce the above figure, which
+we'll do using the [Pynapple package](https://pynapple-org.github.io/pynapple/). We will rely on
+pynapple throughout this notebook, as it simplifies handling this type of
+data (we will explain the essentials of pynapple as they are used, but see the
+[Pynapple docs](https://pynapple-org.github.io/pynapple/)
+if you are interested in learning more). After we've explored the data some, we'll introduce the Generalized
+Linear Model and how to fit it with NeMoS.
+
+## Learning objectives
+
+- Learn how to explore spiking data and do basic analyses using pynapple
+- Learn how to structure data for NeMoS
+- Learn how to fit a basic Generalized Linear Model using NeMoS
+- Learn how to retrieve the parameters and predictions from a fit GLM for
+ intrepetation.
+
+```{code-cell} ipython3
+# Import everything
+import jax
+import matplotlib.pyplot as plt
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+# some helper plotting functions
+from nemos import _documentation_utils as doc_plots
+
+# configure plots some
+plt.style.use(nmo.styles.plot_style)
+```
+
+## Data Streaming
+
+While you can download the data directly from the Allen Brain Atlas and
+interact with it using their
+[AllenSDK](https://allensdk.readthedocs.io/en/latest/visual_behavior_neuropixels.html),
+we prefer the burgeoning [Neurodata Without Borders (NWB)
+standard](https://nwb-overview.readthedocs.io/en/latest/). We have converted
+this single dataset to NWB and uploaded it to the [Open Science
+Framework](https://osf.io/5crqj/). This allows us to easily load the data
+using pynapple, and it will immediately be in a format that pynapple understands!
+
+:::{tip}
+
+ Pynapple can stream any NWB-formatted dataset! See [their
+ documentation](https://pynapple.org/examples/tutorial_pynapple_dandi.html)
+ for more details, and see the [DANDI Archive](https://dandiarchive.org/)
+ for a repository of compliant datasets.
+:::
+
+The first time the following cell is run, it will take a little bit of time
+to download the data, and a progress bar will show the download's progress.
+On subsequent runs, the cell gets skipped: we do not need to redownload the
+data.
+
+
+```{code-cell} ipython3
+path = nmo.fetch.fetch_data("allen_478498617.nwb")
+```
+
+## Pynapple
+
+### Data structures and preparation
+
+Now that we've downloaded the data, let's open it with pynapple and examine
+its contents.
+
+
+```{code-cell} ipython3
+data = nap.load_file(path)
+print(data)
+```
+
+The dataset contains several different pynapple objects, which we will
+explore throughout this demo. The following illustrates how these fields relate to the data
+we visualized above:
+
+![Annotated view of the data we will analyze.](../assets/allen_data_annotated.gif)
+
+
+- `units`: dictionary of neurons, holding each neuron's spike timestamps.
+- `epochs`: start and end times of different intervals, defining the
+ experimental structure, specifying when each stimulation protocol began and
+ ended.
+- `stimulus`: injected current, in Amperes, sampled at 20k Hz.
+- `response`: the neuron's intracellular voltage, sampled at 20k Hz.
+ We will not use this info in this example
+
+Now let's go through the relevant variables in some more detail:
+
+
+```{code-cell} ipython3
+trial_interval_set = data["epochs"]
+
+current = data["stimulus"]
+spikes = data["units"]
+```
+
+First, let's examine `trial_interval_set`:
+
+
+```{code-cell} ipython3
+trial_interval_set.keys()
+```
+
+`trial_interval_set` is a dictionary with strings for keys and
+[`IntervalSets`](https://pynapple.org/generated/pynapple.core.interval_set.IntervalSet.html)
+for values. Each key defines the stimulus protocol, with the value defining
+the beginning and end of that stimulation protocol.
+
+
+```{code-cell} ipython3
+noise_interval = trial_interval_set["Noise 1"]
+noise_interval
+```
+
+As described above, we will be examining "Noise 1". We can see it contains
+three rows, each defining a separate sweep. We'll just grab the first sweep
+(shown in blue in the pictures above) and ignore the other two (shown in
+gray).
+
+
+```{code-cell} ipython3
+noise_interval = noise_interval[0]
+noise_interval
+```
+
+Now let's examine `current`:
+
+
+```{code-cell} ipython3
+current
+```
+
+`current` is a `Tsd`
+([TimeSeriesData](https://pynapple.org/generated/pynapple.core.time_series.Tsd.html))
+object with 2 columns. Like all `Tsd` objects, the first column contains the
+time index and the second column contains the data; in this case, the current
+in Ampere (A).
+
+Currently, `current` contains the entire ~900 second experiment but, as
+discussed above, we only want one of the "Noise 1" sweeps. Fortunately,
+`pynapple` makes it easy to grab out the relevant time points by making use
+of the `noise_interval` we defined above:
+
+
+```{code-cell} ipython3
+current = current.restrict(noise_interval)
+# convert current from Ampere to pico-amperes, to match the above visualization
+# and move the values to a more reasonable range.
+current = current * 1e12
+current
+```
+
+Notice that the timestamps have changed and our shape is much smaller.
+
+Finally, let's examine the spike times. `spikes` is a
+[`TsGroup`](https://pynapple.org/generated/pynapple.core.ts_group.TsGroup.html#pynapple.core.ts_group.TsGroup),
+a dictionary-like object that holds multiple `Ts` (timeseries) objects with
+potentially different time indices:
+
+
+```{code-cell} ipython3
+spikes
+```
+
+Typically, this is used to hold onto the spike times for a population of
+neurons. In this experiment, we only have recordings from a single neuron, so
+there's only one row.
+
+We can index into the `TsGroup` to see the timestamps for this neuron's
+spikes:
+
+
+```{code-cell} ipython3
+spikes[0]
+```
+
+Similar to `current`, this object originally contains data from the entire
+experiment. To get only the data we need, we again use
+`restrict(noise_interval)`:
+
+
+```{code-cell} ipython3
+spikes = spikes.restrict(noise_interval)
+print(spikes)
+spikes[0]
+```
+
+Now, let's visualize the data from this trial, replicating rows 1 and 3
+from the Allen Brain Atlas figure at the beginning of this notebook:
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 1, figsize=(8, 2))
+ax.plot(current, "grey")
+ax.plot(spikes.to_tsd([-5]), "|", color="k", ms = 10)
+ax.set_ylabel("Current (pA)")
+ax.set_xlabel("Time (s)")
+```
+
+### Basic analyses
+
+Before using the Generalized Linear Model, or any model, it's worth taking
+some time to examine our data and think about what features are interesting
+and worth capturing. As we discussed in the [background](../../background/plot_00_conceptual_intro),
+the GLM is a model of the neuronal firing rate. However, in our experiments,
+we do not observe the firing rate, only the spikes! Moreover, neural
+responses are typically noisy—even in this highly controlled experiment
+where the same current was injected over multiple trials, the spike times
+were slightly different from trial-to-trial. No model can perfectly predict
+spike times on an individual trial, so how do we tell if our model is doing a
+good job?
+
+Our objective function is the log-likelihood of the observed spikes given the
+predicted firing rate. That is, we're trying to find the firing rate, as a
+function of time, for which the observed spikes are likely. Intuitively, this
+makes sense: the firing rate should be high where there are many spikes, and
+vice versa. However, it can be difficult to figure out if your model is doing
+a good job by squinting at the observed spikes and the predicted firing rates
+plotted together.
+
+One common way to visualize a rough estimate of firing rate is to smooth
+the spikes by convolving them with a Gaussian filter.
+
+:::{note}
+
+This is a heuristic for getting the firing rate, and shouldn't be taken
+as the literal truth (to see why, pass a firing rate through a Poisson
+process to generate spikes and then smooth the output to approximate the
+generating firing rate). A model should not be expected to match this
+approximate firing rate exactly, but visualizing the two firing rates
+together can help you reason about which phenomena in your data the model
+is able to adequately capture, and which it is missing.
+
+For more information, see section 1.2 of [*Theoretical
+Neuroscience*](https://boulderschool.yale.edu/sites/default/files/files/DayanAbbott.pdf),
+by Dayan and Abbott.
+:::
+
+Pynapple can easily compute this approximate firing rate, and plotting this
+information will help us pull out some phenomena that we think are
+interesting and would like a model to capture.
+
+First, we must convert from our spike times to binned spikes:
+
+
+```{code-cell} ipython3
+# bin size in seconds
+bin_size = 0.001
+# Get spikes for neuron 0
+count = spikes[0].count(bin_size)
+count
+```
+
+Now, let's convert the binned spikes into the firing rate, by smoothing them
+with a gaussian kernel. Pynapple again provides a convenience function for
+this:
+
+
+```{code-cell} ipython3
+# the inputs to this function are the standard deviation of the gaussian in seconds and
+# the full width of the window, in standard deviations. So std=.05 and size_factor=20
+# gives a total filter size of 0.05 sec * 20 = 1 sec.
+firing_rate = count.smooth(std=0.05, size_factor=20)
+# convert from spikes per bin to spikes per second (Hz)
+firing_rate = firing_rate / bin_size
+```
+
+Note that firing_rate is a [`TsdFrame`](https://pynapple.org/generated/pynapple.core.time_series.TsdFrame.html)!
+
+
+
+```{code-cell} ipython3
+print(type(firing_rate))
+```
+
+Now that we've done all this preparation, let's make a plot to more easily
+visualize the data.
+
+:::{note}
+
+We're hiding the details of the plotting function for the purposes of this tutorial, but you can find it in [the source
+code](https://github.com/flatironinstitute/nemos/blob/development/src/nemos/_documentation_utils/plotting.py)
+if you are interested.
+:::
+
+```{code-cell} ipython3
+doc_plots.current_injection_plot(current, spikes, firing_rate)
+```
+
+So now that we can view the details of our experiment a little more clearly,
+what do we see?
+
+- We have three intervals of increasing current, and the firing rate
+ increases as the current does.
+
+- While the neuron is receiving the input, it does not fire continuously or
+ at a steady rate; there appears to be some periodicity in the response. The
+ neuron fires for a while, stops, and then starts again. There's periodicity
+ in the input as well, so this pattern in the response might be reflecting
+ that.
+
+- There's some decay in firing rate as the input remains on: there are three
+ four "bumps" of neuronal firing in the second and third intervals and they
+ decrease in amplitude, with first being the largest.
+
+These give us some good phenomena to try and predict! But there's something
+that's not quite obvious from the above plot: what is the relationship
+between the input and the firing rate? As described in the first bullet point
+above, it looks to be *monotonically increasing*: as the current increases,
+so does the firing rate. But is that exactly true? What form is that
+relationship?
+
+Pynapple can compute a tuning curve to help us answer this question, by
+binning our spikes based on the instantaneous input current and computing the
+firing rate within those bins:
+
+:::{admonition} Tuning curve in `pynapple`
+:class: note
+
+[`compute_1d_tuning_curves`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_1d_tuning_curves) : compute the firing rate as a function of a 1-dimensional feature.
+:::
+
+```{code-cell} ipython3
+tuning_curve = nap.compute_1d_tuning_curves(spikes, current, nb_bins=15)
+tuning_curve
+```
+
+`tuning_curve` is a pandas DataFrame where each column is a neuron (one
+neuron in this case) and each row is a bin over the feature (here, the input
+current). We can easily plot the tuning curve of the neuron:
+
+
+```{code-cell} ipython3
+doc_plots.tuning_curve_plot(tuning_curve);
+```
+
+We can see that, while the firing rate mostly increases with the current,
+it's definitely not a linear relationship, and it might start decreasing as
+the current gets too large.
+
+So this gives us three interesting phenomena we'd like our model to help
+explain: the tuning curve between the firing rate and the current, the firing
+rate's periodicity, and the gradual reduction in firing rate while the
+current remains on.
+
+
+
+
+## NeMoS
+
+### Preparing data
+
+Now that we understand our model, we're almost ready to put it together.
+Before we construct it, however, we need to get the data into the right
+format.
+
+NeMoS requires that the predictors and spike counts it operates on have the
+following properties:
+
+- predictors and spike counts must have the same number of time points.
+
+- predictors must be two-dimensional, with shape `(n_time_bins, n_features)`.
+ In this example, we have a single feature (the injected current).
+
+- spike counts must be one-dimensional, with shape `(n_time_bins, )`. As
+ discussed above, `n_time_bins` must be the same for both the predictors and
+ spike counts.
+
+- predictors and spike counts must be
+ [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html)
+ arrays, `numpy` arrays or `pynapple` `TsdFrame`/`Tsd`.
+
+:::{admonition} What is jax?
+:class: note
+
+[jax](https://github.com/google/jax) is a Google-supported python library
+for automatic differentiation. It has all sorts of neat features, but the
+most relevant of which for NeMoS is its GPU-compatibility and
+just-in-time compilation (both of which make code faster with little
+overhead!), as well as the collection of optimizers present in
+[jaxopt](https://jaxopt.github.io/stable/).
+:::
+
+First, we require that our predictors and our spike counts have the same
+number of time bins. We can achieve this by down-sampling our current to the
+spike counts to the proper resolution using the
+[`bin_average`](https://pynapple.org/generated/pynapple.core.time_series.Tsd.bin_average.html#pynapple.core.time_series.Tsd.bin_average)
+method from pynapple:
+
+
+```{code-cell} ipython3
+binned_current = current.bin_average(bin_size)
+
+print(f"current shape: {binned_current.shape}")
+# rate is in Hz, convert to KHz
+print(f"current sampling rate: {binned_current.rate/1000.:.02f} KHz")
+
+print(f"\ncount shape: {count.shape}")
+print(f"count sampling rate: {count.rate/1000:.02f} KHz")
+```
+
+Secondly, we have to reshape our variables so that they are the proper shape:
+
+- `predictors`: `(n_time_bins, n_features)`
+- `count`: `(n_time_bins, )`
+
+Because we only have a single predictor feature, we'll use
+[`np.expand_dims`](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html)
+to ensure it is a 2d array.
+
+
+```{code-cell} ipython3
+predictor = np.expand_dims(binned_current, 1)
+
+# check that the dimensionality matches NeMoS expectation
+print(f"predictor shape: {predictor.shape}")
+print(f"count shape: {count.shape}")
+```
+
+:::{admonition} What if I have more than one neuron?
+:class: info
+
+In this example, we're only fitting data for a single neuron, but you
+might wonder how the data should be shaped if you have more than one
+neuron -- do you add an extra dimension? or concatenate neurons along one
+of the existing dimensions?
+
+In NeMoS, we always fit Generalized Linear Models to a single neuron at a
+time. We'll discuss this more in the [following
+tutorial](plot_02_head_direction.md), but briefly: you get the same answer
+whether you fit the neurons separately or simultaneously, and fitting
+them separately can make your life easier.
+:::
+
+### Fitting the model
+
+Now we're ready to fit our model!
+
+First, we need to define our GLM model object. We intend for users
+to interact with our models like
+[scikit-learn](https://scikit-learn.org/stable/getting_started.html)
+estimators. In a nutshell, a model instance is initialized with
+hyperparameters that specify optimization and model details,
+and then the user calls the `.fit()` function to fit the model to data.
+We will walk you through the process below by example, but if you
+are interested in reading more details see the [Getting Started with scikit-learn](https://scikit-learn.org/stable/getting_started.html) webpage.
+
+To initialize our model, we need to specify the regularizer and observation
+model objects, both of which should be one of our custom objects:
+
+- Regularizer: this object specifies both the solver algorithm and the
+ regularization scheme. They are jointly specified because each
+ regularization scheme has a list of compatible solvers to choose between.
+ Regularization modifies the objective function to reflect your prior
+ beliefs about the parameters, such as sparsity. Regularization becomes more
+ important as the number of input features, and thus model parameters,
+ grows. They can be found within [`nemos.regularizer`](regularizers).
+
+:::{warning}
+
+With a convex problem like the GLM, in theory it does not matter which
+solver algorithm you use. In practice, due to numerical issues, it
+generally does. Thus, it's worth trying a couple to see how their
+solutions compare. (Different regularization schemes will always give
+different results.)
+:::
+
+- Observation model: this object links the firing rate and the observed
+ data (in this case spikes), describing the distribution of neural activity (and thus changing
+ the log-likelihood). For spiking data, we use the Poisson observation model, but
+ we discuss other options for continuous data
+ in [the calcium imaging analysis demo](plot_06_calcium_imaging.md).
+
+For this example, we'll use an un-regularized LBFGS solver. We'll discuss
+regularization in a later tutorial.
+
+:::{admonition} Why LBFGS?
+:class: info
+
+[LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) is a
+quasi-Netwon method, that is, it uses the first derivative (the gradient)
+and approximates the second derivative (the Hessian) in order to solve
+the problem. This means that LBFGS tends to find a solution faster and is
+often less sensitive to step-size. Try other solvers to see how they
+behave!
+:::
+
+
+```{code-cell} ipython3
+# Initialize the model w/regularizer and solver
+model = nmo.glm.GLM(solver_name="LBFGS")
+```
+
+Now that we've initialized our model with the optimization parameters, we can
+fit our data! In the previous section, we prepared our model matrix
+(`predictor`) and target data (`count`), so to fit the model we just need to
+pass them to the model:
+
+
+```{code-cell} ipython3
+model.fit(predictor, count)
+```
+
+Now that we've fit our data, we can retrieve the resulting parameters.
+Similar to scikit-learn, these are stored as the `coef_` and `intercept_`
+attributes:
+
+
+```{code-cell} ipython3
+print(f"firing_rate(t) = exp({model.coef_} * current(t) + {model.intercept_})")
+```
+
+Note that `model.coef_` has shape `(n_features, )`, while `model.intercept_`
+is a scalar:
+
+
+```{code-cell} ipython3
+print(f"coef_ shape: {model.coef_.shape}")
+print(f"intercept_ shape: {model.intercept_.shape}")
+```
+
+It's nice to get the parameters above, but we can't tell how well our model
+is doing by looking at them. So how should we evaluate our model?
+
+First, we can use the model to predict the firing rates and compare that to
+the smoothed spike train. By calling [`predict()`](nemos.glm.GLM.predict) we can get the model's
+predicted firing rate for this data. Note that this is just the output of the
+model's linear-nonlinear step, as described earlier!
+
+
+```{code-cell} ipython3
+# mkdocs_gallery_thumbnail_number = 4
+
+predicted_fr = model.predict(predictor)
+# convert units from spikes/bin to spikes/sec
+predicted_fr = predicted_fr / bin_size
+
+
+# and let's smooth the firing rate the same way that we smoothed the firing rate
+smooth_predicted_fr = predicted_fr.smooth(0.05, size_factor=20)
+
+# and plot!
+fig = doc_plots.current_injection_plot(current, spikes, firing_rate,
+ # plot the predicted firing rate that has
+ # been smoothed the same way as the
+ # smoothed spike train
+ predicted_firing_rate=smooth_predicted_fr)
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+import os
+from pathlib import Path
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in assets
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_01_current_injection.svg")
+```
+
+What do we see above? Note that the y-axes in the final row are different for
+each subplot!
+
+- Predicted firing rate increases as injected current goes up — Success! 🎉
+
+- The amplitude of the predicted firing rate only matches the observed
+ amplitude in the third interval: it's too high in the first and too low in
+ the second — Failure! ❌
+
+- Our predicted firing rate has the periodicity we see in the smoothed spike
+ train — Success! 🎉
+
+- The predicted firing rate does not decay as the input remains on: the
+ amplitudes are identical for each of the bumps within a given interval —
+ Failure! ❌
+
+The failure described in the second point may seem particularly confusing —
+approximate amplitude feels like it should be very easy to capture, so what's
+going on?
+
+To get a better sense, let's look at the mean firing rate over the whole
+period:
+
+
+```{code-cell} ipython3
+# compare observed mean firing rate with the model predicted one
+print(f"Observed mean firing rate: {np.mean(count) / bin_size} Hz")
+print(f"Predicted mean firing rate: {np.mean(predicted_fr)} Hz")
+```
+
+We matched the average pretty well! So we've matched the average and the
+range of inputs from the third interval reasonably well, but overshot at low
+inputs and undershot in the middle.
+
+We can see this more directly by computing the tuning curve for our predicted
+firing rate and comparing that against our smoothed spike train from the
+beginning of this notebook. Pynapple can help us again with this:
+
+
+```{code-cell} ipython3
+tuning_curve_model = nap.compute_1d_tuning_curves_continuous(predicted_fr[:, np.newaxis], current, 15)
+fig = doc_plots.tuning_curve_plot(tuning_curve)
+fig.axes[0].plot(tuning_curve_model, color="tomato", label="glm")
+fig.axes[0].legend()
+```
+
+In addition to making that mismatch discussed earlier a little more obvious,
+this tuning curve comparison also highlights that this model thinks the
+firing rate will continue to grow as the injected current increases, which is
+not reflected in the data.
+
+Viewing this plot also makes it clear that the model's tuning curve is
+approximately exponential. We already knew that! That's what it means to be a
+LNP model of a single input. But it's nice to see it made explicit.
+
+### Finishing up
+
+There are a handful of other operations you might like to do with the GLM.
+First, you might be wondering how to simulate spikes — the GLM is a LNP
+model, but the firing rate is just the output of *LN*, its first two steps.
+The firing rate is just the mean of a Poisson process, so we can pass it to
+`jax.random.poisson`:
+
+
+```{code-cell} ipython3
+spikes = jax.random.poisson(jax.random.PRNGKey(123), predicted_fr.values)
+```
+
+Note that this is not actually that informative and, in general, it is
+recommended that you focus on firing rates when interpreting your model.
+
+Also, while
+including spike history is often helpful, it can sometimes make simulations unstable:
+if your GLM includes auto-regressive inputs (e.g., neurons are
+connected to themselves or each other), simulations can sometimes can behave
+poorly because of runaway excitation [$^{[1, 2]}$](#ref-1).
+
+Finally, you may want a number with which to evaluate your model's
+performance. As discussed earlier, the model optimizes log-likelihood to find
+the best-fitting weights, and we can calculate this number using its [`score`](nemos.glm.GLM.score)
+method:
+
+
+```{code-cell} ipython3
+log_likelihood = model.score(predictor, count, score_type="log-likelihood")
+print(f"log-likelihood: {log_likelihood}")
+```
+
+This log-likelihood is un-normalized and thus doesn't mean that much by
+itself, other than "higher=better". When comparing alternative GLMs fit on
+the same dataset, whether that's models using different regularizers and
+solvers or those using different predictors, comparing log-likelihoods is a
+reasonable thing to do.
+
+:::{note}
+
+Under the hood, NeMoS is minimizing the negative log-likelihood, as is
+typical in many optimization contexts. [`score`](nemos.glm.GLM.score) returns the real
+log-likelihood, however, and thus higher is better.
+:::
+
+Because it's un-normalized, however, the log-likelihood should not be
+compared across datasets (because e.g., it won't account for difference in
+noise levels). We provide the ability to compute the pseudo-$R^2$ for this
+purpose:
+
+
+```{code-cell} ipython3
+model.score(predictor, count, score_type='pseudo-r2-Cohen')
+```
+
+## Citation
+
+The data used in this tutorial is from the **Allen Brain Map**, with the
+[following
+citation](https://knowledge.brain-map.org/data/1HEYEW7GMUKWIQW37BO/summary):
+
+**Contributors:** Agata Budzillo, Bosiljka Tasic, Brian R. Lee, Fahimeh
+Baftizadeh, Gabe Murphy, Hongkui Zeng, Jim Berg, Nathan Gouwens, Rachel
+Dalley, Staci A. Sorensen, Tim Jarsky, Uygar Sümbül Zizhen Yao
+
+**Dataset:** Allen Institute for Brain Science (2020). Allen Cell Types Database
+-- Mouse Patch-seq [dataset]. Available from
+brain-map.org/explore/classes/multimodal-characterization.
+
+**Primary publication:** Gouwens, N.W., Sorensen, S.A., et al. (2020). Integrated
+morphoelectric and transcriptomic classification of cortical GABAergic cells.
+Cell, 183(4), 935-953.E19. https://doi.org/10.1016/j.cell.2020.09.057
+
+**Patch-seq protocol:** Lee, B. R., Budzillo, A., et al. (2021). Scaled, high
+fidelity electrophysiological, morphological, and transcriptomic cell
+characterization. eLife, 2021;10:e65482. https://doi.org/10.7554/eLife.65482
+
+**Mouse VISp L2/3 glutamatergic neurons:** Berg, J., Sorensen, S. A., Miller, J.,
+Ting, J., et al. (2021) Human neocortical expansion involves glutamatergic
+neuron diversification. Nature, 598(7879):151-158. doi:
+10.1038/s41586-021-03813-8
+
+## References
+
+[1] Arribas, Diego, Yuan Zhao, and Il Memming Park. "Rescuing neural spike train models from bad MLE." Advances in Neural Information Processing Systems 33 (2020): 2293-2303.
+
+[2] Hocker, David, and Memming Park. "Multistep inference for generalized linear spiking models curbs runaway excitation." International IEEE/EMBS Conference on Neural Engineering, May 2017.
diff --git a/docs/tutorials/plot_01_current_injection.py b/docs/tutorials/plot_01_current_injection.py
deleted file mode 100644
index 0e54ffeb..00000000
--- a/docs/tutorials/plot_01_current_injection.py
+++ /dev/null
@@ -1,691 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""# Fit injected current
-
-For our first example, we will look at a very simple dataset: patch-clamp
-recordings from a single neuron in layer 4 of mouse primary visual cortex. This
-data is from the [Allen Brain
-Atlas](https://celltypes.brain-map.org/experiment/electrophysiology/478498617),
-and experimenters injected current directly into the cell, while recording the
-neuron's membrane potential and spiking behavior. The experiments varied the
-shape of the current across many sweeps, mapping the neuron's behavior in
-response to a wide range of potential inputs.
-
-For our purposes, we will examine only one of these sweeps, "Noise 1", in which
-the experimentalists injected three pulses of current. The current is a square
-pulse multiplied by a sinusoid of a fixed frequency, with some random noise
-riding on top.
-
-![Allen Brain Atlas view of the data we will analyze.](../../assets/allen_data.png)
-
-In the figure above (from the Allen Brain Atlas website), we see the
-approximately 22 second sweep, with the input current plotted in the first row,
-the intracellular voltage in the second, and the recorded spikes in the third.
-(The grey lines and dots in the second and third rows comes from other sweeps
-with the same stimulus, which we'll ignore in this exercise.) When fitting the
-Generalized Linear Model, we are attempting to model the spiking behavior, and
-we generally do not have access to the intracellular voltage, so for the rest
-of this notebook, we'll use only the input current and the recorded spikes
-displayed in the first and third rows.
-
-First, let us see how to load in the data and reproduce the above figure, which
-we'll do using the [Pynapple package](https://pynapple-org.github.io/pynapple/). We will rely on
-pynapple throughout this notebook, as it simplifies handling this type of
-data (we will explain the essentials of pynapple as they are used, but see the
-[Pynapple docs](https://pynapple-org.github.io/pynapple/)
-if you are interested in learning more). After we've explored the data some, we'll introduce the Generalized
-Linear Model and how to fit it with NeMoS.
-
-## Learning objectives {.keep-text}
-
-- Learn how to explore spiking data and do basic analyses using pynapple
-- Learn how to structure data for NeMoS
-- Learn how to fit a basic Generalized Linear Model using NeMoS
-- Learn how to retrieve the parameters and predictions from a fit GLM for
- intrepetation.
-
-"""
-
-
-
-# Import everything
-import jax
-import matplotlib.pyplot as plt
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-# some helper plotting functions
-from nemos import _documentation_utils as doc_plots
-
-# configure plots some
-plt.style.use(nmo.styles.plot_style)
-
-# %%
-# ## Data Streaming
-#
-# While you can download the data directly from the Allen Brain Atlas and
-# interact with it using their
-# [AllenSDK](https://allensdk.readthedocs.io/en/latest/visual_behavior_neuropixels.html),
-# we prefer the burgeoning [Neurodata Without Borders (NWB)
-# standard](https://nwb-overview.readthedocs.io/en/latest/). We have converted
-# this single dataset to NWB and uploaded it to the [Open Science
-# Framework](https://osf.io/5crqj/). This allows us to easily load the data
-# using pynapple, and it will immediately be in a format that pynapple understands!
-#
-# !!! tip
-#
-# Pynapple can stream any NWB-formatted dataset! See [their
-# documentation](https://pynapple.org/examples/tutorial_pynapple_dandi.html)
-# for more details, and see the [DANDI Archive](https://dandiarchive.org/)
-# for a repository of compliant datasets.
-#
-# The first time the following cell is run, it will take a little bit of time
-# to download the data, and a progress bar will show the download's progress.
-# On subsequent runs, the cell gets skipped: we do not need to redownload the
-# data.
-
-
-path = nmo.fetch.fetch_data("allen_478498617.nwb")
-
-# %%
-# ## Pynapple
-#
-# ### Data structures and preparation
-#
-# Now that we've downloaded the data, let's open it with pynapple and examine
-# its contents.
-
-
-data = nap.load_file(path)
-print(data)
-
-# %%
-#
-# The dataset contains several different pynapple objects, which we will
-# explore throughout this demo. The following illustrates how these fields relate to the data
-# we visualized above:
-#
-# ![Annotated view of the data we will analyze.](../../assets/allen_data_annotated.gif)
-#
-#
-# - `units`: dictionary of neurons, holding each neuron's spike timestamps.
-# - `epochs`: start and end times of different intervals, defining the
-# experimental structure, specifying when each stimulation protocol began and
-# ended.
-# - `stimulus`: injected current, in Amperes, sampled at 20k Hz.
-# - `response`: the neuron's intracellular voltage, sampled at 20k Hz.
-# We will not use this info in this example
-#
-# Now let's go through the relevant variables in some more detail:
-
-
-trial_interval_set = data["epochs"]
-
-current = data["stimulus"]
-spikes = data["units"]
-
-# %%
-# First, let's examine `trial_interval_set`:
-
-
-trial_interval_set.keys()
-
-# %%
-#
-# `trial_interval_set` is a dictionary with strings for keys and
-# [`IntervalSets`](https://pynapple.org/generated/pynapple.core.interval_set.IntervalSet.html)
-# for values. Each key defines the stimulus protocol, with the value defining
-# the beginning and end of that stimulation protocol.
-
-noise_interval = trial_interval_set["Noise 1"]
-noise_interval
-
-# %%
-#
-# As described above, we will be examining "Noise 1". We can see it contains
-# three rows, each defining a separate sweep. We'll just grab the first sweep
-# (shown in blue in the pictures above) and ignore the other two (shown in
-# gray).
-
-noise_interval = noise_interval[0]
-noise_interval
-
-# %%
-#
-# Now let's examine `current`:
-
-current
-
-# %%
-#
-# `current` is a `Tsd`
-# ([TimeSeriesData](https://pynapple.org/generated/pynapple.core.time_series.Tsd.html))
-# object with 2 columns. Like all `Tsd` objects, the first column contains the
-# time index and the second column contains the data; in this case, the current
-# in Ampere (A).
-#
-# Currently, `current` contains the entire ~900 second experiment but, as
-# discussed above, we only want one of the "Noise 1" sweeps. Fortunately,
-# `pynapple` makes it easy to grab out the relevant time points by making use
-# of the `noise_interval` we defined above:
-
-
-current = current.restrict(noise_interval)
-# convert current from Ampere to pico-amperes, to match the above visualization
-# and move the values to a more reasonable range.
-current = current * 1e12
-current
-
-# %%
-#
-# Notice that the timestamps have changed and our shape is much smaller.
-#
-# Finally, let's examine the spike times. `spikes` is a
-# [`TsGroup`](https://pynapple.org/generated/pynapple.core.ts_group.TsGroup.html#pynapple.core.ts_group.TsGroup),
-# a dictionary-like object that holds multiple `Ts` (timeseries) objects with
-# potentially different time indices:
-
-
-spikes
-
-# %%
-#
-# Typically, this is used to hold onto the spike times for a population of
-# neurons. In this experiment, we only have recordings from a single neuron, so
-# there's only one row.
-#
-# We can index into the `TsGroup` to see the timestamps for this neuron's
-# spikes:
-
-
-spikes[0]
-
-# %%
-#
-# Similar to `current`, this object originally contains data from the entire
-# experiment. To get only the data we need, we again use
-# `restrict(noise_interval)`:
-
-spikes = spikes.restrict(noise_interval)
-print(spikes)
-spikes[0]
-
-
-# %%
-#
-# Now, let's visualize the data from this trial, replicating rows 1 and 3
-# from the Allen Brain Atlas figure at the beginning of this notebook:
-
-
-fig, ax = plt.subplots(1, 1, figsize=(8, 2))
-ax.plot(current, "grey")
-ax.plot(spikes.to_tsd([-5]), "|", color="k", ms = 10)
-ax.set_ylabel("Current (pA)")
-ax.set_xlabel("Time (s)")
-
-# %%
-#
-# ### Basic analyses
-#
-# Before using the Generalized Linear Model, or any model, it's worth taking
-# some time to examine our data and think about what features are interesting
-# and worth capturing. As we discussed in the [background](../../background/plot_00_conceptual_intro),
-# the GLM is a model of the neuronal firing rate. However, in our experiments,
-# we do not observe the firing rate, only the spikes! Moreover, neural
-# responses are typically noisy—even in this highly controlled experiment
-# where the same current was injected over multiple trials, the spike times
-# were slightly different from trial-to-trial. No model can perfectly predict
-# spike times on an individual trial, so how do we tell if our model is doing a
-# good job?
-#
-# Our objective function is the log-likelihood of the observed spikes given the
-# predicted firing rate. That is, we're trying to find the firing rate, as a
-# function of time, for which the observed spikes are likely. Intuitively, this
-# makes sense: the firing rate should be high where there are many spikes, and
-# vice versa. However, it can be difficult to figure out if your model is doing
-# a good job by squinting at the observed spikes and the predicted firing rates
-# plotted together.
-#
-# One common way to visualize a rough estimate of firing rate is to smooth
-# the spikes by convolving them with a Gaussian filter.
-#
-# !!! info
-#
-# This is a heuristic for getting the firing rate, and shouldn't be taken
-# as the literal truth (to see why, pass a firing rate through a Poisson
-# process to generate spikes and then smooth the output to approximate the
-# generating firing rate). A model should not be expected to match this
-# approximate firing rate exactly, but visualizing the two firing rates
-# together can help you reason about which phenomena in your data the model
-# is able to adequately capture, and which it is missing.
-#
-# For more information, see section 1.2 of [*Theoretical
-# Neuroscience*](https://boulderschool.yale.edu/sites/default/files/files/DayanAbbott.pdf),
-# by Dayan and Abbott.
-#
-# Pynapple can easily compute this approximate firing rate, and plotting this
-# information will help us pull out some phenomena that we think are
-# interesting and would like a model to capture.
-#
-# First, we must convert from our spike times to binned spikes:
-
-# bin size in seconds
-bin_size = 0.001
-# Get spikes for neuron 0
-count = spikes[0].count(bin_size)
-count
-
-# %%
-#
-# Now, let's convert the binned spikes into the firing rate, by smoothing them
-# with a gaussian kernel. Pynapple again provides a convenience function for
-# this:
-
-# the inputs to this function are the standard deviation of the gaussian in seconds and
-# the full width of the window, in standard deviations. So std=.05 and size_factor=20
-# gives a total filter size of 0.05 sec * 20 = 1 sec.
-firing_rate = count.smooth(std=0.05, size_factor=20)
-# convert from spikes per bin to spikes per second (Hz)
-firing_rate = firing_rate / bin_size
-
-# %%
-#
-# Note that firing_rate is a [`TsdFrame`](https://pynapple.org/generated/pynapple.core.time_series.TsdFrame.html)!
-#
-
-print(type(firing_rate))
-
-# %%
-#
-# Now that we've done all this preparation, let's make a plot to more easily
-# visualize the data.
-#
-# !!! note
-#
-# We're hiding the details of the plotting function for the purposes of this
-# tutorial, but you can find it in [the source
-# code](https://github.com/flatironinstitute/nemos/blob/development/src/nemos/_documentation_utils/plotting.py)
-# if you are interested.
-
-doc_plots.current_injection_plot(current, spikes, firing_rate)
-
-# %%
-#
-# So now that we can view the details of our experiment a little more clearly,
-# what do we see?
-#
-# - We have three intervals of increasing current, and the firing rate
-# increases as the current does.
-#
-# - While the neuron is receiving the input, it does not fire continuously or
-# at a steady rate; there appears to be some periodicity in the response. The
-# neuron fires for a while, stops, and then starts again. There's periodicity
-# in the input as well, so this pattern in the response might be reflecting
-# that.
-#
-# - There's some decay in firing rate as the input remains on: there are three
-# four "bumps" of neuronal firing in the second and third intervals and they
-# decrease in amplitude, with first being the largest.
-#
-# These give us some good phenomena to try and predict! But there's something
-# that's not quite obvious from the above plot: what is the relationship
-# between the input and the firing rate? As described in the first bullet point
-# above, it looks to be *monotonically increasing*: as the current increases,
-# so does the firing rate. But is that exactly true? What form is that
-# relationship?
-#
-# Pynapple can compute a tuning curve to help us answer this question, by
-# binning our spikes based on the instantaneous input current and computing the
-# firing rate within those bins:
-#
-# !!! note "Tuning curve in `pynapple`"
-# [`compute_1d_tuning_curves`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_1d_tuning_curves) : compute the firing rate as a function of a 1-dimensional feature.
-
-tuning_curve = nap.compute_1d_tuning_curves(spikes, current, nb_bins=15)
-tuning_curve
-
-# %%
-#
-# `tuning_curve` is a pandas DataFrame where each column is a neuron (one
-# neuron in this case) and each row is a bin over the feature (here, the input
-# current). We can easily plot the tuning curve of the neuron:
-
-doc_plots.tuning_curve_plot(tuning_curve)
-
-# %%
-#
-# We can see that, while the firing rate mostly increases with the current,
-# it's definitely not a linear relationship, and it might start decreasing as
-# the current gets too large.
-#
-# So this gives us three interesting phenomena we'd like our model to help
-# explain: the tuning curve between the firing rate and the current, the firing
-# rate's periodicity, and the gradual reduction in firing rate while the
-# current remains on.
-
-# %%
-# ## NeMoS {.strip-code}
-#
-# ### Preparing data
-#
-# Now that we understand our model, we're almost ready to put it together.
-# Before we construct it, however, we need to get the data into the right
-# format.
-#
-# NeMoS requires that the predictors and spike counts it operates on have the
-# following properties:
-#
-# - predictors and spike counts must have the same number of time points.
-#
-# - predictors must be two-dimensional, with shape `(n_time_bins, n_features)`.
-# In this example, we have a single feature (the injected current).
-#
-# - spike counts must be one-dimensional, with shape `(n_time_bins, )`. As
-# discussed above, `n_time_bins` must be the same for both the predictors and
-# spike counts.
-#
-# - predictors and spike counts must be
-# [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html)
-# arrays, `numpy` arrays or `pynapple` `TsdFrame`/`Tsd`.
-#
-# !!! info "What is jax?"
-#
-# [jax](https://github.com/google/jax) is a Google-supported python library
-# for automatic differentiation. It has all sorts of neat features, but the
-# most relevant of which for NeMoS is its GPU-compatibility and
-# just-in-time compilation (both of which make code faster with little
-# overhead!), as well as the collection of optimizers present in
-# [jaxopt](https://jaxopt.github.io/stable/).
-#
-# First, we require that our predictors and our spike counts have the same
-# number of time bins. We can achieve this by down-sampling our current to the
-# spike counts to the proper resolution using the
-# [`bin_average`](https://pynapple.org/generated/pynapple.core.time_series.Tsd.bin_average.html#pynapple.core.time_series.Tsd.bin_average)
-# method from pynapple:
-
-binned_current = current.bin_average(bin_size)
-
-print(f"current shape: {binned_current.shape}")
-# rate is in Hz, convert to KHz
-print(f"current sampling rate: {binned_current.rate/1000.:.02f} KHz")
-
-print(f"\ncount shape: {count.shape}")
-print(f"count sampling rate: {count.rate/1000:.02f} KHz")
-
-
-# %%
-#
-# Secondly, we have to reshape our variables so that they are the proper shape:
-#
-# - `predictors`: `(n_time_bins, n_features)`
-# - `count`: `(n_time_bins, )`
-#
-# Because we only have a single predictor feature, we'll use
-# [`np.expand_dims`](https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html)
-# to ensure it is a 2d array.
-
-predictor = np.expand_dims(binned_current, 1)
-
-# check that the dimensionality matches NeMoS expectation
-print(f"predictor shape: {predictor.shape}")
-print(f"count shape: {count.shape}")
-
-# %%
-# !!! info "What if I have more than one neuron?"
-#
-# In this example, we're only fitting data for a single neuron, but you
-# might wonder how the data should be shaped if you have more than one
-# neuron -- do you add an extra dimension? or concatenate neurons along one
-# of the existing dimensions?
-#
-# In NeMoS, we always fit Generalized Linear Models to a single neuron at a
-# time. We'll discuss this more in the [following
-# tutorial](../plot_02_head_direction/), but briefly: you get the same answer
-# whether you fit the neurons separately or simultaneously, and fitting
-# them separately can make your life easier.
-#
-# ### Fitting the model
-#
-# Now we're ready to fit our model!
-#
-# First, we need to define our GLM model object. We intend for users
-# to interact with our models like
-# [scikit-learn](https://scikit-learn.org/stable/getting_started.html)
-# estimators. In a nutshell, a model instance is initialized with
-# hyperparameters that specify optimization and model details,
-# and then the user calls the `.fit()` function to fit the model to data.
-# We will walk you through the process below by example, but if you
-# are interested in reading more details see the [Getting Started with scikit-learn](https://scikit-learn.org/stable/getting_started.html) webpage.
-#
-# To initialize our model, we need to specify the regularizer and observation
-# model objects, both of which should be one of our custom objects:
-#
-# - Regularizer: this object specifies both the solver algorithm and the
-# regularization scheme. They are jointly specified because each
-# regularization scheme has a list of compatible solvers to choose between.
-# Regularization modifies the objective function to reflect your prior
-# beliefs about the parameters, such as sparsity. Regularization becomes more
-# important as the number of input features, and thus model parameters,
-# grows. They can be found within `nemos.regularizer`.
-#
-# !!! warning
-#
-# With a convex problem like the GLM, in theory it does not matter which
-# solver algorithm you use. In practice, due to numerical issues, it
-# generally does. Thus, it's worth trying a couple to see how their
-# solutions compare. (Different regularization schemes will always give
-# different results.)
-#
-# - Observation model: this object links the firing rate and the observed
-# data (in this case spikes), describing the distribution of neural activity (and thus changing
-# the log-likelihood). For spiking data, we use the Poisson observation model, but
-# we discuss other options for continuous data
-# in [the calcium imaging analysis demo](../plot_06_calcium_imaging/).
-#
-# For this example, we'll use an un-regularized LBFGS solver. We'll discuss
-# regularization in a later tutorial.
-#
-# !!! info "Why LBFGS?"
-#
-# [LBFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) is a
-# quasi-Netwon method, that is, it uses the first derivative (the gradient)
-# and approximates the second derivative (the Hessian) in order to solve
-# the problem. This means that LBFGS tends to find a solution faster and is
-# often less sensitive to step-size. Try other solvers to see how they
-# behave!
-#
-
-# Initialize the model w/regularizer and solver
-model = nmo.glm.GLM(solver_name="LBFGS")
-
-# %%
-#
-# Now that we've initialized our model with the optimization parameters, we can
-# fit our data! In the previous section, we prepared our model matrix
-# (`predictor`) and target data (`count`), so to fit the model we just need to
-# pass them to the model:
-
-model.fit(predictor, count)
-
-# %%
-#
-# Now that we've fit our data, we can retrieve the resulting parameters.
-# Similar to scikit-learn, these are stored as the `coef_` and `intercept_`
-# attributes:
-
-print(f"firing_rate(t) = exp({model.coef_} * current(t) + {model.intercept_})")
-
-# %%
-#
-# Note that `model.coef_` has shape `(n_features, )`, while `model.intercept_`
-# is a scalar:
-
-print(f"coef_ shape: {model.coef_.shape}")
-print(f"intercept_ shape: {model.intercept_.shape}")
-
-# %%
-#
-# It's nice to get the parameters above, but we can't tell how well our model
-# is doing by looking at them. So how should we evaluate our model?
-#
-# First, we can use the model to predict the firing rates and compare that to
-# the smoothed spike train. By calling `predict()` we can get the model's
-# predicted firing rate for this data. Note that this is just the output of the
-# model's linear-nonlinear step, as described earlier!
-
-# mkdocs_gallery_thumbnail_number = 4
-
-predicted_fr = model.predict(predictor)
-# convert units from spikes/bin to spikes/sec
-predicted_fr = predicted_fr / bin_size
-
-
-# and let's smooth the firing rate the same way that we smoothed the firing rate
-smooth_predicted_fr = predicted_fr.smooth(0.05, size_factor=20)
-
-# and plot!
-doc_plots.current_injection_plot(current, spikes, firing_rate,
- # plot the predicted firing rate that has
- # been smoothed the same way as the
- # smoothed spike train
- predicted_firing_rate=smooth_predicted_fr)
-
-# %%
-#
-# What do we see above? Note that the y-axes in the final row are different for
-# each subplot!
-#
-# - Predicted firing rate increases as injected current goes up — Success! :tada:
-#
-# - The amplitude of the predicted firing rate only matches the observed
-# amplitude in the third interval: it's too high in the first and too low in
-# the second — Failure! :x:
-#
-# - Our predicted firing rate has the periodicity we see in the smoothed spike
-# train — Success! :tada:
-#
-# - The predicted firing rate does not decay as the input remains on: the
-# amplitudes are identical for each of the bumps within a given interval —
-# Failure! :x:
-#
-# The failure described in the second point may seem particularly confusing —
-# approximate amplitude feels like it should be very easy to capture, so what's
-# going on?
-#
-# To get a better sense, let's look at the mean firing rate over the whole
-# period:
-
-# compare observed mean firing rate with the model predicted one
-print(f"Observed mean firing rate: {np.mean(count) / bin_size} Hz")
-print(f"Predicted mean firing rate: {np.mean(predicted_fr)} Hz")
-
-# %%
-#
-# We matched the average pretty well! So we've matched the average and the
-# range of inputs from the third interval reasonably well, but overshot at low
-# inputs and undershot in the middle.
-#
-# We can see this more directly by computing the tuning curve for our predicted
-# firing rate and comparing that against our smoothed spike train from the
-# beginning of this notebook. Pynapple can help us again with this:
-
-tuning_curve_model = nap.compute_1d_tuning_curves_continuous(predicted_fr[:, np.newaxis], current, 15)
-fig = doc_plots.tuning_curve_plot(tuning_curve)
-fig.axes[0].plot(tuning_curve_model, color="tomato", label="glm")
-fig.axes[0].legend()
-
-# %%
-#
-# In addition to making that mismatch discussed earlier a little more obvious,
-# this tuning curve comparison also highlights that this model thinks the
-# firing rate will continue to grow as the injected current increases, which is
-# not reflected in the data.
-#
-# Viewing this plot also makes it clear that the model's tuning curve is
-# approximately exponential. We already knew that! That's what it means to be a
-# LNP model of a single input. But it's nice to see it made explicit.
-#
-# ### Finishing up
-#
-# There are a handful of other operations you might like to do with the GLM.
-# First, you might be wondering how to simulate spikes — the GLM is a LNP
-# model, but the firing rate is just the output of *LN*, its first two steps.
-# The firing rate is just the mean of a Poisson process, so we can pass it to
-# `jax.random.poisson`:
-
-spikes = jax.random.poisson(jax.random.PRNGKey(123), predicted_fr.values)
-
-# %%
-#
-# Note that this is not actually that informative and, in general, it is
-# recommended that you focus on firing rates when interpreting your model.
-#
-# Also, while
-# including spike history is often helpful, it can sometimes make simulations unstable:
-# if your GLM includes auto-regressive inputs (e.g., neurons are
-# connected to themselves or each other), simulations can sometimes can behave
-# poorly because of runaway excitation [$^{[1, 2]}$](#ref-1).
-#
-# Finally, you may want a number with which to evaluate your model's
-# performance. As discussed earlier, the model optimizes log-likelihood to find
-# the best-fitting weights, and we can calculate this number using its `score`
-# method:
-
-log_likelihood = model.score(predictor, count, score_type="log-likelihood")
-print(f"log-likelihood: {log_likelihood}")
-
-# %%
-#
-# This log-likelihood is un-normalized and thus doesn't mean that much by
-# itself, other than "higher=better". When comparing alternative GLMs fit on
-# the same dataset, whether that's models using different regularizers and
-# solvers or those using different predictors, comparing log-likelihoods is a
-# reasonable thing to do.
-#
-# !!! info
-#
-# Under the hood, NeMoS is minimizing the negative log-likelihood, as is
-# typical in many optimization contexts. `score` returns the real
-# log-likelihood, however, and thus higher is better.
-#
-# Because it's un-normalized, however, the log-likelihood should not be
-# compared across datasets (because e.g., it won't account for difference in
-# noise levels). We provide the ability to compute the pseudo-$R^2$ for this
-# purpose:
-model.score(predictor, count, score_type='pseudo-r2-Cohen')
-
-# %%
-# ## Citation
-#
-# The data used in this tutorial is from the **Allen Brain Map**, with the
-# [following
-# citation](https://knowledge.brain-map.org/data/1HEYEW7GMUKWIQW37BO/summary):
-#
-# **Contributors:** Agata Budzillo, Bosiljka Tasic, Brian R. Lee, Fahimeh
-# Baftizadeh, Gabe Murphy, Hongkui Zeng, Jim Berg, Nathan Gouwens, Rachel
-# Dalley, Staci A. Sorensen, Tim Jarsky, Uygar Sümbül Zizhen Yao
-#
-# **Dataset:** Allen Institute for Brain Science (2020). Allen Cell Types Database
-# -- Mouse Patch-seq [dataset]. Available from
-# brain-map.org/explore/classes/multimodal-characterization.
-#
-# **Primary publication:** Gouwens, N.W., Sorensen, S.A., et al. (2020). Integrated
-# morphoelectric and transcriptomic classification of cortical GABAergic cells.
-# Cell, 183(4), 935-953.E19. https://doi.org/10.1016/j.cell.2020.09.057
-#
-# **Patch-seq protocol:** Lee, B. R., Budzillo, A., et al. (2021). Scaled, high
-# fidelity electrophysiological, morphological, and transcriptomic cell
-# characterization. eLife, 2021;10:e65482. https://doi.org/10.7554/eLife.65482
-#
-# **Mouse VISp L2/3 glutamatergic neurons:** Berg, J., Sorensen, S. A., Miller, J.,
-# Ting, J., et al. (2021) Human neocortical expansion involves glutamatergic
-# neuron diversification. Nature, 598(7879):151-158. doi:
-# 10.1038/s41586-021-03813-8
-#
-# ## References
-#
-# [1] Arribas, Diego, Yuan Zhao, and Il Memming Park. "Rescuing neural spike train models from bad MLE." Advances in Neural Information Processing Systems 33 (2020): 2293-2303.
-#
-# [2] Hocker, David, and Memming Park. "Multistep inference for generalized linear spiking models curbs runaway excitation." International IEEE/EMBS Conference on Neural Engineering, May 2017.
diff --git a/docs/tutorials/plot_02_head_direction.md b/docs/tutorials/plot_02_head_direction.md
new file mode 100644
index 00000000..cf408662
--- /dev/null
+++ b/docs/tutorials/plot_02_head_direction.md
@@ -0,0 +1,724 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+
+# Fit head-direction population
+
+## Learning objectives
+
+- Learn how to add history-related predictors to NeMoS GLM
+- Learn about NeMoS [`Basis`](nemos_basis) objects
+- Learn how to use [`Basis`](nemos_basis) objects with convolution
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+# some helper plotting functions
+from nemos import _documentation_utils as doc_plots
+
+# configure pynapple to ignore conversion warning
+nap.nap_config.suppress_conversion_warnings = True
+
+# configure plots some
+plt.style.use(nmo.styles.plot_style)
+```
+
+## Data Streaming
+
+Here we load the data from OSF. The data is a NWB file.
+
+
+```{code-cell} ipython3
+path = nmo.fetch.fetch_data("Mouse32-140822.nwb")
+```
+
+## Pynapple
+We are going to open the NWB file with pynapple.
+
+
+```{code-cell} ipython3
+data = nap.load_file(path)
+
+data
+```
+
+Get spike timings
+
+
+```{code-cell} ipython3
+spikes = data["units"]
+
+spikes
+```
+
+Get the behavioural epochs (in this case, sleep and wakefulness)
+
+
+
+```{code-cell} ipython3
+epochs = data["epochs"]
+wake_ep = data["epochs"]["wake"]
+```
+
+Get the tracked orientation of the animal
+
+
+```{code-cell} ipython3
+angle = data["ry"]
+```
+
+This cell will restrict the data to what we care about i.e. the activity of head-direction neurons during wakefulness.
+
+
+
+```{code-cell} ipython3
+spikes = spikes.getby_category("location")["adn"]
+
+spikes = spikes.restrict(wake_ep).getby_threshold("rate", 1.0)
+angle = angle.restrict(wake_ep)
+```
+
+First let's check that they are head-direction neurons.
+
+
+```{code-cell} ipython3
+tuning_curves = nap.compute_1d_tuning_curves(
+ group=spikes, feature=angle, nb_bins=61, minmax=(0, 2 * np.pi)
+)
+```
+
+Each row indicates an angular bin (in radians), and each column corresponds to a single unit.
+Let's plot the tuning curve of the first two neurons.
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 2, figsize=(12, 4))
+ax[0].plot(tuning_curves.iloc[:, 0])
+ax[0].set_xlabel("Angle (rad)")
+ax[0].set_ylabel("Firing rate (Hz)")
+ax[1].plot(tuning_curves.iloc[:, 1])
+ax[1].set_xlabel("Angle (rad)")
+plt.tight_layout()
+```
+
+Before using NeMoS, let's explore the data at the population level.
+
+Let's plot the preferred heading
+
+
+
+```{code-cell} ipython3
+fig = doc_plots.plot_head_direction_tuning(
+ tuning_curves, spikes, angle, threshold_hz=1, start=8910, end=8960
+)
+```
+
+As we can see, the population activity tracks very well the current head-direction of the animal.
+**Question : are neurons constantly tuned to head-direction and can we use it to predict the spiking activity of each neuron based only on the activity of other neurons?**
+
+To fit the GLM faster, we will use only the first 3 min of wake
+
+
+```{code-cell} ipython3
+wake_ep = nap.IntervalSet(
+ start=wake_ep.start[0], end=wake_ep.start[0] + 3 * 60
+)
+```
+
+To use the GLM, we need first to bin the spike trains. Here we use pynapple
+
+
+```{code-cell} ipython3
+bin_size = 0.01
+count = spikes.count(bin_size, ep=wake_ep)
+```
+
+Here we are going to rearrange neurons order based on their prefered directions.
+
+
+
+```{code-cell} ipython3
+pref_ang = tuning_curves.idxmax()
+
+count = nap.TsdFrame(
+ t=count.t,
+ d=count.values[:, np.argsort(pref_ang.values)],
+)
+```
+
+## NeMoS
+It's time to use NeMoS. Our goal is to estimate the pairwise interaction between neurons.
+This can be quantified with a GLM if we use the recent population spike history to predict the current time step.
+### Self-Connected Single Neuron
+To simplify our life, let's see first how we can model spike history effects in a single neuron.
+The simplest approach is to use counts in fixed length window $i$, $y_{t-i}, \dots, y_{t-1}$ to predict the next
+count $y_{t}$. Let's plot the count history,
+
+
+
+```{code-cell} ipython3
+# select a neuron's spike count time series
+neuron_count = count[:, 0]
+
+# restrict to a smaller time interval
+epoch_one_spk = nap.IntervalSet(
+ start=count.time_support.start[0], end=count.time_support.start[0] + 1.2
+)
+plt.figure(figsize=(8, 3.5))
+plt.step(
+ neuron_count.restrict(epoch_one_spk).t, neuron_count.restrict(epoch_one_spk).d, where="post"
+)
+plt.title("Spike Count Time Series")
+plt.xlabel("Time (sec)")
+plt.ylabel("Counts")
+plt.tight_layout()
+```
+
+#### Features Construction
+Let's fix the spike history window size that we will use as predictor.
+
+
+```{code-cell} ipython3
+# set the size of the spike history window in seconds
+window_size_sec = 0.8
+
+doc_plots.plot_history_window(neuron_count, epoch_one_spk, window_size_sec);
+```
+
+For each time point, we shift our window one bin at the time and vertically stack the spike count history in a matrix.
+Each row of the matrix will be used as the predictors for the rate in the next bin (red narrow rectangle in
+the figure).
+
+
+```{code-cell} ipython3
+doc_plots.run_animation(neuron_count, epoch_one_spk.start[0])
+```
+
+If $t$ is smaller than the window size, we won't have a full window of spike history for estimating the rate.
+One may think of padding the window (with zeros for example) but this may generate weird border artifacts.
+To avoid that, we can simply restrict our analysis to times $t$ larger than the window and NaN-pad earlier
+time-points;
+
+A fast way to compute this feature matrix is convolving the counts with the identity matrix.
+We can apply the convolution and NaN-padding in a single step using the
+[`nemos.convolve.create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor)
+function.
+
+
+```{code-cell} ipython3
+# convert the prediction window to bins (by multiplying with the sampling rate)
+window_size = int(window_size_sec * neuron_count.rate)
+
+# convolve the counts with the identity matrix.
+plt.close("all")
+input_feature = nmo.convolve.create_convolutional_predictor(
+ np.eye(window_size), neuron_count
+)
+
+# print the NaN indices along the time axis
+print("NaN indices:\n", np.where(np.isnan(input_feature[:, 0]))[0])
+```
+
+The binned counts originally have shape "number of samples", we should check that the
+dimension are matching our expectation
+
+
+```{code-cell} ipython3
+print(f"Time bins in counts: {neuron_count.shape[0]}")
+print(f"Convolution window size in bins: {window_size}")
+print(f"Feature shape: {input_feature.shape}")
+```
+
+We can visualize the output for a few time bins
+
+
+```{code-cell} ipython3
+suptitle = "Input feature: Count History"
+neuron_id = 0
+doc_plots.plot_features(input_feature, count.rate, suptitle);
+```
+
+As you may see, the time axis is backward, this happens because convolution flips the time axis.
+This is equivalent, as we can interpret the result as how much a spike will affect the future rate.
+In the previous tutorial our feature was 1-dimensional (just the current), now
+instead the feature dimension is 80, because our bin size was 0.01 sec and the window size is 0.8 sec.
+We can learn these weights by maximum likelihood by fitting a GLM.
+
+
+
+
+#### Fitting the Model
+
+When working a real dataset, it is good practice to train your models on a chunk of the data and
+use the other chunk to assess the model performance. This process is known as "cross-validation".
+There is no unique strategy on how to cross-validate your model; What works best
+depends on the characteristic of your data (time series or independent samples,
+presence or absence of trials...), and that of your model. Here, for simplicity use the first
+half of the wake epochs for training and the second half for testing. This is a reasonable
+choice if the statistics of the neural activity does not change during the course of
+the recording. We will learn about better cross-validation strategies with other
+examples.
+
+
+```{code-cell} ipython3
+# construct the train and test epochs
+duration = input_feature.time_support.tot_length("s")
+start = input_feature.time_support["start"]
+end = input_feature.time_support["end"]
+first_half = nap.IntervalSet(start, start + duration / 2)
+second_half = nap.IntervalSet(start + duration / 2, end)
+```
+
+Fit the glm to the first half of the recording and visualize the ML weights.
+
+
+```{code-cell} ipython3
+# define the GLM object
+model = nmo.glm.GLM(solver_name="LBFGS")
+
+# Fit over the training epochs
+model.fit(
+ input_feature.restrict(first_half),
+ neuron_count.restrict(first_half)
+)
+```
+
+```{code-cell} ipython3
+plt.figure()
+plt.title("Spike History Weights")
+plt.plot(np.arange(window_size) / count.rate, np.squeeze(model.coef_), lw=2, label="GLM raw history 1st Half")
+plt.axhline(0, color="k", lw=0.5)
+plt.xlabel("Time From Spike (sec)")
+plt.ylabel("Kernel")
+plt.legend()
+```
+
+The response in the previous figure seems noise added to a decay, therefore the response
+can be described with fewer degrees of freedom. In other words, it looks like we
+are using way too many weights to describe a simple response.
+If we are correct, what would happen if we re-fit the weights on the other half of the data?
+#### Inspecting the results
+
+
+```{code-cell} ipython3
+# fit on the test set
+
+model_second_half = nmo.glm.GLM(solver_name="LBFGS")
+model_second_half.fit(
+ input_feature.restrict(second_half),
+ neuron_count.restrict(second_half)
+)
+
+plt.figure()
+plt.title("Spike History Weights")
+plt.plot(np.arange(window_size) / count.rate, np.squeeze(model.coef_),
+ label="GLM raw history 1st Half", lw=2)
+plt.plot(np.arange(window_size) / count.rate, np.squeeze(model_second_half.coef_),
+ color="orange", label="GLM raw history 2nd Half", lw=2)
+plt.axhline(0, color="k", lw=0.5)
+plt.xlabel("Time From Spike (sec)")
+plt.ylabel("Kernel")
+plt.legend()
+```
+
+What can we conclude?
+
+The fast fluctuations are inconsistent across fits, indicating that
+they are probably capturing noise, a phenomenon known as over-fitting;
+On the other hand, the decaying trend is fairly consistent, even if
+our estimate is noisy. You can imagine how things could get
+worst if we needed a finer temporal resolution, such 1ms time bins
+(which would require 800 coefficients instead of 80).
+What can we do to mitigate over-fitting now?
+
+#### Reducing feature dimensionality
+One way to proceed is to find a lower-dimensional representation of the response
+by parametrizing the decay effect. For instance, we could try to model it
+with an exponentially decaying function $f(t) = \exp( - \alpha t)$, with
+$\alpha >0$ a positive scalar. This is not a bad idea, because we would greatly
+simplify the dimensionality our features (from 80 to 1). Unfortunately,
+there is no way to know a-priori what is a good parameterization. More
+importantly, not all the parametrizations guarantee a unique and stable solution
+to the maximum likelihood estimation of the coefficients (convexity).
+
+In the GLM framework, the main way to construct a lower dimensional parametrization
+while preserving convexity, is to use a set of basis functions.
+For history-type inputs, whether of the spiking history or of the current
+history, we'll use the raised cosine log-stretched basis first described in
+[Pillow et al., 2005](https://www.jneurosci.org/content/25/47/11003). This
+basis set has the nice property that their precision drops linearly with
+distance from event, which is a makes sense for many history-related inputs
+in neuroscience: whether an input happened 1 or 5 msec ago matters a lot,
+whereas whether an input happened 51 or 55 msec ago is less important.
+
+
+```{code-cell} ipython3
+doc_plots.plot_basis();
+```
+
+:::{note}
+
+We provide a handful of different choices for basis functions, and
+selecting the proper basis function for your input is an important
+analytical step. We will eventually provide guidance on this choice, but
+for now we'll give you a decent choice.
+:::
+
+NeMoS includes [`Basis`](nemos_basis) objects to handle the construction and use of these
+basis functions.
+
+When we instantiate this object, the only arguments we need to specify is the
+number of functions we want, the mode of operation of the basis (`"conv"`),
+and the window size for the convolution. With more basis functions, we'll be able to
+represent the effect of the corresponding input with the higher precision, at
+the cost of adding additional parameters.
+
+
+```{code-cell} ipython3
+# a basis object can be instantiated in "conv" mode for convolving the input.
+basis = nmo.basis.RaisedCosineBasisLog(
+ n_basis_funcs=8, mode="conv", window_size=window_size
+)
+
+# `basis.evaluate_on_grid` is a convenience method to view all basis functions
+# across their whole domain:
+time, basis_kernels = basis.evaluate_on_grid(window_size)
+
+print(basis_kernels.shape)
+
+# time takes equi-spaced values between 0 and 1, we could multiply by the
+# duration of our window to scale it to seconds.
+time *= window_size_sec
+```
+
+To appreciate why the raised-cosine basis can approximate well our response
+we can learn a "good" set of weight for the basis element such that
+a weighted sum of the basis approximates the GLM weights for the count history.
+One way to do so is by minimizing the least-squares.
+
+
+```{code-cell} ipython3
+# compute the least-squares weights
+lsq_coef, _, _, _ = np.linalg.lstsq(basis_kernels, np.squeeze(model.coef_), rcond=-1)
+
+# plot the basis and the approximation
+doc_plots.plot_weighted_sum_basis(time, model.coef_, basis_kernels, lsq_coef);
+```
+
+The first plot is the response of each of the 8 basis functions to a single
+pulse. This is known as the impulse response function, and is a useful way to
+characterize linear systems like our basis objects. The second plot are is a
+bar plot representing the least-square coefficients. The third one are the
+impulse responses scaled by the weights. The last plot shows the sum of the
+scaled response overlapped to the original spike count history weights.
+
+Our predictor previously was huge: every possible 80 time point chunk of the
+data, for 1440000 total numbers. By using this basis set we can instead reduce
+the predictor to 8 numbers for every 80 time point window for 144000 total
+numbers. Basically an order of magnitude less. With 1ms bins we would have
+achieved 2 order of magnitude reduction in input size. This is a huge benefit
+in terms of memory allocation and, computing time. As an additional benefit,
+we will reduce over-fitting.
+
+Let's see our basis in action. We can "compress" spike history feature by convolving the basis
+with the counts (without creating the large spike history feature matrix).
+This can be performed in NeMoS by calling the "compute_features" method of basis.
+
+
+```{code-cell} ipython3
+# equivalent to
+# `nmo.convolve.create_convolutional_predictor(basis_kernels, neuron_count)`
+conv_spk = basis.compute_features(neuron_count)
+
+print(f"Raw count history as feature: {input_feature.shape}")
+print(f"Compressed count history as feature: {conv_spk.shape}")
+
+# Visualize the convolution results
+epoch_one_spk = nap.IntervalSet(8917.5, 8918.5)
+epoch_multi_spk = nap.IntervalSet(8979.2, 8980.2)
+
+doc_plots.plot_convolved_counts(neuron_count, conv_spk, epoch_one_spk, epoch_multi_spk);
+
+# find interval with two spikes to show the accumulation, in a second row
+```
+
+Now that we have our "compressed" history feature matrix, we can fit the ML parameters for a GLM.
+
+
+
+
+#### Fit and compare the models
+
+
+```{code-cell} ipython3
+# use restrict on interval set training
+model_basis = nmo.glm.GLM(solver_name="LBFGS")
+model_basis.fit(conv_spk.restrict(first_half), neuron_count.restrict(first_half))
+```
+
+We can plot the resulting response, noting that the weights we just learned needs to be "expanded" back
+to the original `window_size` dimension by multiplying them with the basis kernels.
+We have now 8 coefficients,
+
+
+```{code-cell} ipython3
+print(model_basis.coef_)
+```
+
+In order to get the response we need to multiply the coefficients by their corresponding
+basis function, and sum them.
+
+
+```{code-cell} ipython3
+self_connection = np.matmul(basis_kernels, np.squeeze(model_basis.coef_))
+
+print(self_connection.shape)
+```
+
+We can now compare this model that based on the raw count history.
+
+
+```{code-cell} ipython3
+plt.figure()
+plt.title("Spike History Weights")
+plt.plot(time, np.squeeze(model.coef_), alpha=0.3, label="GLM raw history")
+plt.plot(time, self_connection, "--k", label="GLM basis", lw=2)
+plt.axhline(0, color="k", lw=0.5)
+plt.xlabel("Time from spike (sec)")
+plt.ylabel("Weight")
+plt.legend()
+```
+
+Let's check if our new estimate does a better job in terms of over-fitting. We can do that
+by visual comparison, as we did previously. Let's fit the second half of the dataset.
+
+
+```{code-cell} ipython3
+model_basis_second_half = nmo.glm.GLM(solver_name="LBFGS")
+model_basis_second_half.fit(conv_spk.restrict(second_half), neuron_count.restrict(second_half))
+
+# compute responses for the 2nd half fit
+self_connection_second_half = np.matmul(basis_kernels, np.squeeze(model_basis_second_half.coef_))
+
+plt.figure()
+plt.title("Spike History Weights")
+plt.plot(time, np.squeeze(model.coef_), "k", alpha=0.3, label="GLM raw history 1st half")
+plt.plot(time, np.squeeze(model_second_half.coef_), alpha=0.3, color="orange", label="GLM raw history 2nd half")
+plt.plot(time, self_connection, "--k", lw=2, label="GLM basis 1st half")
+plt.plot(time, self_connection_second_half, color="orange", lw=2, ls="--", label="GLM basis 2nd half")
+plt.axhline(0, color="k", lw=0.5)
+plt.xlabel("Time from spike (sec)")
+plt.ylabel("Weight")
+plt.legend()
+```
+
+Or we can score the model predictions using both one half of the set for training
+and the other half for testing.
+
+
+```{code-cell} ipython3
+# compare model scores, as expected the training score is better with more parameters
+# this may could be over-fitting.
+print(f"full history train score: {model.score(input_feature.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
+print(f"basis train score: {model_basis.score(conv_spk.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
+```
+
+To check that, let's try to see ho the model perform on unseen data and obtaining a test
+score.
+
+
+```{code-cell} ipython3
+print(f"\nfull history test score: {model.score(input_feature.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
+print(f"basis test score: {model_basis.score(conv_spk.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
+```
+
+Let's extract and plot the rates
+
+
+```{code-cell} ipython3
+rate_basis = model_basis.predict(conv_spk) * conv_spk.rate
+rate_history = model.predict(input_feature) * conv_spk.rate
+ep = nap.IntervalSet(start=8819.4, end=8821)
+
+# plot the rates
+doc_plots.plot_rates_and_smoothed_counts(
+ neuron_count,
+ {"Self-connection raw history":rate_history, "Self-connection bsais": rate_basis}
+);
+```
+
+### All-to-all Connectivity
+The same approach can be applied to the whole population. Now the firing rate of a neuron
+is predicted not only by its own count history, but also by the rest of the
+simultaneously recorded population. We can convolve the basis with the counts of each neuron
+to get an array of predictors of shape, `(num_time_points, num_neurons * num_basis_funcs)`.
+
+#### Preparing the features
+
+```{code-cell} ipython3
+# re-initialize basis
+basis = nmo.basis.RaisedCosineBasisLog(
+ n_basis_funcs=8, mode="conv", window_size=window_size
+)
+
+# convolve all the neurons
+convolved_count = basis.compute_features(count)
+```
+
+Check the dimension to make sure it make sense
+Shape should be (n_samples, n_basis_func * n_neurons)
+
+
+```{code-cell} ipython3
+print(f"Convolved count shape: {convolved_count.shape}")
+```
+
+#### Fitting the Model
+This is an all-to-all neurons model.
+We are using the class [`PopulationGLM`](nemos.glm.PopulationGLM) to fit the whole population at once.
+
+:::{note}
+
+Once we condition on past activity, log-likelihood of the population is the sum of the log-likelihood
+of individual neurons. Maximizing the sum (i.e. the population log-likelihood) is equivalent to
+maximizing each individual term separately (i.e. fitting one neuron at the time).
+:::
+
+
+```{code-cell} ipython3
+model = nmo.glm.PopulationGLM(
+ regularizer="Ridge",
+ solver_name="LBFGS",
+ regularizer_strength=0.1
+ ).fit(convolved_count, count)
+```
+
+#### Comparing model predictions.
+Predict the rate (counts are already sorted by tuning prefs)
+
+
+```{code-cell} ipython3
+predicted_firing_rate = model.predict(convolved_count) * conv_spk.rate
+```
+
+Plot fit predictions over a short window not used for training.
+
+
+```{code-cell} ipython3
+# use pynapple for time axis for all variables plotted for tick labels in imshow
+doc_plots.plot_head_direction_tuning_model(tuning_curves, predicted_firing_rate, spikes, angle, threshold_hz=1,
+ start=8910, end=8960, cmap_label="hsv");
+```
+
+Let's see if our firing rate predictions improved and in what sense.
+
+
+```{code-cell} ipython3
+# mkdocs_gallery_thumbnail_number = 2
+fig = doc_plots.plot_rates_and_smoothed_counts(
+ neuron_count,
+ {"Self-connection: raw history": rate_history,
+ "Self-connection: bsais": rate_basis,
+ "All-to-all: basis": predicted_firing_rate[:, 0]}
+)
+```
+
+#### Visualizing the connectivity
+Compute the tuning curve form the predicted rates.
+
+
+```{code-cell} ipython3
+tuning = nap.compute_1d_tuning_curves_continuous(predicted_firing_rate,
+ feature=angle,
+ nb_bins=61,
+ minmax=(0, 2 * np.pi))
+```
+
+Extract the weights and store it in a `(n_neurons, n_neurons, n_basis_funcs)` array.
+
+
+```{code-cell} ipython3
+weights = model.coef_.reshape(count.shape[1], basis.n_basis_funcs, count.shape[1])
+```
+
+Multiply the weights by the basis, to get the history filters.
+
+
+```{code-cell} ipython3
+responses = np.einsum("jki,tk->ijt", weights, basis_kernels)
+
+print(responses.shape)
+```
+
+Finally, we can visualize the pairwise interactions by plotting
+all the coupling filters.
+
+
+```{code-cell} ipython3
+fig = doc_plots.plot_coupling(responses, tuning)
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in assets
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_02_head_direction.svg")
+```
\ No newline at end of file
diff --git a/docs/tutorials/plot_02_head_direction.py b/docs/tutorials/plot_02_head_direction.py
deleted file mode 100644
index dc4740d0..00000000
--- a/docs/tutorials/plot_02_head_direction.py
+++ /dev/null
@@ -1,603 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-# Fit head-direction population
-
-## Learning objectives
-
-- Learn how to add history-related predictors to NeMoS GLM
-- Learn about NeMoS `Basis` objects
-- Learn how to use `Basis` objects with convolution
-
-"""
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-# some helper plotting functions
-from nemos import _documentation_utils as doc_plots
-
-# configure pynapple to ignore conversion warning
-nap.nap_config.suppress_conversion_warnings = True
-
-# configure plots some
-plt.style.use(nmo.styles.plot_style)
-
-# %%
-# ## Data Streaming
-#
-# Here we load the data from OSF. The data is a NWB file.
-
-path = nmo.fetch.fetch_data("Mouse32-140822.nwb")
-
-# %%
-# ## Pynapple
-# We are going to open the NWB file with pynapple.
-
-
-data = nap.load_file(path)
-
-data
-
-# %%
-#
-# Get spike timings
-
-
-spikes = data["units"]
-
-spikes
-
-# %%
-#
-# Get the behavioural epochs (in this case, sleep and wakefulness)
-#
-
-
-epochs = data["epochs"]
-wake_ep = data["epochs"]["wake"]
-
-# %%
-# Get the tracked orientation of the animal
-
-
-angle = data["ry"]
-
-
-# %%
-# This cell will restrict the data to what we care about i.e. the activity of head-direction neurons during wakefulness.
-#
-
-spikes = spikes.getby_category("location")["adn"]
-
-spikes = spikes.restrict(wake_ep).getby_threshold("rate", 1.0)
-angle = angle.restrict(wake_ep)
-
-# %%
-# First let's check that they are head-direction neurons.
-
-
-tuning_curves = nap.compute_1d_tuning_curves(
- group=spikes, feature=angle, nb_bins=61, minmax=(0, 2 * np.pi)
-)
-
-# %%
-# Each row indicates an angular bin (in radians), and each column corresponds to a single unit.
-# Let's plot the tuning curve of the first two neurons.
-
-fig, ax = plt.subplots(1, 2, figsize=(12, 4))
-ax[0].plot(tuning_curves.iloc[:, 0])
-ax[0].set_xlabel("Angle (rad)")
-ax[0].set_ylabel("Firing rate (Hz)")
-ax[1].plot(tuning_curves.iloc[:, 1])
-ax[1].set_xlabel("Angle (rad)")
-plt.tight_layout()
-
-# %%
-# Before using NeMoS, let's explore the data at the population level.
-#
-# Let's plot the preferred heading
-#
-
-fig = doc_plots.plot_head_direction_tuning(
- tuning_curves, spikes, angle, threshold_hz=1, start=8910, end=8960
-)
-
-# %%
-# As we can see, the population activity tracks very well the current head-direction of the animal.
-# **Question : are neurons constantly tuned to head-direction and can we use it to predict the spiking activity of each neuron based only on the activity of other neurons?**
-#
-# To fit the GLM faster, we will use only the first 3 min of wake
-
-
-wake_ep = nap.IntervalSet(
- start=wake_ep.start[0], end=wake_ep.start[0] + 3 * 60
-)
-
-# %%
-# To use the GLM, we need first to bin the spike trains. Here we use pynapple
-
-bin_size = 0.01
-count = spikes.count(bin_size, ep=wake_ep)
-
-# %%
-# Here we are going to rearrange neurons order based on their prefered directions.
-#
-
-
-pref_ang = tuning_curves.idxmax()
-
-count = nap.TsdFrame(
- t=count.t,
- d=count.values[:, np.argsort(pref_ang.values)],
-)
-
-# %%
-# ## NeMoS {.strip-code}
-# It's time to use NeMoS. Our goal is to estimate the pairwise interaction between neurons.
-# This can be quantified with a GLM if we use the recent population spike history to predict the current time step.
-# ### Self-Connected Single Neuron
-# To simplify our life, let's see first how we can model spike history effects in a single neuron.
-# The simplest approach is to use counts in fixed length window $i$, $y_{t-i}, \dots, y_{t-1}$ to predict the next
-# count $y_{t}$. Let's plot the count history,
-#
-
-
-# select a neuron's spike count time series
-neuron_count = count[:, 0]
-
-# restrict to a smaller time interval
-epoch_one_spk = nap.IntervalSet(
- start=count.time_support.start[0], end=count.time_support.start[0] + 1.2
-)
-plt.figure(figsize=(8, 3.5))
-plt.step(
- neuron_count.restrict(epoch_one_spk).t, neuron_count.restrict(epoch_one_spk).d, where="post"
-)
-plt.title("Spike Count Time Series")
-plt.xlabel("Time (sec)")
-plt.ylabel("Counts")
-plt.tight_layout()
-
-# %%
-# #### Features Construction
-# Let's fix the spike history window size that we will use as predictor.
-
-
-# set the size of the spike history window in seconds
-window_size_sec = 0.8
-
-doc_plots.plot_history_window(neuron_count, epoch_one_spk, window_size_sec)
-
-
-# %%
-# For each time point, we shift our window one bin at the time and vertically stack the spike count history in a matrix.
-# Each row of the matrix will be used as the predictors for the rate in the next bin (red narrow rectangle in
-# the figure).
-
-
-doc_plots.run_animation(neuron_count, epoch_one_spk.start[0])
-
-# %%
-# If $t$ is smaller than the window size, we won't have a full window of spike history for estimating the rate.
-# One may think of padding the window (with zeros for example) but this may generate weird border artifacts.
-# To avoid that, we can simply restrict our analysis to times $t$ larger than the window and NaN-pad earlier
-# time-points;
-#
-# A fast way to compute this feature matrix is convolving the counts with the identity matrix.
-# We can apply the convolution and NaN-padding in a single step using the
-# [`nemos.convolve.create_convolutional_predictor`](../../../reference/nemos/convolve/#nemos.convolve.create_convolutional_predictor)
-# function.
-
-# convert the prediction window to bins (by multiplying with the sampling rate)
-window_size = int(window_size_sec * neuron_count.rate)
-
-# convolve the counts with the identity matrix.
-plt.close("all")
-input_feature = nmo.convolve.create_convolutional_predictor(
- np.eye(window_size), neuron_count
-)
-
-# print the NaN indices along the time axis
-print("NaN indices:\n", np.where(np.isnan(input_feature[:, 0]))[0])
-
-# %%
-# The binned counts originally have shape "number of samples", we should check that the
-# dimension are matching our expectation
-
-print(f"Time bins in counts: {neuron_count.shape[0]}")
-print(f"Convolution window size in bins: {window_size}")
-print(f"Feature shape: {input_feature.shape}")
-
-# %%
-#
-# We can visualize the output for a few time bins
-
-
-suptitle = "Input feature: Count History"
-neuron_id = 0
-doc_plots.plot_features(input_feature, count.rate, suptitle)
-
-# %%
-# As you may see, the time axis is backward, this happens because convolution flips the time axis.
-# This is equivalent, as we can interpret the result as how much a spike will affect the future rate.
-# In the previous tutorial our feature was 1-dimensional (just the current), now
-# instead the feature dimension is 80, because our bin size was 0.01 sec and the window size is 0.8 sec.
-# We can learn these weights by maximum likelihood by fitting a GLM.
-
-# %%
-# #### Fitting the Model
-#
-# When working a real dataset, it is good practice to train your models on a chunk of the data and
-# use the other chunk to assess the model performance. This process is known as "cross-validation".
-# There is no unique strategy on how to cross-validate your model; What works best
-# depends on the characteristic of your data (time series or independent samples,
-# presence or absence of trials...), and that of your model. Here, for simplicity use the first
-# half of the wake epochs for training and the second half for testing. This is a reasonable
-# choice if the statistics of the neural activity does not change during the course of
-# the recording. We will learn about better cross-validation strategies with other
-# examples.
-
-# construct the train and test epochs
-duration = input_feature.time_support.tot_length("s")
-start = input_feature.time_support["start"]
-end = input_feature.time_support["end"]
-first_half = nap.IntervalSet(start, start + duration / 2)
-second_half = nap.IntervalSet(start + duration / 2, end)
-
-# %%
-# Fit the glm to the first half of the recording and visualize the ML weights.
-
-
-# define the GLM object
-model = nmo.glm.GLM(solver_name="LBFGS")
-
-# Fit over the training epochs
-model.fit(
- input_feature.restrict(first_half),
- neuron_count.restrict(first_half)
-)
-
-# %%
-
-plt.figure()
-plt.title("Spike History Weights")
-plt.plot(np.arange(window_size) / count.rate, np.squeeze(model.coef_), lw=2, label="GLM raw history 1st Half")
-plt.axhline(0, color="k", lw=0.5)
-plt.xlabel("Time From Spike (sec)")
-plt.ylabel("Kernel")
-plt.legend()
-
-# %%
-# The response in the previous figure seems noise added to a decay, therefore the response
-# can be described with fewer degrees of freedom. In other words, it looks like we
-# are using way too many weights to describe a simple response.
-# If we are correct, what would happen if we re-fit the weights on the other half of the data?
-# #### Inspecting the results
-
-# fit on the test set
-
-model_second_half = nmo.glm.GLM(solver_name="LBFGS")
-model_second_half.fit(
- input_feature.restrict(second_half),
- neuron_count.restrict(second_half)
-)
-
-plt.figure()
-plt.title("Spike History Weights")
-plt.plot(np.arange(window_size) / count.rate, np.squeeze(model.coef_),
- label="GLM raw history 1st Half", lw=2)
-plt.plot(np.arange(window_size) / count.rate, np.squeeze(model_second_half.coef_),
- color="orange", label="GLM raw history 2nd Half", lw=2)
-plt.axhline(0, color="k", lw=0.5)
-plt.xlabel("Time From Spike (sec)")
-plt.ylabel("Kernel")
-plt.legend()
-
-# %%
-# What can we conclude?
-#
-# The fast fluctuations are inconsistent across fits, indicating that
-# they are probably capturing noise, a phenomenon known as over-fitting;
-# On the other hand, the decaying trend is fairly consistent, even if
-# our estimate is noisy. You can imagine how things could get
-# worst if we needed a finer temporal resolution, such 1ms time bins
-# (which would require 800 coefficients instead of 80).
-# What can we do to mitigate over-fitting now?
-#
-# #### Reducing feature dimensionality
-# One way to proceed is to find a lower-dimensional representation of the response
-# by parametrizing the decay effect. For instance, we could try to model it
-# with an exponentially decaying function $f(t) = \exp( - \alpha t)$, with
-# $\alpha >0$ a positive scalar. This is not a bad idea, because we would greatly
-# simplify the dimensionality our features (from 80 to 1). Unfortunately,
-# there is no way to know a-priori what is a good parameterization. More
-# importantly, not all the parametrizations guarantee a unique and stable solution
-# to the maximum likelihood estimation of the coefficients (convexity).
-#
-# In the GLM framework, the main way to construct a lower dimensional parametrization
-# while preserving convexity, is to use a set of basis functions.
-# For history-type inputs, whether of the spiking history or of the current
-# history, we'll use the raised cosine log-stretched basis first described in
-# [Pillow et al., 2005](https://www.jneurosci.org/content/25/47/11003). This
-# basis set has the nice property that their precision drops linearly with
-# distance from event, which is a makes sense for many history-related inputs
-# in neuroscience: whether an input happened 1 or 5 msec ago matters a lot,
-# whereas whether an input happened 51 or 55 msec ago is less important.
-
-
-doc_plots.plot_basis()
-
-# %%
-# !!! info
-#
-# We provide a handful of different choices for basis functions, and
-# selecting the proper basis function for your input is an important
-# analytical step. We will eventually provide guidance on this choice, but
-# for now we'll give you a decent choice.
-#
-# NeMoS includes `Basis` objects to handle the construction and use of these
-# basis functions.
-#
-# When we instantiate this object, the only arguments we need to specify is the
-# number of functions we want, the mode of operation of the basis (`"conv"`),
-# and the window size for the convolution. With more basis functions, we'll be able to
-# represent the effect of the corresponding input with the higher precision, at
-# the cost of adding additional parameters.
-
-# a basis object can be instantiated in "conv" mode for convolving the input.
-basis = nmo.basis.RaisedCosineBasisLog(
- n_basis_funcs=8, mode="conv", window_size=window_size
-)
-
-# `basis.evaluate_on_grid` is a convenience method to view all basis functions
-# across their whole domain:
-time, basis_kernels = basis.evaluate_on_grid(window_size)
-
-print(basis_kernels.shape)
-
-# time takes equi-spaced values between 0 and 1, we could multiply by the
-# duration of our window to scale it to seconds.
-time *= window_size_sec
-
-# %%
-# To appreciate why the raised-cosine basis can approximate well our response
-# we can learn a "good" set of weight for the basis element such that
-# a weighted sum of the basis approximates the GLM weights for the count history.
-# One way to do so is by minimizing the least-squares.
-
-
-# compute the least-squares weights
-lsq_coef, _, _, _ = np.linalg.lstsq(basis_kernels, np.squeeze(model.coef_), rcond=-1)
-
-# plot the basis and the approximation
-doc_plots.plot_weighted_sum_basis(time, model.coef_, basis_kernels, lsq_coef)
-
-# %%
-#
-# The first plot is the response of each of the 8 basis functions to a single
-# pulse. This is known as the impulse response function, and is a useful way to
-# characterize linear systems like our basis objects. The second plot are is a
-# bar plot representing the least-square coefficients. The third one are the
-# impulse responses scaled by the weights. The last plot shows the sum of the
-# scaled response overlapped to the original spike count history weights.
-#
-# Our predictor previously was huge: every possible 80 time point chunk of the
-# data, for 1440000 total numbers. By using this basis set we can instead reduce
-# the predictor to 8 numbers for every 80 time point window for 144000 total
-# numbers. Basically an order of magnitude less. With 1ms bins we would have
-# achieved 2 order of magnitude reduction in input size. This is a huge benefit
-# in terms of memory allocation and, computing time. As an additional benefit,
-# we will reduce over-fitting.
-#
-# Let's see our basis in action. We can "compress" spike history feature by convolving the basis
-# with the counts (without creating the large spike history feature matrix).
-# This can be performed in NeMoS by calling the "compute_features" method of basis.
-
-
-# equivalent to
-# `nmo.convolve.create_convolutional_predictor(basis_kernels, neuron_count)`
-conv_spk = basis.compute_features(neuron_count)
-
-print(f"Raw count history as feature: {input_feature.shape}")
-print(f"Compressed count history as feature: {conv_spk.shape}")
-
-# Visualize the convolution results
-epoch_one_spk = nap.IntervalSet(8917.5, 8918.5)
-epoch_multi_spk = nap.IntervalSet(8979.2, 8980.2)
-
-doc_plots.plot_convolved_counts(neuron_count, conv_spk, epoch_one_spk, epoch_multi_spk)
-
-# find interval with two spikes to show the accumulation, in a second row
-
-# %%
-# Now that we have our "compressed" history feature matrix, we can fit the ML parameters for a GLM.
-
-# %%
-# #### Fit and compare the models
-
-# use restrict on interval set training
-model_basis = nmo.glm.GLM(solver_name="LBFGS")
-model_basis.fit(conv_spk.restrict(first_half), neuron_count.restrict(first_half))
-
-# %%
-# We can plot the resulting response, noting that the weights we just learned needs to be "expanded" back
-# to the original `window_size` dimension by multiplying them with the basis kernels.
-# We have now 8 coefficients,
-
-print(model_basis.coef_)
-
-# %%
-# In order to get the response we need to multiply the coefficients by their corresponding
-# basis function, and sum them.
-
-self_connection = np.matmul(basis_kernels, np.squeeze(model_basis.coef_))
-
-print(self_connection.shape)
-
-# %%
-# We can now compare this model that based on the raw count history.
-
-plt.figure()
-plt.title("Spike History Weights")
-plt.plot(time, np.squeeze(model.coef_), alpha=0.3, label="GLM raw history")
-plt.plot(time, self_connection, "--k", label="GLM basis", lw=2)
-plt.axhline(0, color="k", lw=0.5)
-plt.xlabel("Time from spike (sec)")
-plt.ylabel("Weight")
-plt.legend()
-
-# %%
-# Let's check if our new estimate does a better job in terms of over-fitting. We can do that
-# by visual comparison, as we did previously. Let's fit the second half of the dataset.
-
-model_basis_second_half = nmo.glm.GLM(solver_name="LBFGS")
-model_basis_second_half.fit(conv_spk.restrict(second_half), neuron_count.restrict(second_half))
-
-# compute responses for the 2nd half fit
-self_connection_second_half = np.matmul(basis_kernels, np.squeeze(model_basis_second_half.coef_))
-
-plt.figure()
-plt.title("Spike History Weights")
-plt.plot(time, np.squeeze(model.coef_), "k", alpha=0.3, label="GLM raw history 1st half")
-plt.plot(time, np.squeeze(model_second_half.coef_), alpha=0.3, color="orange", label="GLM raw history 2nd half")
-plt.plot(time, self_connection, "--k", lw=2, label="GLM basis 1st half")
-plt.plot(time, self_connection_second_half, color="orange", lw=2, ls="--", label="GLM basis 2nd half")
-plt.axhline(0, color="k", lw=0.5)
-plt.xlabel("Time from spike (sec)")
-plt.ylabel("Weight")
-plt.legend()
-
-
-# %%
-# Or we can score the model predictions using both one half of the set for training
-# and the other half for testing.
-
-# compare model scores, as expected the training score is better with more parameters
-# this may could be over-fitting.
-print(f"full history train score: {model.score(input_feature.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
-print(f"basis train score: {model_basis.score(conv_spk.restrict(first_half), neuron_count.restrict(first_half), score_type='pseudo-r2-Cohen')}")
-
-# %%
-# To check that, let's try to see ho the model perform on unseen data and obtaining a test
-# score.
-print(f"\nfull history test score: {model.score(input_feature.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
-print(f"basis test score: {model_basis.score(conv_spk.restrict(second_half), neuron_count.restrict(second_half), score_type='pseudo-r2-Cohen')}")
-
-# %%
-# Let's extract and plot the rates
-
-
-rate_basis = model_basis.predict(conv_spk) * conv_spk.rate
-rate_history = model.predict(input_feature) * conv_spk.rate
-ep = nap.IntervalSet(start=8819.4, end=8821)
-
-# plot the rates
-doc_plots.plot_rates_and_smoothed_counts(
- neuron_count,
- {"Self-connection raw history":rate_history, "Self-connection bsais": rate_basis}
-)
-
-# %%
-# ### All-to-all Connectivity
-# The same approach can be applied to the whole population. Now the firing rate of a neuron
-# is predicted not only by its own count history, but also by the rest of the
-# simultaneously recorded population. We can convolve the basis with the counts of each neuron
-# to get an array of predictors of shape, `(num_time_points, num_neurons * num_basis_funcs)`.
-#
-# #### Preparing the features
-
-# define a basis function that expects an input of shape (num_samples, num_neurons).
-num_neurons = count.shape[1]
-basis = nmo.basis.RaisedCosineBasisLog(
- n_basis_funcs=8, mode="conv", window_size=window_size, label="convolved counts"
-)
-
-# convolve all the neurons
-convolved_count = basis.compute_features(count)
-
-# %%
-# Check the dimension to make sure it make sense
-# Shape should be (n_samples, n_basis_func * num_neurons)
-print(f"Convolved count shape: {convolved_count.shape}")
-
-# %%
-# #### Fitting the Model
-# This is an all-to-all neurons model.
-# We are using the class `PopulationGLM` to fit the whole population at once.
-#
-# !!! note
-# Once we condition on past activity, log-likelihood of the population is the sum of the log-likelihood
-# of individual neurons. Maximizing the sum (i.e. the population log-likelihood) is equivalent to
-# maximizing each individual term separately (i.e. fitting one neuron at the time).
-#
-
-model = nmo.glm.PopulationGLM(
- regularizer="Ridge",
- solver_name="LBFGS",
- regularizer_strength=0.1
- ).fit(convolved_count, count)
-
-# %%
-# #### Comparing model predictions.
-# Predict the rate (counts are already sorted by tuning prefs)
-
-predicted_firing_rate = model.predict(convolved_count) * conv_spk.rate
-
-# %%
-# Plot fit predictions over a short window not used for training.
-
-# use pynapple for time axis for all variables plotted for tick labels in imshow
-doc_plots.plot_head_direction_tuning_model(tuning_curves, predicted_firing_rate, spikes, angle, threshold_hz=1,
- start=8910, end=8960, cmap_label="hsv")
-# %%
-# Let's see if our firing rate predictions improved and in what sense.
-
-# mkdocs_gallery_thumbnail_number = 2
-doc_plots.plot_rates_and_smoothed_counts(
- neuron_count,
- {"Self-connection: raw history": rate_history,
- "Self-connection: bsais": rate_basis,
- "All-to-all: basis": predicted_firing_rate[:, 0]}
-)
-
-# %%
-# #### Visualizing the connectivity
-# Compute the tuning curve form the predicted rates.
-
-tuning = nap.compute_1d_tuning_curves_continuous(predicted_firing_rate,
- feature=angle,
- nb_bins=61,
- minmax=(0, 2 * np.pi))
-
-# %%
-# Extract the weights and store it in a (n_neurons, n_basis_funcs, n_neurons) array.
-# We can use `basis.split_by_feature` for this. The method will return a dictionary with an array
-# for each feature, and keys the label we provided to the basis.
-# In this case, "convolved counts" is the only feature.
-
-# split the coefficients by feature
-weights = basis.split_by_feature(model.coef_, axis=0)
-
-# the output is a dictionary containing an array of shape (n_neurons, n_basis_funcs, n_neurons)
-for k, v in weights.items():
- print(f"{k}: {v.shape}")
-
-# get the array
-weights = weights["convolved counts"]
-
-# %%
-# Multiply the weights by the basis, to get the history filters.
-
-responses = np.einsum("jki,tk->ijt", weights, basis_kernels)
-
-print(responses.shape)
-
-# %%
-# Finally, we can visualize the pairwise interactions by plotting
-# all the coupling filters.
-
-doc_plots.plot_coupling(responses, tuning)
diff --git a/docs/tutorials/plot_03_grid_cells.py b/docs/tutorials/plot_03_grid_cells.md
similarity index 51%
rename from docs/tutorials/plot_03_grid_cells.py
rename to docs/tutorials/plot_03_grid_cells.md
index f8532835..f4977cbd 100644
--- a/docs/tutorials/plot_03_grid_cells.py
+++ b/docs/tutorials/plot_03_grid_cells.md
@@ -1,58 +1,104 @@
-# -*- coding: utf-8 -*-
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
-"""
-# Fit grid cell
-"""
+# Fit grid cell
+```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap
from scipy.ndimage import gaussian_filter
import nemos as nmo
+```
+
+## Data Streaming
-# %%
-# ## Data Streaming
-#
-# Here we load the data from OSF. The data is a NWB file.
+Here we load the data from OSF. The data is a NWB file.
+
+```{code-cell} ipython3
io = nmo.fetch.download_dandi_data(
"000582",
"sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb",
)
+```
+
+## Pynapple
-# %%
-# ## Pynapple
-#
-# Let's load the dataset and see what's inside
+Let's load the dataset and see what's inside
+
+```{code-cell} ipython3
dataset = nap.NWBFile(io.read(), lazy_loading=False)
print(dataset)
+```
-
-# %%
-# In this case, the data were used in this [publication](https://www.science.org/doi/full/10.1126/science.1125572).
-# We thus expect to find neurons tuned to position and head-direction of the animal.
-# Let's verify that with pynapple.
-# First, extract the spike times and the position of the animal.
+In this case, the data were used in this [publication](https://www.science.org/doi/full/10.1126/science.1125572).
+We thus expect to find neurons tuned to position and head-direction of the animal.
+Let's verify that with pynapple.
+First, extract the spike times and the position of the animal.
+```{code-cell} ipython3
spikes = dataset["units"] # Get spike timings
position = dataset["SpatialSeriesLED1"] # Get the tracked orientation of the animal
+```
-# %%
-# Here we compute quickly the head-direction of the animal from the position of the LEDs.
+Here we compute quickly the head-direction of the animal from the position of the LEDs.
+
+```{code-cell} ipython3
diff = dataset["SpatialSeriesLED1"].values - dataset["SpatialSeriesLED2"].values
head_dir = (np.arctan2(*diff.T) + (2 * np.pi)) % (2 * np.pi)
head_dir = nap.Tsd(dataset["SpatialSeriesLED1"].index, head_dir).dropna()
+```
+Let's quickly compute some tuning curves for head-direction and spatial position.
-# %%
-# Let's quickly compute some tuning curves for head-direction and spatial position.
+```{code-cell} ipython3
hd_tuning = nap.compute_1d_tuning_curves(
group=spikes, feature=head_dir, nb_bins=61, minmax=(0, 2 * np.pi)
)
@@ -60,11 +106,12 @@
pos_tuning, binsxy = nap.compute_2d_tuning_curves(
group=spikes, features=position, nb_bins=12
)
+```
+Let's plot the tuning curves for each neuron.
-# %%
-# Let's plot the tuning curves for each neuron.
+```{code-cell} ipython3
fig = plt.figure(figsize=(12, 4))
gs = plt.GridSpec(2, len(spikes))
for i in range(len(spikes)):
@@ -74,60 +121,76 @@
ax = plt.subplot(gs[1, i])
ax.imshow(gaussian_filter(pos_tuning[i], sigma=1))
plt.tight_layout()
+```
-# %%
-# ## NeMoS
-# It's time to use NeMoS.
-# Let's try to predict the spikes as a function of position and see if we can generate better tuning curves
-# First we start by binning the spike trains in 10 ms bins.
+## NeMoS
+It's time to use NeMoS.
+Let's try to predict the spikes as a function of position and see if we can generate better tuning curves
+First we start by binning the spike trains in 10 ms bins.
+```{code-cell} ipython3
bin_size = 0.01 # second
counts = spikes.count(bin_size, ep=position.time_support)
+```
+
+We need to interpolate the position to the same time resolution.
+We can still use pynapple for this.
-# %%
-# We need to interpolate the position to the same time resolution.
-# We can still use pynapple for this.
+```{code-cell} ipython3
position = position.interpolate(counts)
+```
-# %%
-# We can define a two-dimensional basis for position by multiplying two one-dimensional bases,
-# see [here](../../background/plot_02_ND_basis_function) for more details.
+We can define a two-dimensional basis for position by multiplying two one-dimensional bases,
+see [here](../../background/plot_02_ND_basis_function) for more details.
+
+```{code-cell} ipython3
basis_2d = nmo.basis.RaisedCosineBasisLinear(
n_basis_funcs=10
) * nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=10)
+```
+
+Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid.
-# %%
-# Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid.
+```{code-cell} ipython3
X, Y, Z = basis_2d.evaluate_on_grid(100, 100)
+```
-# %%
-# We can visualize the basis.
+We can visualize the basis.
+
+```{code-cell} ipython3
fig, axs = plt.subplots(2, 5, figsize=(10, 4))
for k in range(2):
for h in range(5):
axs[k][h].contourf(X, Y, Z[:, :, 50 + 2 * (k + h)], cmap="Blues")
plt.tight_layout()
+```
+
+Each basis element represent a possible position of the animal in an arena.
+Now we can "evaluate" the basis for each position of the animal
-# %%
-# Each basis element represent a possible position of the animal in an arena.
-# Now we can "evaluate" the basis for each position of the animal
+```{code-cell} ipython3
position_basis = basis_2d(position["x"], position["y"])
+```
+
+Now try to make sense of what it is
+
-# %%
-# Now try to make sense of what it is
+```{code-cell} ipython3
print(position_basis.shape)
+```
-# %%
-# The shape is (n_samples, n_basis). This means that for each time point "t", we evaluated the basis at the
-# corresponding position. Let's plot 5 time steps.
+The shape is (n_samples, n_basis). This means that for each time point "t", we evaluated the basis at the
+corresponding position. Let's plot 5 time steps.
+
+```{code-cell} ipython3
fig = plt.figure(figsize=(12, 4))
gs = plt.GridSpec(2, 5)
xt = np.arange(0, 1000, 200)
@@ -148,39 +211,48 @@
ax.plot(position["x"][xt[i]], position["y"][xt[i]], "o", color=cmap(colors[i]))
plt.tight_layout()
+```
+Now we can fit the GLM and see what we get. In this case, we use Ridge for regularization.
+Here we will focus on the last neuron (neuron 7) who has a nice grid pattern
-# %%
-# Now we can fit the GLM and see what we get. In this case, we use Ridge for regularization.
-# Here we will focus on the last neuron (neuron 7) who has a nice grid pattern
+```{code-cell} ipython3
model = nmo.glm.GLM(
regularizer="Ridge",
regularizer_strength=0.001
)
+```
+
+Let's fit the model
-# %%
-# Let's fit the model
+```{code-cell} ipython3
neuron = 7
model.fit(position_basis, counts[:, neuron])
+```
-# %%
-# We can compute the model predicted firing rate.
+We can compute the model predicted firing rate.
+
+```{code-cell} ipython3
rate_pos = model.predict(position_basis)
+```
+
+And compute the tuning curves/
-# %%
-# And compute the tuning curves/
+```{code-cell} ipython3
model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
tsdframe=rate_pos[:, np.newaxis] * rate_pos.rate, features=position, nb_bins=12
)
+```
-# %%
-# Let's compare the tuning curve predicted by the model with that based on the actual spikes.
+Let's compare the tuning curve predicted by the model with that based on the actual spikes.
+
+```{code-cell} ipython3
smooth_pos_tuning = gaussian_filter(pos_tuning[neuron], sigma=1)
smooth_model = gaussian_filter(model_tuning[0], sigma=1)
@@ -194,13 +266,15 @@
ax = plt.subplot(gs[0, 1])
ax.imshow(smooth_model, vmin=vmin, vmax=vmax)
plt.tight_layout()
+```
+
+The grid shows but the peak firing rate is off, we might have over-regularized.
+We can fix this by tuning the regularization strength by means of cross-validation.
+This can be done through scikit-learn. Let's apply a grid-search over different
+values, and select the regularization by k-fold cross-validation.
-# %%
-# The grid shows but the peak firing rate is off, we might have over-regularized.
-# We can fix this by tuning the regularization strength by means of cross-validation.
-# This can be done through scikit-learn. Let's apply a grid-search over different
-# values, and select the regularization by k-fold cross-validation.
+```{code-cell} ipython3
# import the grid-search cross-validation from scikit-learn
from sklearn.model_selection import GridSearchCV
@@ -212,15 +286,19 @@
# run the search, the default is a 5-fold cross-validation strategy
cls.fit(position_basis, counts[:, neuron])
+```
-# %%
-# Let's get the best estimator and see what we get.
+Let's get the best estimator and see what we get.
+
+```{code-cell} ipython3
best_model = cls.best_estimator_
+```
+
+Let's predict and compute the tuning curves once again.
-# %%
-# Let's predict and compute the tuning curves once again.
+```{code-cell} ipython3
# predict the rate with the selected model
best_rate_pos = best_model.predict(position_basis)
@@ -228,10 +306,12 @@
best_model_tuning, binsxy = nap.compute_2d_tuning_curves_continuous(
tsdframe=best_rate_pos[:, np.newaxis] * best_rate_pos.rate, features=position, nb_bins=12
)
+```
-# %%
-# We can now plot the results.
+We can now plot the results.
+
+```{code-cell} ipython3
# plot the resutls
smooth_best_model = gaussian_filter(best_model_tuning[0], sigma=1)
@@ -247,6 +327,26 @@
axs[2].set_title(f"Ridge - strength: {best_model.regularizer_strength}")
axs[2].imshow(smooth_best_model, vmin=vmin, vmax=vmax)
plt.tight_layout()
-
-
-
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in ../_build/html...
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_03_grid_cells.svg")
+```
\ No newline at end of file
diff --git a/docs/tutorials/plot_04_v1_cells.md b/docs/tutorials/plot_04_v1_cells.md
new file mode 100644
index 00000000..bd91d5ef
--- /dev/null
+++ b/docs/tutorials/plot_04_v1_cells.md
@@ -0,0 +1,381 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+# Fit V1 cell
+
+The data presented in this notebook was collected by [Sonica Saraf](https://www.cns.nyu.edu/~saraf/) from the [Movshon lab](https://www.cns.nyu.edu/labs/movshonlab/) at NYU.
+
+The notebook focuses on fitting a V1 cell model.
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+import numpy as np
+import pynapple as nap
+
+import nemos as nmo
+
+# configure plots some
+plt.style.use(nmo.styles.plot_style)
+
+
+# utility for filling a time series
+def fill_forward(time_series, data, ep=None, out_of_range=np.nan):
+ """
+ Fill a time series forward in time with data.
+
+ Parameters
+ ----------
+ time_series:
+ The time series to match.
+ data: Tsd, TsdFrame, or TsdTensor
+ The time series with data to be extend.
+
+ Returns
+ -------
+ : Tsd, TsdFrame, or TsdTensor
+ The data time series filled forward.
+
+ """
+ assert isinstance(data, (nap.Tsd, nap.TsdFrame, nap.TsdTensor))
+
+ if ep is None:
+ ep = time_series.time_support
+ else:
+ assert isinstance(ep, nap.IntervalSet)
+ time_series.restrict(ep)
+
+ data = data.restrict(ep)
+ starts = ep.start
+ ends = ep.end
+
+ filled_d = np.full((time_series.t.shape[0], *data.shape[1:]), out_of_range, dtype=data.dtype)
+ fill_idx = 0
+ for start, end in zip(starts, ends):
+ data_ep = data.get(start, end)
+ ts_ep = time_series.get(start, end)
+ idxs = np.searchsorted(data_ep.t, ts_ep.t, side="right") - 1
+ filled_d[fill_idx:fill_idx + ts_ep.t.shape[0]][idxs >= 0] = data_ep.d[idxs[idxs>=0]]
+ fill_idx += ts_ep.t.shape[0]
+ return type(data)(t=time_series.t, d=filled_d, time_support=ep)
+```
+
+## Data Streaming
+
+
+
+```{code-cell} ipython3
+path = nmo.fetch.fetch_data("m691l1.nwb")
+```
+
+## Pynapple
+The data have been copied to your local station.
+We are gonna open the NWB file with pynapple
+
+
+```{code-cell} ipython3
+dataset = nap.load_file(path)
+```
+
+What does it look like?
+
+
+```{code-cell} ipython3
+print(dataset)
+```
+
+Let's extract the data.
+
+
+```{code-cell} ipython3
+epochs = dataset["epochs"]
+units = dataset["units"]
+stimulus = dataset["whitenoise"]
+```
+
+Stimulus is white noise shown at 40 Hz
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 1, figsize=(12,4))
+ax.imshow(stimulus[0], cmap='Greys_r')
+stimulus.shape
+```
+
+There are 73 neurons recorded together in V1. To fit the GLM faster, we will focus on one neuron.
+
+
+```{code-cell} ipython3
+print(units)
+# this returns TsGroup with one neuron only
+spikes = units[[34]]
+```
+
+How could we predict neuron's response to white noise stimulus?
+
+- we could fit the instantaneous spatial response. that is, just predict
+ neuron's response to a given frame of white noise. this will give an x by y
+ filter. implicitly assumes that there's no temporal info: only matters what
+ we've just seen
+
+- could fit spatiotemporal filter. instead of an x by y that we use
+ independently on each frame, fit (x, y, t) over, say 100 msecs. and then
+ fit each of these independently (like in head direction example)
+
+- that's a lot of parameters! can simplify by assumping that the response is
+ separable: fit a single (x, y) filter and then modulate it over time. this
+ wouldn't catch e.g., direction-selectivity because it assumes that phase
+ preference is constant over time
+
+- could make use of our knowledge of V1 and try to fit a more complex
+ functional form, e.g., a Gabor.
+
+That last one is very non-linear and thus non-convex. we'll do the third one.
+
+in this example, we'll fit the spatial filter outside of the GLM framework,
+using spike-triggered average, and then we'll use the GLM to fit the temporal
+timecourse.
+
+## Spike-triggered average
+
+Spike-triggered average says: every time our neuron spikes, we store the
+stimulus that was on the screen. for the whole recording, we'll have many of
+these, which we then average to get this STA, which is the "optimal stimulus"
+/ spatial filter.
+
+In practice, we do not just the stimulus on screen, but in some window of
+time around it. (it takes some time for info to travel through the eye/LGN to
+V1). Pynapple makes this easy:
+
+
+```{code-cell} ipython3
+sta = nap.compute_event_trigger_average(spikes, stimulus, binsize=0.025,
+ windowsize=(-0.15, 0.0))
+```
+
+sta is a [`TsdTensor`](https://pynapple.org/generated/pynapple.core.time_series.TsdTensor.html), which gives us the 2d receptive field at each of the
+time points.
+
+
+```{code-cell} ipython3
+sta
+```
+
+We index into this in a 2d manner: row, column (here we only have 1 column).
+
+
+```{code-cell} ipython3
+sta[1, 0]
+```
+
+we can easily plot this
+
+
+```{code-cell} ipython3
+fig, axes = plt.subplots(1, len(sta), figsize=(3*len(sta),3))
+for i, t in enumerate(sta.t):
+ axes[i].imshow(sta[i,0], vmin = np.min(sta), vmax = np.max(sta),
+ cmap='Greys_r')
+ axes[i].set_title(str(t)+" s")
+```
+
+that looks pretty reasonable for a V1 simple cell: localized in space,
+orientation, and spatial frequency. that is, looks Gabor-ish
+
+To convert this to the spatial filter we'll use for the GLM, let's take the
+average across the bins that look informative: -.125 to -.05
+
+
+```{code-cell} ipython3
+# mkdocs_gallery_thumbnail_number = 3
+receptive_field = np.mean(sta.get(-0.125, -0.05), axis=0)[0]
+
+fig, ax = plt.subplots(1, 1, figsize=(4,4))
+ax.imshow(receptive_field, cmap='Greys_r')
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in assets
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_04_v1_cells.svg")
+```
+
+This receptive field gives us the spatial part of the linear response: it
+gives a map of weights that we use for a weighted sum on an image. There are
+multiple ways of performing this operation:
+
+
+```{code-cell} ipython3
+# element-wise multiplication and sum
+print((receptive_field * stimulus[0]).sum())
+# dot product of flattened versions
+print(np.dot(receptive_field.flatten(), stimulus[0].flatten()))
+```
+
+When performing this operation on multiple stimuli, things become slightly
+more complicated. For loops on the above methods would work, but would be
+slow. Reshaping and using the dot product is one common method, as are
+methods like `np.tensordot`.
+
+We'll use einsum to do this, which is a convenient way of representing many
+different matrix operations:
+
+
+```{code-cell} ipython3
+filtered_stimulus = np.einsum('t h w, h w -> t', stimulus, receptive_field)
+```
+
+This notation says: take these arrays with dimensions `(t,h,w)` and `(h,w)`
+and multiply and sum to get an array of shape `(t,)`. This performs the same
+operations as above.
+
+And this remains a pynapple object, so we can easily visualize it!
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 1, figsize=(12,4))
+ax.plot(filtered_stimulus)
+```
+
+But what is this? It's how much each frame in the video should drive our
+neuron, based on the receptive field we fit using the spike-triggered
+average.
+
+This, then, is the spatial component of our input, as described above.
+
+## Preparing data for NeMoS
+
+We'll now use the GLM to fit the temporal component. To do that, let's get
+this and our spike counts into the proper format for NeMoS:
+
+
+```{code-cell} ipython3
+# grab spikes from when we were showing our stimulus, and bin at 1 msec
+# resolution
+bin_size = .001
+counts = spikes[34].restrict(filtered_stimulus.time_support).count(bin_size)
+print(counts.rate)
+print(filtered_stimulus.rate)
+```
+
+Hold on, our stimulus is at a much lower rate than what we want for our rates
+-- in previous tutorials, our input has been at a higher rate than our spikes,
+and so we used `bin_average` to down-sample to the appropriate rate. When the
+input is at a lower rate, we need to think a little more carefully about how
+to up-sample.
+
+
+```{code-cell} ipython3
+print(counts[:5])
+print(filtered_stimulus[:5])
+```
+
+What was the visual input to the neuron at time 0.005? It was the same input
+as time 0. At time 0.0015? Same thing, up until we pass time 0.025017. Thus,
+we want to "fill forward" the values of our input, and we have pynapple
+convenience function to do so:
+
+
+```{code-cell} ipython3
+filtered_stimulus = fill_forward(counts, filtered_stimulus)
+filtered_stimulus
+```
+
+We can see that the time points are now aligned, and we've filled forward the
+values the way we'd like.
+
+Now, similar to the [head direction tutorial](plot_02_head_direction), we'll
+use the log-stretched raised cosine basis to create the predictor for our
+GLM:
+
+
+```{code-cell} ipython3
+window_size = 100
+basis = nmo.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size)
+
+convolved_input = basis.compute_features(filtered_stimulus)
+```
+
+convolved_input has shape (n_time_pts, n_features * n_basis_funcs), because
+n_features is the singleton dimension from filtered_stimulus.
+
+## Fitting the GLM
+
+Now we're ready to fit the model! Let's do it, same as before:
+
+
+```{code-cell} ipython3
+model = nmo.glm.GLM()
+model.fit(convolved_input, counts)
+```
+
+We have our coefficients for each of our 8 basis functions, let's combine
+them to get the temporal time course of our input:
+
+
+```{code-cell} ipython3
+time, basis_kernels = basis.evaluate_on_grid(window_size)
+time *= bin_size * window_size
+temp_weights = np.einsum('b, t b -> t', model.coef_, basis_kernels)
+plt.plot(time, temp_weights)
+plt.xlabel("time[sec]")
+plt.ylabel("amplitude")
+```
+
+When taken together, the results of the GLM and the spike-triggered average
+give us the linear component of our LNP model: the separable spatio-temporal
+filter.
diff --git a/docs/tutorials/plot_04_v1_cells.py b/docs/tutorials/plot_04_v1_cells.py
deleted file mode 100644
index ea3da54a..00000000
--- a/docs/tutorials/plot_04_v1_cells.py
+++ /dev/null
@@ -1,293 +0,0 @@
-# # -*- coding: utf-8 -*-
-#
-"""# Fit V1 cell
-
-The data presented in this notebook was collected by [Sonica Saraf](https://www.cns.nyu.edu/~saraf/) from the [Movshon lab](https://www.cns.nyu.edu/labs/movshonlab/) at NYU.
-
-The notebook focuses on fitting a V1 cell model.
-
-"""
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pynapple as nap
-
-import nemos as nmo
-
-# configure plots some
-plt.style.use(nmo.styles.plot_style)
-
-
-# utility for filling a time series
-def fill_forward(time_series, data, ep=None, out_of_range=np.nan):
- """
- Fill a time series forward in time with data.
-
- Parameters
- ----------
- time_series:
- The time series to match.
- data: Tsd, TsdFrame, or TsdTensor
- The time series with data to be extend.
-
- Returns
- -------
- : Tsd, TsdFrame, or TsdTensor
- The data time series filled forward.
-
- """
- assert isinstance(data, (nap.Tsd, nap.TsdFrame, nap.TsdTensor))
-
- if ep is None:
- ep = time_series.time_support
- else:
- assert isinstance(ep, nap.IntervalSet)
- time_series.restrict(ep)
-
- data = data.restrict(ep)
- starts = ep.start
- ends = ep.end
-
- filled_d = np.full((time_series.t.shape[0], *data.shape[1:]), out_of_range, dtype=data.dtype)
- fill_idx = 0
- for start, end in zip(starts, ends):
- data_ep = data.get(start, end)
- ts_ep = time_series.get(start, end)
- idxs = np.searchsorted(data_ep.t, ts_ep.t, side="right") - 1
- filled_d[fill_idx:fill_idx + ts_ep.t.shape[0]][idxs >= 0] = data_ep.d[idxs[idxs>=0]]
- fill_idx += ts_ep.t.shape[0]
- return type(data)(t=time_series.t, d=filled_d, time_support=ep)
-
-
-# %%
-# ## Data Streaming
-#
-
-path = nmo.fetch.fetch_data("m691l1.nwb")
-
-# %%
-# ## Pynapple
-# The data have been copied to your local station.
-# We are gonna open the NWB file with pynapple
-
-dataset = nap.load_file(path)
-
-# %%
-# What does it look like?
-print(dataset)
-
-# %%
-# Let's extract the data.
-epochs = dataset["epochs"]
-units = dataset["units"]
-stimulus = dataset["whitenoise"]
-
-# %%
-# Stimulus is white noise shown at 40 Hz
-
-fig, ax = plt.subplots(1, 1, figsize=(12,4))
-ax.imshow(stimulus[0], cmap='Greys_r')
-stimulus.shape
-
-# %%
-# There are 73 neurons recorded together in V1. To fit the GLM faster, we will focus on one neuron.
-print(units)
-# this returns TsGroup with one neuron only
-spikes = units[[34]]
-
-# %%
-# How could we predict neuron's response to white noise stimulus?
-#
-# - we could fit the instantaneous spatial response. that is, just predict
-# neuron's response to a given frame of white noise. this will give an x by y
-# filter. implicitly assumes that there's no temporal info: only matters what
-# we've just seen
-#
-# - could fit spatiotemporal filter. instead of an x by y that we use
-# independently on each frame, fit (x, y, t) over, say 100 msecs. and then
-# fit each of these independently (like in head direction example)
-#
-# - that's a lot of parameters! can simplify by assumping that the response is
-# separable: fit a single (x, y) filter and then modulate it over time. this
-# wouldn't catch e.g., direction-selectivity because it assumes that phase
-# preference is constant over time
-#
-# - could make use of our knowledge of V1 and try to fit a more complex
-# functional form, e.g., a Gabor.
-#
-# That last one is very non-linear and thus non-convex. we'll do the third one.
-#
-# in this example, we'll fit the spatial filter outside of the GLM framework,
-# using spike-triggered average, and then we'll use the GLM to fit the temporal
-# timecourse.
-#
-# ## Spike-triggered average
-#
-# Spike-triggered average says: every time our neuron spikes, we store the
-# stimulus that was on the screen. for the whole recording, we'll have many of
-# these, which we then average to get this STA, which is the "optimal stimulus"
-# / spatial filter.
-#
-# In practice, we do not just the stimulus on screen, but in some window of
-# time around it. (it takes some time for info to travel through the eye/LGN to
-# V1). Pynapple makes this easy:
-
-
-sta = nap.compute_event_trigger_average(spikes, stimulus, binsize=0.025,
- windowsize=(-0.15, 0.0))
-# %%
-#
-# sta is a `TsdTensor`, which gives us the 2d receptive field at each of the
-# time points.
-
-sta
-
-# %%
-#
-# We index into this in a 2d manner: row, column (here we only have 1 column).
-sta[1, 0]
-
-# %%
-# we can easily plot this
-
-fig, axes = plt.subplots(1, len(sta), figsize=(3*len(sta),3))
-for i, t in enumerate(sta.t):
- axes[i].imshow(sta[i,0], vmin = np.min(sta), vmax = np.max(sta),
- cmap='Greys_r')
- axes[i].set_title(str(t)+" s")
-
-
-# %%
-#
-# that looks pretty reasonable for a V1 simple cell: localized in space,
-# orientation, and spatial frequency. that is, looks Gabor-ish
-#
-# To convert this to the spatial filter we'll use for the GLM, let's take the
-# average across the bins that look informative: -.125 to -.05
-
-# mkdocs_gallery_thumbnail_number = 3
-receptive_field = np.mean(sta.get(-0.125, -0.05), axis=0)[0]
-
-fig, ax = plt.subplots(1, 1, figsize=(4,4))
-ax.imshow(receptive_field, cmap='Greys_r')
-
-# %%
-#
-# This receptive field gives us the spatial part of the linear response: it
-# gives a map of weights that we use for a weighted sum on an image. There are
-# multiple ways of performing this operation:
-
-# element-wise multiplication and sum
-print((receptive_field * stimulus[0]).sum())
-# dot product of flattened versions
-print(np.dot(receptive_field.flatten(), stimulus[0].flatten()))
-
-# %%
-#
-# When performing this operation on multiple stimuli, things become slightly
-# more complicated. For loops on the above methods would work, but would be
-# slow. Reshaping and using the dot product is one common method, as are
-# methods like `np.tensordot`.
-#
-# We'll use einsum to do this, which is a convenient way of representing many
-# different matrix operations:
-
-filtered_stimulus = np.einsum('t h w, h w -> t', stimulus, receptive_field)
-
-# %%
-#
-# This notation says: take these arrays with dimensions `(t,h,w)` and `(h,w)`
-# and multiply and sum to get an array of shape `(t,)`. This performs the same
-# operations as above.
-#
-# And this remains a pynapple object, so we can easily visualize it!
-
-fig, ax = plt.subplots(1, 1, figsize=(12,4))
-ax.plot(filtered_stimulus)
-
-# %%
-#
-# But what is this? It's how much each frame in the video should drive our
-# neuron, based on the receptive field we fit using the spike-triggered
-# average.
-#
-# This, then, is the spatial component of our input, as described above.
-#
-# ## Preparing data for NeMoS
-#
-# We'll now use the GLM to fit the temporal component. To do that, let's get
-# this and our spike counts into the proper format for NeMoS:
-
-# grab spikes from when we were showing our stimulus, and bin at 1 msec
-# resolution
-bin_size = .001
-counts = spikes[34].restrict(filtered_stimulus.time_support).count(bin_size)
-print(counts.rate)
-print(filtered_stimulus.rate)
-
-# %%
-#
-# Hold on, our stimulus is at a much lower rate than what we want for our rates
-# -- in previous tutorials, our input has been at a higher rate than our spikes,
-# and so we used `bin_average` to down-sample to the appropriate rate. When the
-# input is at a lower rate, we need to think a little more carefully about how
-# to up-sample.
-
-print(counts[:5])
-print(filtered_stimulus[:5])
-
-# %%
-#
-# What was the visual input to the neuron at time 0.005? It was the same input
-# as time 0. At time 0.0015? Same thing, up until we pass time 0.025017. Thus,
-# we want to "fill forward" the values of our input, and we have pynapple
-# convenience function to do so:
-filtered_stimulus = fill_forward(counts, filtered_stimulus)
-filtered_stimulus
-
-# %%
-#
-# We can see that the time points are now aligned, and we've filled forward the
-# values the way we'd like.
-#
-# Now, similar to the [head direction tutorial](../plot_02_head_direction), we'll
-# use the log-stretched raised cosine basis to create the predictor for our
-# GLM:
-
-window_size = 100
-basis = nmo.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size)
-
-convolved_input = basis.compute_features(filtered_stimulus)
-
-# %%
-#
-# convolved_input has shape (n_time_pts, n_features * n_basis_funcs), because
-# n_features is the singleton dimension from filtered_stimulus.
-#
-# ## Fitting the GLM
-#
-# Now we're ready to fit the model! Let's do it, same as before:
-
-
-model = nmo.glm.GLM()
-model.fit(convolved_input, counts)
-
-# %%
-#
-# We have our coefficients for each of our 8 basis functions, let's combine
-# them to get the temporal time course of our input:
-
-time, basis_kernels = basis.evaluate_on_grid(window_size)
-time *= bin_size * window_size
-temp_weights = np.einsum('b, t b -> t', model.coef_, basis_kernels)
-plt.plot(time, temp_weights)
-plt.xlabel("time[sec]")
-plt.ylabel("amplitude")
-
-# %%
-#
-# When taken together, the results of the GLM and the spike-triggered average
-# give us the linear component of our LNP model: the separable spatio-temporal
-# filter.
-
-
diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md
new file mode 100644
index 00000000..82256887
--- /dev/null
+++ b/docs/tutorials/plot_05_place_cells.md
@@ -0,0 +1,541 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+
+# Fit place cell
+
+The data for this example are from [Grosmark, Andres D., and György Buzsáki. "Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences." Science 351.6280 (2016): 1440-1443](https://www.science.org/doi/full/10.1126/science.aad1935).
+
+```{code-cell} ipython3
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import pynapple as nap
+from scipy.ndimage import gaussian_filter
+
+import nemos as nmo
+
+# some helper plotting functions
+from nemos import _documentation_utils as doc_plots
+
+# configure plots some
+plt.style.use(nmo.styles.plot_style)
+
+# shut down jax to numpy conversion warning
+nap.nap_config.suppress_conversion_warnings = True
+```
+
+## Data Streaming
+
+Here we load the data from OSF. The data is a NWB file.
+
+
+```{code-cell} ipython3
+path = nmo.fetch.fetch_data("Achilles_10252013.nwb")
+```
+
+## Pynapple
+We are going to open the NWB file with pynapple
+
+
+```{code-cell} ipython3
+data = nap.load_file(path)
+
+data
+```
+
+Let's extract the spike times, the position and the theta phase.
+
+
+```{code-cell} ipython3
+spikes = data["units"]
+position = data["position"]
+theta = data["theta_phase"]
+```
+
+The NWB file also contains the time at which the animal was traversing the linear track. We can use it to restrict the position and assign it as the `time_support` of position.
+
+
+```{code-cell} ipython3
+position = position.restrict(data["trials"])
+```
+
+The recording contains both inhibitory and excitatory neurons. Here we will focus of the excitatory cells. Neurons have already been labelled before.
+
+
+```{code-cell} ipython3
+spikes = spikes.getby_category("cell_type")["pE"]
+```
+
+We can discard the low firing neurons as well.
+
+
+```{code-cell} ipython3
+spikes = spikes.getby_threshold("rate", 0.3)
+```
+
+## Place fields
+Let's plot some data. We start by making place fields i.e firing rate as a function of position.
+
+
+```{code-cell} ipython3
+pf = nap.compute_1d_tuning_curves(spikes, position, 50, position.time_support)
+```
+
+Let's do a quick sort of the place fields for display
+
+
+```{code-cell} ipython3
+order = pf.idxmax().sort_values().index.values
+```
+
+Here each row is one neuron
+
+
+```{code-cell} ipython3
+fig = plt.figure(figsize=(12, 10))
+gs = plt.GridSpec(len(spikes), 1)
+for i, n in enumerate(order):
+ plt.subplot(gs[i, 0])
+ plt.fill_between(pf.index.values, np.zeros(len(pf)), pf[n].values)
+ if i < len(spikes) - 1:
+ plt.xticks([])
+ else:
+ plt.xlabel("Position (cm)")
+ plt.yticks([])
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in assets
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_05_place_cells.svg")
+```
+
+## Phase precession
+
+In addition to place modulation, place cells are also modulated by the theta oscillation. The phase at which neurons fire is dependant of the position. This phenomen is called "phase precession" (see [J. O’Keefe, M. L. Recce, "Phase relationship between hippocampal place units and the EEG theta rhythm." Hippocampus 3, 317–330 (1993).](https://doi.org/10.1002/hipo.450030307)).
+
+Let's compute the response of the neuron as a function of both theta and position. The phase of theta has already been computed but we have to bring it to the same dimension as the position feature. While the position has been sampled at 40Hz, the theta phase has been computed at 1250Hz.
+Later on during the analysis, we will use a bin size of 5 ms for counting the spikes. Since this corresponds to an intermediate frequency between 40 and 1250 Hz, we will bring all the features to 200Hz already.
+
+
+```{code-cell} ipython3
+bin_size = 0.005
+
+theta = theta.bin_average(bin_size, position.time_support)
+theta = (theta + 2 * np.pi) % (2 * np.pi)
+
+data = nap.TsdFrame(
+ t=theta.t,
+ d=np.vstack(
+ (position.interpolate(theta, ep=position.time_support).values, theta.values)
+ ).T,
+ time_support=position.time_support,
+ columns=["position", "theta"],
+)
+
+print(data)
+```
+
+`data` is a [`TsdFrame`](https://pynapple.org/generated/pynapple.core.time_series.TsdTensor.html) that contains the position and phase. Before calling [`compute_2d_tuning_curves`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_2d_tuning_curves) from pynapple to observe the theta phase precession, we will restrict the analysis to the place field of one neuron.
+
+There are a lot of neurons but for this analysis, we will focus on one neuron only.
+
+
+```{code-cell} ipython3
+neuron = 175
+
+plt.figure(figsize=(5,3))
+plt.fill_between(pf[neuron].index.values, np.zeros(len(pf)), pf[neuron].values)
+plt.xlabel("Position (cm)")
+plt.ylabel("Firing rate (Hz)")
+```
+
+This neurons place field is between 0 and 60 cm within the linear track. Here we will use the `threshold` function of pynapple to quickly compute the epochs for which the animal is within the place field :
+
+
+```{code-cell} ipython3
+within_ep = position.threshold(60.0, method="below").time_support
+```
+
+`within_ep` is an [`IntervalSet`](https://pynapple.org/generated/pynapple.core.interval_set.IntervalSet.html). We can now give it to [`compute_2d_tuning_curves`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_2d_tuning_curves) along with the spiking activity and the position-phase features.
+
+
+```{code-cell} ipython3
+tc_pos_theta, xybins = nap.compute_2d_tuning_curves(spikes, data, 20, within_ep)
+```
+
+To show the theta phase precession, we can also display the spike as a function of both position and theta. In this case, we use the function `value_from` from pynapple.
+
+
+```{code-cell} ipython3
+theta_pos_spikes = spikes[neuron].value_from(data, ep = within_ep)
+```
+
+Now we can plot everything together :
+
+
+```{code-cell} ipython3
+plt.figure()
+gs = plt.GridSpec(2, 2)
+plt.subplot(gs[0, 0])
+plt.fill_between(pf[neuron].index.values, np.zeros(len(pf)), pf[neuron].values)
+plt.xlabel("Position (cm)")
+plt.ylabel("Firing rate (Hz)")
+
+plt.subplot(gs[1, 0])
+extent = (xybins[0][0], xybins[0][-1], xybins[1][0], xybins[1][-1])
+plt.imshow(gaussian_filter(tc_pos_theta[neuron].T, 1), aspect="auto", origin="lower", extent=extent)
+plt.xlabel("Position (cm)")
+plt.ylabel("Theta Phase (rad)")
+
+plt.subplot(gs[1, 1])
+plt.plot(theta_pos_spikes["position"], theta_pos_spikes["theta"], "o", markersize=0.5)
+plt.xlabel("Position (cm)")
+plt.ylabel("Theta Phase (rad)")
+
+plt.tight_layout()
+```
+
+## Speed modulation
+The speed at which the animal traverse the field is not homogeneous. Does it influence the firing rate of hippocampal neurons? We can compute tuning curves for speed as well as average speed across the maze.
+In the next block, we compute the speed of the animal for each epoch (i.e. crossing of the linear track) by doing the difference of two consecutive position multiplied by the sampling rate of the position.
+
+
+```{code-cell} ipython3
+speed = []
+for s, e in data.time_support.values: # Time support contains the epochs
+ pos_ep = data["position"].get(s, e)
+ speed_ep = np.abs(np.diff(pos_ep)) # Absolute difference of two consecutive points
+ speed_ep = np.pad(speed_ep, [0, 1], mode="edge") # Adding one point at the end to match the size of the position array
+ speed_ep = speed_ep * data.rate # Converting to cm/s
+ speed.append(speed_ep)
+
+speed = nap.Tsd(t=data.t, d=np.hstack(speed), time_support=data.time_support)
+```
+
+Now that we have the speed of the animal, we can compute the tuning curves for speed modulation. Here we call pynapple [`compute_1d_tuning_curves`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_1d_tuning_curves):
+
+
+```{code-cell} ipython3
+tc_speed = nap.compute_1d_tuning_curves(spikes, speed, 20)
+```
+
+To assess the variabilty in speed when the animal is travering the linear track, we can compute the average speed and estimate the standard deviation. Here we use numpy only and put the results in a pandas [`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html):
+
+
+```{code-cell} ipython3
+bins = np.linspace(np.min(data["position"]), np.max(data["position"]), 20)
+
+idx = np.digitize(data["position"].values, bins)
+
+mean_speed = np.array([np.mean(speed[idx==i]) for i in np.unique(idx)])
+std_speed = np.array([np.std(speed[idx==i]) for i in np.unique(idx)])
+```
+
+Here we plot the tuning curve of one neuron for speed as well as the average speed as a function of the animal position
+
+
+```{code-cell} ipython3
+plt.figure(figsize=(8, 3))
+plt.subplot(121)
+plt.plot(bins, mean_speed)
+plt.fill_between(
+ bins,
+ mean_speed - std_speed,
+ mean_speed + std_speed,
+ alpha=0.1,
+)
+plt.xlabel("Position (cm)")
+plt.ylabel("Speed (cm/s)")
+plt.title("Animal speed")
+plt.subplot(122)
+plt.fill_between(
+ tc_speed.index.values, np.zeros(len(tc_speed)), tc_speed[neuron].values
+)
+plt.xlabel("Speed (cm/s)")
+plt.ylabel("Firing rate (Hz)")
+plt.title("Neuron {}".format(neuron))
+plt.tight_layout()
+```
+
+This neurons show a strong modulation of firing rate as a function of speed but we can also notice that the animal, on average, accelerates when travering the field. Is the speed tuning we observe a true modulation or spurious correlation caused by traversing the place field at different speed and for different theta phase? We can use NeMoS to model the activity and give the position, the phase and the speed as input variable.
+
+We will use speed, phase and position to model the activity of the neuron.
+All the feature have already been brought to the same dimension thanks to `pynapple`.
+
+
+```{code-cell} ipython3
+position = data["position"]
+theta = data["theta"]
+count = spikes[neuron].count(bin_size, data.time_support)
+
+print(position.shape)
+print(theta.shape)
+print(speed.shape)
+print(count.shape)
+```
+
+## Basis evaluation
+
+For each feature, we will use a different set of basis :
+
+ - position : [`MSplineBasis`](nemos.basis.MSplineBasis)
+ - theta phase : [`CyclicBSplineBasis`](nemos.basis.CyclicBSplineBasis)
+ - speed : [`MSplineBasis`](nemos.basis.MSplineBasis)
+
+
+```{code-cell} ipython3
+position_basis = nmo.basis.MSplineBasis(n_basis_funcs=10)
+phase_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12)
+speed_basis = nmo.basis.MSplineBasis(n_basis_funcs=15)
+```
+
+In addition, we will consider position and phase to be a joint variable. In NeMoS, we can combine basis by multiplying them and adding them. In this case the final basis object for our model can be made in one line :
+
+
+```{code-cell} ipython3
+basis = position_basis * phase_basis + speed_basis
+```
+
+The object basis only tell us how each basis covers the feature space. For each timestep, we need to _evaluate_ what are the features value. For that we can call NeMoS basis:
+
+
+```{code-cell} ipython3
+X = basis(position, theta, speed)
+```
+
+`X` is our design matrix. For each timestamps, it contains the information about the current position,
+speed and theta phase of the experiment. Notice how passing a pynapple object to the basis
+also returns a `pynapple` object.
+
+
+```{code-cell} ipython3
+print(X)
+```
+
+## Model learning
+
+We can now use the Poisson GLM from NeMoS to learn the model.
+
+
+```{code-cell} ipython3
+glm = nmo.glm.GLM(
+ solver_kwargs=dict(tol=10**-12),
+ solver_name="LBFGS"
+)
+
+glm.fit(X, count)
+```
+
+## Prediction
+
+Let's check first if our model can accurately predict the different tuning curves we displayed above. We can use the [`predict`](nemos.glm.GLM.predict) function of NeMoS and then compute new tuning curves
+
+
+```{code-cell} ipython3
+predicted_rate = glm.predict(X) / bin_size
+
+glm_pf = nap.compute_1d_tuning_curves_continuous(predicted_rate[:, np.newaxis], position, 50)
+glm_pos_theta, xybins = nap.compute_2d_tuning_curves_continuous(
+ predicted_rate[:, np.newaxis], data, 30, ep=within_ep
+)
+glm_speed = nap.compute_1d_tuning_curves_continuous(predicted_rate[:, np.newaxis], speed, 30)
+```
+
+Let's display both tuning curves together.
+
+
+```{code-cell} ipython3
+fig = doc_plots.plot_position_phase_speed_tuning(
+ pf[neuron],
+ glm_pf[0],
+ tc_speed[neuron],
+ glm_speed[0],
+ tc_pos_theta[neuron],
+ glm_pos_theta[0],
+ xybins
+ )
+```
+
+## Model selection
+
+While this model captures nicely the features-rate relationship, it is not necessarily the simplest model. Let's construct several models and evaluate their score to determine the best model.
+
+:::{note}
+
+To shorten this notebook, only a few combinations are tested. Feel free to expand this list.
+:::
+
+
+```{code-cell} ipython3
+models = {
+ "position": position_basis,
+ "position + speed": position_basis + speed_basis,
+ "position + phase": position_basis + phase_basis,
+ "position * phase + speed": position_basis * phase_basis + speed_basis,
+}
+features = {
+ "position": (position,),
+ "position + speed": (position, speed),
+ "position + phase": (position, theta),
+ "position * phase + speed": (position, theta, speed),
+}
+```
+
+In a loop, we can (1) evaluate the basis, (2), fit the model, (3) compute the score and (4) predict the firing rate. For evaluating the score, we can define a train set of intervals and a test set of intervals.
+
+
+```{code-cell} ipython3
+train_iset = position.time_support[::2] # Taking every other epoch
+test_iset = position.time_support[1::2]
+```
+
+Let's train all the models.
+
+
+```{code-cell} ipython3
+scores = {}
+predicted_rates = {}
+
+for m in models:
+ print("1. Evaluating basis : ", m)
+ X = models[m](*features[m])
+
+ print("2. Fitting model : ", m)
+ glm.fit(
+ X.restrict(train_iset),
+ count.restrict(train_iset),
+ )
+
+ print("3. Scoring model : ", m)
+ scores[m] = glm.score(
+ X.restrict(test_iset),
+ count.restrict(test_iset),
+ score_type="pseudo-r2-McFadden",
+ )
+
+ print("4. Predicting rate")
+ predicted_rates[m] = glm.predict(X.restrict(test_iset)) / bin_size
+
+
+scores = pd.Series(scores)
+scores = scores.sort_values()
+```
+
+Let's compute scores for each models.
+
+
+```{code-cell} ipython3
+plt.figure(figsize=(5, 3))
+plt.barh(np.arange(len(scores)), scores)
+plt.yticks(np.arange(len(scores)), scores.index)
+plt.xlabel("Pseudo r2")
+plt.tight_layout()
+```
+
+Some models are doing better than others.
+
+:::{warning}
+A proper model comparison should be done by scoring models repetitively on various train and test set. Here we are only doing partial models comparison for the sake of conciseness.
+:::
+
+Alternatively, we can plot some tuning curves to compare each models visually.
+
+
+```{code-cell} ipython3
+tuning_curves = {}
+
+for m in models:
+ tuning_curves[m] = {
+ "position": nap.compute_1d_tuning_curves_continuous(
+ predicted_rates[m][:, np.newaxis], position, 50, ep=test_iset
+ ),
+ "speed": nap.compute_1d_tuning_curves_continuous(
+ predicted_rates[m][:, np.newaxis], speed, 20, ep=test_iset
+ ),
+ }
+
+# recompute tuning from spikes restricting to the test-set
+pf = nap.compute_1d_tuning_curves(spikes, position, 50, ep=test_iset)
+tc_speed = nap.compute_1d_tuning_curves(spikes, speed, 20, ep=test_iset)
+
+
+fig = plt.figure(figsize=(8, 4))
+outer_grid = fig.add_gridspec(2, 2)
+for i, m in enumerate(models):
+ doc_plots.plot_position_speed_tuning(
+ outer_grid[i // 2, i % 2],
+ tuning_curves[m],
+ pf[neuron],
+ tc_speed[neuron],
+ m)
+
+plt.tight_layout()
+plt.show()
+```
+
+## Conclusion
+
+Various combinations of features can lead to different results. Feel free to explore more. To go beyond this notebook, you can check the following references :
+
+ - [Hardcastle, Kiah, et al. "A multiplexed, heterogeneous, and adaptive code for navigation in medial entorhinal cortex." Neuron 94.2 (2017): 375-387](https://www.cell.com/neuron/pdf/S0896-6273(17)30237-4.pdf)
+
+ - [McClain, Kathryn, et al. "Position–theta-phase model of hippocampal place cell activity applied to quantification of running speed modulation of firing rate." Proceedings of the National Academy of Sciences 116.52 (2019): 27035-27042](https://www.pnas.org/doi/abs/10.1073/pnas.1912792116)
+
+ - [Peyrache, Adrien, Natalie Schieferstein, and Gyorgy Buzsáki. "Transformation of the head-direction signal into a spatial code." Nature communications 8.1 (2017): 1752.](https://www.nature.com/articles/s41467-017-01908-3)
diff --git a/docs/tutorials/plot_05_place_cells.py b/docs/tutorials/plot_05_place_cells.py
deleted file mode 100644
index 24ea3f63..00000000
--- a/docs/tutorials/plot_05_place_cells.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""
-# Fit place cell
-
-The data for this example are from [Grosmark, Andres D., and György Buzsáki. "Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences." Science 351.6280 (2016): 1440-1443](https://www.science.org/doi/full/10.1126/science.aad1935).
-
-"""
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import pynapple as nap
-from scipy.ndimage import gaussian_filter
-
-import nemos as nmo
-
-# some helper plotting functions
-from nemos import _documentation_utils as doc_plots
-
-# configure plots some
-plt.style.use(nmo.styles.plot_style)
-
-# %%
-# ## Data Streaming
-#
-# Here we load the data from OSF. The data is a NWB file.
-
-path = nmo.fetch.fetch_data("Achilles_10252013.nwb")
-
-# %%
-# ## Pynapple
-# We are going to open the NWB file with pynapple
-
-data = nap.load_file(path)
-
-data
-
-# %%
-# Let's extract the spike times, the position and the theta phase.
-
-spikes = data["units"]
-position = data["position"]
-theta = data["theta_phase"]
-
-# %%
-# The NWB file also contains the time at which the animal was traversing the linear track. We can use it to restrict the position and assign it as the `time_support` of position.
-
-position = position.restrict(data["trials"])
-
-# %%
-# The recording contains both inhibitory and excitatory neurons. Here we will focus of the excitatory cells. Neurons have already been labelled before.
-spikes = spikes.getby_category("cell_type")["pE"]
-
-# %%
-# We can discard the low firing neurons as well.
-spikes = spikes.getby_threshold("rate", 0.3)
-
-# %%
-# ## Place fields
-# Let's plot some data. We start by making place fields i.e firing rate as a function of position.
-
-pf = nap.compute_1d_tuning_curves(spikes, position, 50, position.time_support)
-
-# %%
-# Let's do a quick sort of the place fields for display
-order = pf.idxmax().sort_values().index.values
-
-# %%
-# Here each row is one neuron
-
-plt.figure(figsize=(12, 10))
-gs = plt.GridSpec(len(spikes), 1)
-for i, n in enumerate(order):
- plt.subplot(gs[i, 0])
- plt.fill_between(pf.index.values, np.zeros(len(pf)), pf[n].values)
- if i < len(spikes) - 1:
- plt.xticks([])
- else:
- plt.xlabel("Position (cm)")
- plt.yticks([])
-
-
-# %%
-# ## Phase precession
-#
-# In addition to place modulation, place cells are also modulated by the theta oscillation. The phase at which neurons fire is dependant of the position. This phenomen is called "phase precession" (see [J. O’Keefe, M. L. Recce, "Phase relationship between hippocampal place units and the EEG theta rhythm." Hippocampus 3, 317–330 (1993).](https://doi.org/10.1002/hipo.450030307)
-#
-# Let's compute the response of the neuron as a function of both theta and position. The phase of theta has already been computed but we have to bring it to the same dimension as the position feature. While the position has been sampled at 40Hz, the theta phase has been computed at 1250Hz.
-# Later on during the analysis, we will use a bin size of 5 ms for counting the spikes. Since this corresponds to an intermediate frequency between 40 and 1250 Hz, we will bring all the features to 200Hz already.
-
-bin_size = 0.005
-
-theta = theta.bin_average(bin_size, position.time_support)
-theta = (theta + 2 * np.pi) % (2 * np.pi)
-
-data = nap.TsdFrame(
- t=theta.t,
- d=np.vstack(
- (position.interpolate(theta, ep=position.time_support).values, theta.values)
- ).T,
- time_support=position.time_support,
- columns=["position", "theta"],
-)
-
-print(data)
-
-
-
-
-
-# %%
-# `data` is a `TsdFrame` that contains the position and phase. Before calling `compute_2d_tuning_curves` from pynapple to observe the theta phase precession, we will restrict the analysis to the place field of one neuron.
-#
-# There are a lot of neurons but for this analysis, we will focus on one neuron only.
-neuron = 175
-
-plt.figure(figsize=(5,3))
-plt.fill_between(pf[neuron].index.values, np.zeros(len(pf)), pf[neuron].values)
-plt.xlabel("Position (cm)")
-plt.ylabel("Firing rate (Hz)")
-
-# %%
-# This neurons place field is between 0 and 60 cm within the linear track. Here we will use the `threshold` function of pynapple to quickly compute the epochs for which the animal is within the place field :
-
-within_ep = position.threshold(60.0, method="below").time_support
-
-# %%
-# `within_ep` is an `IntervalSet`. We can now give it to `compute_2d_tuning_curves` along with the spiking activity and the position-phase features.
-
-tc_pos_theta, xybins = nap.compute_2d_tuning_curves(spikes, data, 20, within_ep)
-
-# %%
-# To show the theta phase precession, we can also display the spike as a function of both position and theta. In this case, we use the function `value_from` from pynapple.
-
-theta_pos_spikes = spikes[neuron].value_from(data, ep = within_ep)
-
-# %%
-# Now we can plot everything together :
-
-plt.figure()
-gs = plt.GridSpec(2, 2)
-plt.subplot(gs[0, 0])
-plt.fill_between(pf[neuron].index.values, np.zeros(len(pf)), pf[neuron].values)
-plt.xlabel("Position (cm)")
-plt.ylabel("Firing rate (Hz)")
-
-plt.subplot(gs[1, 0])
-extent = (xybins[0][0], xybins[0][-1], xybins[1][0], xybins[1][-1])
-plt.imshow(gaussian_filter(tc_pos_theta[neuron].T, 1), aspect="auto", origin="lower", extent=extent)
-plt.xlabel("Position (cm)")
-plt.ylabel("Theta Phase (rad)")
-
-plt.subplot(gs[1, 1])
-plt.plot(theta_pos_spikes["position"], theta_pos_spikes["theta"], "o", markersize=0.5)
-plt.xlabel("Position (cm)")
-plt.ylabel("Theta Phase (rad)")
-
-plt.tight_layout()
-
-# %%
-# ## Speed modulation
-# The speed at which the animal traverse the field is not homogeneous. Does it influence the firing rate of hippocampal neurons? We can compute tuning curves for speed as well as average speed across the maze.
-# In the next block, we compute the speed of the animal for each epoch (i.e. crossing of the linear track) by doing the difference of two consecutive position multiplied by the sampling rate of the position.
-
-speed = []
-for s, e in data.time_support.values: # Time support contains the epochs
- pos_ep = data["position"].get(s, e)
- speed_ep = np.abs(np.diff(pos_ep)) # Absolute difference of two consecutive points
- speed_ep = np.pad(speed_ep, [0, 1], mode="edge") # Adding one point at the end to match the size of the position array
- speed_ep = speed_ep * data.rate # Converting to cm/s
- speed.append(speed_ep)
-
-speed = nap.Tsd(t=data.t, d=np.hstack(speed), time_support=data.time_support)
-
-# %%
-# Now that we have the speed of the animal, we can compute the tuning curves for speed modulation. Here we call pynapple `compute_1d_tuning_curves`:
-
-tc_speed = nap.compute_1d_tuning_curves(spikes, speed, 20)
-
-# %%
-# To assess the variabilty in speed when the animal is travering the linear track, we can compute the average speed and estimate the standard deviation. Here we use numpy only and put the results in a pandas `DataFrame`:
-
-bins = np.linspace(np.min(data["position"]), np.max(data["position"]), 20)
-
-idx = np.digitize(data["position"].values, bins)
-
-mean_speed = np.array([np.mean(speed[idx==i]) for i in np.unique(idx)])
-std_speed = np.array([np.std(speed[idx==i]) for i in np.unique(idx)])
-
-# %%
-# Here we plot the tuning curve of one neuron for speed as well as the average speed as a function of the animal position
-
-plt.figure(figsize=(8, 3))
-plt.subplot(121)
-plt.plot(bins, mean_speed)
-plt.fill_between(
- bins,
- mean_speed - std_speed,
- mean_speed + std_speed,
- alpha=0.1,
-)
-plt.xlabel("Position (cm)")
-plt.ylabel("Speed (cm/s)")
-plt.title("Animal speed")
-plt.subplot(122)
-plt.fill_between(
- tc_speed.index.values, np.zeros(len(tc_speed)), tc_speed[neuron].values
-)
-plt.xlabel("Speed (cm/s)")
-plt.ylabel("Firing rate (Hz)")
-plt.title("Neuron {}".format(neuron))
-plt.tight_layout()
-
-# %%
-# This neurons show a strong modulation of firing rate as a function of speed but we can also notice that the animal, on average, accelerates when travering the field. Is the speed tuning we observe a true modulation or spurious correlation caused by traversing the place field at different speed and for different theta phase? We can use NeMoS to model the activity and give the position, the phase and the speed as input variable.
-#
-# We will use speed, phase and position to model the activity of the neuron.
-# All the feature have already been brought to the same dimension thanks to `pynapple`.
-
-position = data["position"]
-theta = data["theta"]
-count = spikes[neuron].count(bin_size, data.time_support)
-
-print(position.shape)
-print(theta.shape)
-print(speed.shape)
-print(count.shape)
-
-# %%
-# ## Basis evaluation
-#
-# For each feature, we will use a different set of basis :
-#
-# - position : `nmo.basis.MSplineBasis`
-# - theta phase : `nmo.basis.CyclicBSplineBasis`
-# - speed : `nmo.basis.MSplineBasis`
-
-position_basis = nmo.basis.MSplineBasis(n_basis_funcs=10)
-phase_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12)
-speed_basis = nmo.basis.MSplineBasis(n_basis_funcs=15)
-
-# %%
-# In addition, we will consider position and phase to be a joint variable. In NeMoS, we can combine basis by multiplying them and adding them. In this case the final basis object for our model can be made in one line :
-
-basis = position_basis * phase_basis + speed_basis
-
-# %%
-# The object basis only tell us how each basis covers the feature space. For each timestep, we need to _evaluate_ what are the features value. For that we can call NeMoS basis:
-
-X = basis(position, theta, speed)
-
-# %%
-# `X` is our design matrix. For each timestamps, it contains the information about the current position,
-# speed and theta phase of the experiment. Notice how passing a pynapple object to the basis
-# also returns a `pynapple` object.
-
-print(X)
-
-# %%
-# ## Model learning
-#
-# We can now use the Poisson GLM from NeMoS to learn the model.
-
-glm = nmo.glm.GLM(
- solver_kwargs=dict(tol=10**-12),
- solver_name="LBFGS"
-)
-
-glm.fit(X, count)
-
-# %%
-# ## Prediction
-#
-# Let's check first if our model can accurately predict the different tuning curves we displayed above. We can use the `predict` function of NeMoS and then compute new tuning curves
-
-predicted_rate = glm.predict(X) / bin_size
-
-glm_pf = nap.compute_1d_tuning_curves_continuous(predicted_rate[:, np.newaxis], position, 50)
-glm_pos_theta, xybins = nap.compute_2d_tuning_curves_continuous(
- predicted_rate[:, np.newaxis], data, 30, ep=within_ep
-)
-glm_speed = nap.compute_1d_tuning_curves_continuous(predicted_rate[:, np.newaxis], speed, 30)
-
-# %%
-# Let's display both tuning curves together.
-fig = doc_plots.plot_position_phase_speed_tuning(
- pf[neuron],
- glm_pf[0],
- tc_speed[neuron],
- glm_speed[0],
- tc_pos_theta[neuron],
- glm_pos_theta[0],
- xybins
- )
-
-# %%
-# ## Model selection
-#
-# While this model captures nicely the features-rate relationship, it is not necessarily the simplest model. Let's construct several models and evaluate their score to determine the best model.
-#
-# !!! note
-# To shorten this notebook, only a few combinations are tested. Feel free to expand this list.
-#
-
-models = {
- "position": position_basis,
- "position + speed": position_basis + speed_basis,
- "position + phase": position_basis + phase_basis,
- "position * phase + speed": position_basis * phase_basis + speed_basis,
-}
-features = {
- "position": (position,),
- "position + speed": (position, speed),
- "position + phase": (position, theta),
- "position * phase + speed": (position, theta, speed),
-}
-
-# %%
-# In a loop, we can (1) evaluate the basis, (2), fit the model, (3) compute the score and (4) predict the firing rate. For evaluating the score, we can define a train set of intervals and a test set of intervals.
-
-train_iset = position.time_support[::2] # Taking every other epoch
-test_iset = position.time_support[1::2]
-
-# %%
-# Let's train all the models.
-scores = {}
-predicted_rates = {}
-
-for m in models:
- print("1. Evaluating basis : ", m)
- X = models[m](*features[m])
-
- print("2. Fitting model : ", m)
- glm.fit(
- X.restrict(train_iset),
- count.restrict(train_iset),
- )
-
- print("3. Scoring model : ", m)
- scores[m] = glm.score(
- X.restrict(test_iset),
- count.restrict(test_iset),
- score_type="pseudo-r2-McFadden",
- )
-
- print("4. Predicting rate")
- predicted_rates[m] = glm.predict(X.restrict(test_iset)) / bin_size
-
-
-scores = pd.Series(scores)
-scores = scores.sort_values()
-
-# %%
-# Let's compute scores for each models.
-
-plt.figure(figsize=(5, 3))
-plt.barh(np.arange(len(scores)), scores)
-plt.yticks(np.arange(len(scores)), scores.index)
-plt.xlabel("Pseudo r2")
-plt.tight_layout()
-
-
-# %%
-# Some models are doing better than others.
-#
-# !!! warning
-# A proper model comparison should be done by scoring models repetitively on various train and test set. Here we are only doing partial models comparison for the sake of conciseness.
-#
-# Alternatively, we can plot some tuning curves to compare each models visually.
-
-tuning_curves = {}
-
-for m in models:
- tuning_curves[m] = {
- "position": nap.compute_1d_tuning_curves_continuous(
- predicted_rates[m][:, np.newaxis], position, 50, ep=test_iset
- ),
- "speed": nap.compute_1d_tuning_curves_continuous(
- predicted_rates[m][:, np.newaxis], speed, 20, ep=test_iset
- ),
- }
-
-# recompute tuning from spikes restricting to the test-set
-pf = nap.compute_1d_tuning_curves(spikes, position, 50, ep=test_iset)
-tc_speed = nap.compute_1d_tuning_curves(spikes, speed, 20, ep=test_iset)
-
-
-fig = plt.figure(figsize=(8, 4))
-outer_grid = fig.add_gridspec(2, 2)
-for i, m in enumerate(models):
- doc_plots.plot_position_speed_tuning(
- outer_grid[i // 2, i % 2],
- tuning_curves[m],
- pf[neuron],
- tc_speed[neuron],
- m)
-
-plt.tight_layout()
-plt.show()
-
-# %%
-# ## Conclusion
-#
-# Various combinations of features can lead to different results. Feel free to explore more. To go beyond this notebook, you can check the following references :
-#
-# - [Hardcastle, Kiah, et al. "A multiplexed, heterogeneous, and adaptive code for navigation in medial entorhinal cortex." Neuron 94.2 (2017): 375-387](https://www.cell.com/neuron/pdf/S0896-6273(17)30237-4.pdf)
-#
-# - [McClain, Kathryn, et al. "Position–theta-phase model of hippocampal place cell activity applied to quantification of running speed modulation of firing rate." Proceedings of the National Academy of Sciences 116.52 (2019): 27035-27042](https://www.pnas.org/doi/abs/10.1073/pnas.1912792116)
-#
-# - [Peyrache, Adrien, Natalie Schieferstein, and Gyorgy Buzsáki. "Transformation of the head-direction signal into a spatial code." Nature communications 8.1 (2017): 1752.](https://www.nature.com/articles/s41467-017-01908-3)
-#
diff --git a/docs/tutorials/plot_06_calcium_imaging.md b/docs/tutorials/plot_06_calcium_imaging.md
new file mode 100644
index 00000000..1dbfc46e
--- /dev/null
+++ b/docs/tutorials/plot_06_calcium_imaging.md
@@ -0,0 +1,393 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+%matplotlib inline
+import warnings
+
+# Ignore the first specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
+ category=UserWarning,
+)
+
+# Ignore the second specific warning
+warnings.filterwarnings(
+ "ignore",
+ message="Ignoring cached namespace 'core'",
+ category=UserWarning,
+)
+
+warnings.filterwarnings(
+ "ignore",
+ message=(
+ "invalid value encountered in div "
+ ),
+ category=RuntimeWarning,
+)
+```
+
+(tutorial-calcium-imaging)=
+Fit Calcium Imaging
+============
+
+
+For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging at 30Hz using the genetically encoded calcium indicator GCaMP6f). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction.
+
+The data were collected by Sofia Skromne Carrasco from the [Peyrache Lab](https://www.peyrachelab.com/).
+
+```{code-cell} ipython3
+import jax
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+import pynapple as nap
+from sklearn.linear_model import LinearRegression
+
+import nemos as nmo
+```
+
+configure plots
+
+
+```{code-cell} ipython3
+plt.style.use(nmo.styles.plot_style)
+```
+
+## Data Streaming
+
+Here we load the data from OSF. The data is a NWB file.
+
+
+```{code-cell} ipython3
+path = nmo.fetch.fetch_data("A0670-221213.nwb")
+```
+
+***
+## pynapple preprocessing
+
+Now that we have the file, let's load the data. The NWB file contains multiple entries.
+
+
+```{code-cell} ipython3
+data = nap.load_file(path)
+print(data)
+```
+
+In the NWB file, the calcium traces are saved the RoiResponseSeries field. Let's save them in a variable called 'transients' and print it.
+
+
+```{code-cell} ipython3
+transients = data['RoiResponseSeries']
+print(transients)
+```
+
+`transients` is a [`TsdFrame`](https://pynapple.org/generated/pynapple.core.time_series.TsdTensor.html). Each column contains the activity of one neuron.
+
+The mouse was recorded for a 20 minute recording epoch as we can see from the `time_support` property of the `transients` object.
+
+
+```{code-cell} ipython3
+ep = transients.time_support
+print(ep)
+```
+
+There are a few different ways we can explore the data. First, let's inspect the raw calcium traces for neurons 4 and 35 for the first 250 seconds of the experiment.
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 2, figsize=(12, 4))
+ax[0].plot(transients[:, 4].get(0,250))
+ax[0].set_ylabel("Firing rate (Hz)")
+ax[0].set_title("Trace 4")
+ax[0].set_xlabel("Time(s)")
+ax[1].plot(transients[:, 35].get(0,250))
+ax[1].set_title("Trace 35")
+ax[1].set_xlabel("Time(s)")
+plt.tight_layout()
+```
+
+You can see that the calcium signals are both nonnegative, and noisy. One (neuron 4) has much higher SNR than the other. We cannot typically resolve individual action potentials, but instead see slow calcium fluctuations that result from an unknown underlying electrical signal (estimating the spikes from calcium traces is known as _deconvolution_ and is beyond the scope of this demo).
+
+
+
+
+We can also plot tuning curves, plotting mean calcium activity as a function of head direction, using the function [`compute_1d_tuning_curves_continuous`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_1d_tuning_curves_continuous).
+Here `data['ry']` is a [`Tsd`](https://pynapple.org/generated/pynapple.core.time_series.Tsd.html) that contains the angular head-direction of the animal between 0 and 2$\pi$.
+
+
+```{code-cell} ipython3
+tcurves = nap.compute_1d_tuning_curves_continuous(transients, data['ry'], 120)
+```
+
+The function returns a pandas DataFrame. Let's plot the tuning curves for neurons 4 and 35.
+
+
+```{code-cell} ipython3
+fig, ax = plt.subplots(1, 2, figsize=(12, 4))
+ax[0].plot(tcurves.iloc[:, 4])
+ax[0].set_xlabel("Angle (rad)")
+ax[0].set_ylabel("Firing rate (Hz)")
+ax[0].set_title("Trace 4")
+ax[1].plot(tcurves.iloc[:, 35])
+ax[1].set_xlabel("Angle (rad)")
+ax[1].set_title("Trace 35")
+plt.tight_layout()
+```
+
+As a first processing step, let's bin the calcium traces to a 100ms resolution.
+
+
+```{code-cell} ipython3
+Y = transients.bin_average(0.1, ep)
+```
+
+We can visualize the downsampled transients for the first 50 seconds of data.
+
+
+```{code-cell} ipython3
+plt.figure()
+plt.plot(transients[:,0].get(0, 50), linewidth=5, label="30 Hz")
+plt.plot(Y[:,0].get(0, 50), '--', linewidth=2, label="10 Hz")
+plt.xlabel("Time (s)")
+plt.ylabel("Fluorescence")
+plt.legend()
+plt.show()
+```
+
+The downsampling did not destroy the fast transient dynamics, so seems fine to use. We can now move on to using NeMoS to fit a model.
+
+
+
+
+## Basis instantiation
+
+We can define a cyclic-BSpline for capturing the encoding of the heading angle, and a log-spaced raised cosine basis for the coupling filters between neurons. Note that we are not including a self-coupling (spike history) filter, because in practice we have found it results in overfitting.
+
+We can combine the two bases.
+
+
+```{code-cell} ipython3
+heading_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12)
+coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=10)
+```
+
+We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an [`AdditiveBasis`](nemos.basis.AdditiveBasis) object).
+
+
+```{code-cell} ipython3
+heading_basis.identifiability_constraints = True
+basis = heading_basis + coupling_basis
+```
+
+## Gamma GLM
+
+Until now, we have been modeling spike trains, and have used a Poisson distribution for the observation model. With calcium traces, things are quite different: we no longer have counts but continuous signals, so the Poisson assumption is no longer appropriate. A Gaussian model is also not ideal since the calcium traces are non-negative. To satisfy these constraints, we will use a Gamma distribution from NeMoS with a soft-plus non linearity.
+:::{admonition} Non-linearity
+:class: note
+
+Different option are possible. With a soft-plus we are assuming an "additive" effect of the predictors, while an exponential non-linearity assumes multiplicative effects. Deciding which firing rate model works best is an empirical question. You can fit different configurations to see which one capture best the neural activity.
+:::
+
+
+```{code-cell} ipython3
+model = nmo.glm.GLM(
+ solver_kwargs=dict(tol=10**-13),
+ regularizer="Ridge",
+ regularizer_strength=0.02,
+ observation_model=nmo.observation_models.GammaObservations(inverse_link_function=jax.nn.softplus)
+)
+```
+
+We select one neuron to fit later, so remove it from the list of predictors
+
+
+```{code-cell} ipython3
+neu = 4
+selected_neurons = jnp.hstack(
+ (jnp.arange(0, neu), jnp.arange(neu+1, Y.shape[1]))
+)
+
+print(selected_neurons)
+```
+
+We need to bring the head-direction of the animal to the same size as the transients matrix.
+We can use the function [`bin_average`](https://pynapple.org/generated/pynapple.core.time_series.Tsd.bin_average.html#pynapple.core.time_series.Tsd.bin_average) of pynapple. Notice how we pass the parameter `ep`
+that is the `time_support` of the transients.
+
+
+```{code-cell} ipython3
+head_direction = data['ry'].bin_average(0.1, ep)
+```
+
+Let's check that `head_direction` and `Y` are of the same size.
+
+
+```{code-cell} ipython3
+print(head_direction.shape)
+print(Y.shape)
+```
+
+## Design matrix
+
+We can now create the design matrix by combining the head-direction of the animal and the activity of all other neurons.
+
+
+
+```{code-cell} ipython3
+X = basis.compute_features(head_direction, Y[:, selected_neurons])
+```
+
+## Train & test set
+
+Let's create a train epoch and a test epoch to fit and test the models. Since `X` is a pynapple time series, we can create [`IntervalSet`](https://pynapple.org/generated/pynapple.core.interval_set.IntervalSet.html) objects to restrict them into a train set and test set.
+
+
+```{code-cell} ipython3
+train_ep = nap.IntervalSet(start=X.time_support.start, end=X.time_support.get_intervals_center().t)
+test_ep = X.time_support.set_diff(train_ep) # Removing the train_ep from time_support
+
+print(train_ep)
+print(test_ep)
+```
+
+We can now restrict the `X` and `Y` to create our train set and test set.
+
+
+```{code-cell} ipython3
+Xtrain = X.restrict(train_ep)
+Ytrain = Y.restrict(train_ep)
+
+Xtest = X.restrict(test_ep)
+Ytest = Y.restrict(test_ep)
+```
+
+## Model fitting
+
+It's time to fit the model on the data from the neuron we left out.
+
+
+```{code-cell} ipython3
+model.fit(Xtrain, Ytrain[:, neu])
+```
+
+## Model comparison
+
+
+
+
+We can compare this to scikit-learn [linear regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html).
+
+
+```{code-cell} ipython3
+mdl = LinearRegression()
+valid = ~jnp.isnan(Xtrain.d.sum(axis=1)) # Scikit learn does not like nans.
+mdl.fit(Xtrain[valid], Ytrain[valid, neu])
+```
+
+We now have 2 models we can compare. Let's predict the activity of the neuron during the test epoch.
+
+
+```{code-cell} ipython3
+yp = model.predict(Xtest)
+ylreg = mdl.predict(Xtest) # Notice that this is not a pynapple object.
+```
+
+Unfortunately scikit-learn has not adopted pynapple yet. Let's convert `ylreg` to a pynapple object to make our life easier.
+
+
+```{code-cell} ipython3
+ylreg = nap.Tsd(t=yp.t, d=ylreg, time_support = yp.time_support)
+```
+
+Let's plot the predicted activity for the first 60 seconds of data.
+
+
+```{code-cell} ipython3
+# mkdocs_gallery_thumbnail_number = 3
+
+ep_to_plot = nap.IntervalSet(test_ep.start+20, test_ep.start+80)
+
+plt.figure()
+plt.plot(Ytest[:,neu].restrict(ep_to_plot), "r", label="true", linewidth=2)
+plt.plot(yp.restrict(ep_to_plot), "k", label="gamma-nemos", alpha=1)
+plt.plot(ylreg.restrict(ep_to_plot), "g", label="linreg-sklearn", alpha=0.5)
+plt.legend(loc='best')
+plt.xlabel("Time (s)")
+plt.ylabel("Fluorescence")
+plt.show()
+```
+
+While there is some variability in the fit for both models, one advantage of the gamma distribution is clear: the nonnegativity constraint is followed with the data.
+ This is required for using GLMs to predict the firing rate, which must be positive, in response to simulated inputs. See Peyrache et al. 2018[$^{[1]}$](#ref-1) for an example of simulating activity with a GLM.
+
+Another way to compare models is to compute tuning curves. Here we use the function [`compute_1d_tuning_curves_continuous`](https://pynapple.org/generated/pynapple.process.tuning_curves.html#pynapple.process.tuning_curves.compute_1d_tuning_curves_continuous) from pynapple.
+
+
+```{code-cell} ipython3
+real_tcurves = nap.compute_1d_tuning_curves_continuous(transients, data['ry'], 120, ep=test_ep)
+gamma_tcurves = nap.compute_1d_tuning_curves_continuous(yp, data['ry'], 120, ep=test_ep)
+linreg_tcurves = nap.compute_1d_tuning_curves_continuous(ylreg, data['ry'], 120, ep=test_ep)
+```
+
+Let's plot them.
+
+
+```{code-cell} ipython3
+fig = plt.figure()
+plt.plot(real_tcurves[neu], "r", label="true", linewidth=2)
+plt.plot(gamma_tcurves, "k", label="gamma-nemos", alpha=1)
+plt.plot(linreg_tcurves, "g", label="linreg-sklearn", alpha=0.5)
+plt.legend(loc='best')
+plt.ylabel("Fluorescence")
+plt.xlabel("Head-direction (rad)")
+plt.show()
+```
+
+```{code-cell} ipython3
+:tags: [hide-input]
+
+# save image for thumbnail
+from pathlib import Path
+import os
+
+root = os.environ.get("READTHEDOCS_OUTPUT")
+if root:
+ path = Path(root) / "html/_static/thumbnails/tutorials"
+# if local store in assets
+else:
+ path = Path("../_build/html/_static/thumbnails/tutorials")
+
+# make sure the folder exists if run from build
+if root or Path("../_build/html/_static").exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+if path.exists():
+ fig.savefig(path / "plot_06_calcium_imaging.svg")
+```
+
+
+:::{admonition} Gamma-GLM for Calcium Imaging Analysis
+:class: note
+
+Using Gamma-GLMs for fitting calcium imaging data is still in early stages, and hasn't been through
+the levels of review and validation that they have for fitting spike data. Users should consider
+this a relatively unexplored territory, and we hope that we hope that NeMoS will help researchers
+explore this new space of models.
+:::
+
+## References
+
+[1] Peyrache, A., Schieferstein, N. & Buzsáki, G. Transformation of the head-direction signal into a spatial code. Nat Commun 8, 1752 (2017). https://doi.org/10.1038/s41467-017-01908-3
diff --git a/docs/tutorials/plot_06_calcium_imaging.py b/docs/tutorials/plot_06_calcium_imaging.py
deleted file mode 100644
index 54d83666..00000000
--- a/docs/tutorials/plot_06_calcium_imaging.py
+++ /dev/null
@@ -1,319 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Fit Calcium Imaging
-============
-
-
-For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging at 30Hz using the genetically encoded calcium indicator GCaMP6f). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction.
-
-The data were collected by Sofia Skromne Carrasco from the [Peyrache Lab](https://www.peyrachelab.com/).
-
-"""
-import warnings
-from warnings import catch_warnings
-
-import jax
-import jax.numpy as jnp
-import matplotlib.pyplot as plt
-import numpy as np
-import pynapple as nap
-from sklearn.linear_model import LinearRegression
-
-import nemos as nmo
-from nemos.identifiability_constraints import (
- apply_identifiability_constraints_by_basis_component,
-)
-
-# ignore dtype warnings
-nap.nap_config.suppress_conversion_warnings = True
-warnings.filterwarnings("ignore", category=UserWarning, message="The feature matrix is not of dtype")
-
-
-
-# %%
-# configure plots
-plt.style.use(nmo.styles.plot_style)
-
-
-# %%
-# ## Data Streaming
-#
-# Here we load the data from OSF. The data is a NWB file.
-
-path = nmo.fetch.fetch_data("A0670-221213.nwb")
-
-# %%
-# ***
-# ## pynapple preprocessing
-#
-# Now that we have the file, let's load the data. The NWB file contains multiple entries.
-data = nap.load_file(path)
-print(data)
-
-# %%
-# In the NWB file, the calcium traces are saved the RoiResponseSeries field. Let's save them in a variable called 'transients' and print it.
-
-transients = data['RoiResponseSeries']
-print(transients)
-
-# %%
-# `transients` is a `TsdFrame`. Each column contains the activity of one neuron.
-#
-# The mouse was recorded for a 20 minute recording epoch as we can see from the `time_support` property of the `transients` object.
-
-ep = transients.time_support
-print(ep)
-
-
-# %%
-# There are a few different ways we can explore the data. First, let's inspect the raw calcium traces for neurons 4 and 35 for the first 250 seconds of the experiment.
-
-fig, ax = plt.subplots(1, 2, figsize=(12, 4))
-ax[0].plot(transients[:, 4].get(0,250))
-ax[0].set_ylabel("Firing rate (Hz)")
-ax[0].set_title("Trace 4")
-ax[0].set_xlabel("Time(s)")
-ax[1].plot(transients[:, 35].get(0,250))
-ax[1].set_title("Trace 35")
-ax[1].set_xlabel("Time(s)")
-plt.tight_layout()
-
-# %%
-# You can see that the calcium signals are both nonnegative, and noisy. One (neuron 4) has much higher SNR than the other. We cannot typically resolve individual action potentials, but instead see slow calcium fluctuations that result from an unknown underlying electrical signal (estimating the spikes from calcium traces is known as _deconvolution_ and is beyond the scope of this demo).
-
-# %%
-# We can also plot tuning curves, plotting mean calcium activity as a function of head direction, using the function `compute_1d_tuning_curves_continuous`.
-# Here `data['ry']` is a `Tsd` that contains the angular head-direction of the animal between 0 and 2$\pi$.
-
-tcurves = nap.compute_1d_tuning_curves_continuous(transients, data['ry'], 120)
-
-
-
-# %%
-# The function returns a pandas DataFrame. Let's plot the tuning curves for neurons 4 and 35.
-
-fig, ax = plt.subplots(1, 2, figsize=(12, 4))
-ax[0].plot(tcurves.iloc[:, 4])
-ax[0].set_xlabel("Angle (rad)")
-ax[0].set_ylabel("Firing rate (Hz)")
-ax[0].set_title("Trace 4")
-ax[1].plot(tcurves.iloc[:, 35])
-ax[1].set_xlabel("Angle (rad)")
-ax[1].set_title("Trace 35")
-plt.tight_layout()
-
-# %%
-# As a first processing step, let's bin the calcium traces to a 100ms resolution.
-
-Y = transients.bin_average(0.1, ep)
-
-# %%
-# We can visualize the downsampled transients for the first 50 seconds of data.
-plt.figure()
-plt.plot(transients[:,0].get(0, 50), linewidth=5, label="30 Hz")
-plt.plot(Y[:,0].get(0, 50), '--', linewidth=2, label="10 Hz")
-plt.xlabel("Time (s)")
-plt.ylabel("Fluorescence")
-plt.legend()
-plt.show()
-
-# %%
-# The downsampling did not destroy the fast transient dynamics, so seems fine to use. We can now move on to using NeMoS to fit a model.
-
-# %%
-# ## Basis instantiation
-#
-# We can define a cyclic-BSpline for capturing the encoding of the heading angle, and a log-spaced raised cosine basis for the coupling filters between neurons. Note that we are not including a self-coupling (spike history) filter, because in practice we have found it results in overfitting.
-#
-# We can combine the two bases.
-
-heading_basis = nmo.basis.CyclicBSplineBasis(n_basis_funcs=12)
-
-# define a basis that expect all the other neurons as predictors, i.e. shape (num_samples, num_neurons - 1)
-num_neurons = Y.shape[1]
-ws = 10
-coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=ws)
-
-# %%
-# We need to combine the bases.
-basis = heading_basis + coupling_basis
-
-
-# %%
-# ## Gamma GLM
-#
-# Until now, we have been modeling spike trains, and have used a Poisson distribution for the observation model. With calcium traces, things are quite different: we no longer have counts but continuous signals, so the Poisson assumption is no longer appropriate. A Gaussian model is also not ideal since the calcium traces are non-negative. To satisfy these constraints, we will use a Gamma distribution from NeMoS with a soft-plus non linearity.
-# !!! note "Non-linearity"
-# Different option are possible. With a soft-plus we are assuming an "additive" effect of the predictors, while an exponential non-linearity assumes multiplicative effects. Deciding which firing rate model works best is an empirical question. You can fit different configurations to see which one capture best the neural activity.
-
-model = nmo.glm.GLM(
- solver_kwargs=dict(tol=10**-13),
- regularizer="Ridge",
- regularizer_strength=0.02,
- observation_model=nmo.observation_models.GammaObservations(inverse_link_function=jax.nn.softplus)
-)
-
-
-# %%
-# We select one neuron to fit later, so remove it from the list of predictors
-
-neu = 4
-selected_neurons = jnp.hstack(
- (jnp.arange(0, neu), jnp.arange(neu+1, Y.shape[1]))
-)
-
-print(selected_neurons)
-
-# %%
-# We need to bring the head-direction of the animal to the same size as the transients matrix.
-# We can use the function `bin_average` of pynapple. Notice how we pass the parameter `ep`
-# that is the `time_support` of the transients.
-
-head_direction = data['ry'].bin_average(0.1, ep)
-
-# %%
-# Let's check that `head_direction` and `Y` are of the same size.
-print(head_direction.shape)
-print(Y.shape)
-
-# %%
-# ## Design matrix
-#
-# We can now create the design matrix by combining the head-direction of the animal and the activity of all other neurons.
-X = basis.compute_features(head_direction, Y[:, selected_neurons])
-
-# %%
-#
-# Before we use this design matrix to fit the population, we need to take a brief detour
-# into linear algebra. Depending on your design matrix is constructed, it is likely to
-# be rank-deficient, in which case it has a null space. Practically, that means that
-# there are infinitely many different sets of parameters that predict the same firing
-# rate. If you want to interpret your parameters, this is bad!
-#
-# While this multiplicity of solutions is always a potential issue when fitting models,
-# it is particularly relevant when using basis objects in nemos, as many of our basis
-# sets completely tile the input space (i.e., summing across all $n$ basis functions
-# returns 1 everywhere), which, when combined with the intercept term always present in
-# the GLM (i.e., the base firing rate), will give you a rank-deficient matrix.
-#
-# We thus recommend that you always check the rank of your design matrix and provide
-# some tools to drop the linearly-dependent columns, if necessary, which will guarantee
-# that your design matrix is full rank and thus that there is one unique solution.
-#
-# !!! tip "Linear Algebra"
-#
-# To read more about matrix rank, see
-# [Wikipedia](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Main_definitions).
-# Gil Strang's Linear Algebra course, [available for free
-# online](https://web.mit.edu/18.06/www/), and [NeuroMatch
-# Academy](https://compneuro.neuromatch.io/tutorials/W0D3_LinearAlgebra/student/W0D3_Tutorial2.html)
-# are also great resources.
-#
-# In this case, we are using the CyclicBSpline basis functions, which uniformly tile and
-# thus will result in a rank-deficient matrix. Therefore, we will use a utility function
-# to drop a column from the matrix and make it full-rank:
-
-# The number of features is the number of columns plus one (for the intercept)
-print(f"Number of features in the rank-deficient design matrix: {X.shape[1] + 1}")
-X, _ = apply_identifiability_constraints_by_basis_component(basis, X)
-# We have dropped one column
-print(f"Number of features in the full-rank design matrix: {X.shape[1] + 1}")
-
-# %%
-# ## Train & test set
-#
-# Let's create a train epoch and a test epoch to fit and test the models. Since `X` is a pynapple time series, we can create `IntervalSet` objects to restrict them into a train set and test set.
-
-train_ep = nap.IntervalSet(start=X.time_support.start, end=X.time_support.get_intervals_center().t)
-test_ep = X.time_support.set_diff(train_ep) # Removing the train_ep from time_support
-
-print(train_ep)
-print(test_ep)
-
-# %%
-# We can now restrict the `X` and `Y` to create our train set and test set.
-Xtrain = X.restrict(train_ep)
-Ytrain = Y.restrict(train_ep)
-
-Xtest = X.restrict(test_ep)
-Ytest = Y.restrict(test_ep)
-
-# %%
-# ## Model fitting
-#
-# It's time to fit the model on the data from the neuron we left out.
-
-model.fit(Xtrain, Ytrain[:, neu])
-
-
-# %%
-# ## Model comparison
-
-# %%
-# We can compare this to scikit-learn [linear regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html).
-
-mdl = LinearRegression()
-valid = ~jnp.isnan(Xtrain.d.sum(axis=1)) # Scikit learn does not like nans.
-mdl.fit(Xtrain[valid], Ytrain[valid, neu])
-
-
-# %%
-# We now have 2 models we can compare. Let's predict the activity of the neuron during the test epoch.
-
-yp = model.predict(Xtest)
-ylreg = mdl.predict(Xtest) # Notice that this is not a pynapple object.
-
-# %%
-# Unfortunately scikit-learn has not adopted pynapple yet. Let's convert `ylreg` to a pynapple object to make our life easier.
-
-ylreg = nap.Tsd(t=yp.t, d=ylreg, time_support = yp.time_support)
-
-
-# %%
-# Let's plot the predicted activity for the first 60 seconds of data.
-
-# mkdocs_gallery_thumbnail_number = 3
-
-ep_to_plot = nap.IntervalSet(test_ep.start+20, test_ep.start+80)
-
-plt.figure()
-plt.plot(Ytest[:,neu].restrict(ep_to_plot), "r", label="true", linewidth=2)
-plt.plot(yp.restrict(ep_to_plot), "k", label="gamma-nemos", alpha=1)
-plt.plot(ylreg.restrict(ep_to_plot), "g", label="linreg-sklearn", alpha=0.5)
-plt.legend(loc='best')
-plt.xlabel("Time (s)")
-plt.ylabel("Fluorescence")
-plt.show()
-
-# %%
-# While there is some variability in the fit for both models, one advantage of the gamma distribution is clear: the nonnegativity constraint is followed with the data.
-# This is required for using GLMs to predict the firing rate, which must be positive, in response to simulated inputs. See Peyrache et al. 2018[$^{[1]}$](#ref-1) for an example of simulating activity with a GLM.
-#
-# Another way to compare models is to compute tuning curves. Here we use the function `compute_1d_tuning_curves_continuous` from pynapple.
-
-real_tcurves = nap.compute_1d_tuning_curves_continuous(transients, data['ry'], 120, ep=test_ep)
-gamma_tcurves = nap.compute_1d_tuning_curves_continuous(yp, data['ry'], 120, ep=test_ep)
-linreg_tcurves = nap.compute_1d_tuning_curves_continuous(ylreg, data['ry'], 120, ep=test_ep)
-
-# %%
-# Let's plot them.
-
-plt.figure()
-plt.plot(real_tcurves[neu], "r", label="true", linewidth=2)
-plt.plot(gamma_tcurves, "k", label="gamma-nemos", alpha=1)
-plt.plot(linreg_tcurves, "g", label="linreg-sklearn", alpha=0.5)
-plt.legend(loc='best')
-plt.ylabel("Fluorescence")
-plt.xlabel("Head-direction (rad)")
-plt.show()
-
-#%%
-# !!! note "Gamma-GLM for Calcium Imaging Analysis"
-# Using Gamma-GLMs for fitting calcium imaging data is still in early stages, and hasn't been through
-# the levels of review and validation that they have for fitting spike data. Users should consider
-# this a relatively unexplored territory, and we hope that we hope that NeMoS will help researchers
-# explore this new space of models.
-#
-# ## References
-#
-# [1] Peyrache, A., Schieferstein, N. & Buzsáki, G. Transformation of the head-direction signal into a spatial code. Nat Commun 8, 1752 (2017). https://doi.org/10.1038/s41467-017-01908-3
diff --git a/mkdocs.yml b/mkdocs.yml
deleted file mode 100644
index b7deaeb7..00000000
--- a/mkdocs.yml
+++ /dev/null
@@ -1,105 +0,0 @@
-site_name: NeMoS
-repo_url: https://github.com/flatironinstitute/nemos
-
-strict: true
-
-theme:
- name: 'material' # The theme name, using the 'material' theme
- favicon: assets/NeMoS_favicon.ico
- logo: assets/NeMoS_Icon_CMYK_White.svg
- palette:
- primary: 'light blue' # The primary color palette for the theme
- features:
- - navigation.tabs # Enable navigation tabs feature for the theme
- markdown_extensions:
- - attr_list
- - admonition
- - tables
- - pymdownx.emoji:
- emoji_index: !!python/name:material.extensions.emoji.twemoji
- emoji_generator: !!python/name:material.extensions.emoji.to_svg
-
- features:
- - content.tabs.link
- - content.code.annotate
- - content.code.copy
- - announce.dismiss
- - navigation.tabs
- - navigation.instant
- - navigation.instant.prefetch
- - navigation.instant.preview
- - navigation.instant.progress
- - navigation.path
- - navigation.sections
- - navigation.top
- - search.highlights
- - search.share
- - search.suggest
-
-validation:
- omitted_files: info
- absolute_links: warn # Or 'relative_to_docs' - new in MkDocs 1.6
- unrecognized_links: info
- anchors: warn # New in MkDocs 1.6
-
-markdown_extensions:
- - md_in_html
- - footnotes
- - pymdownx.superfences
- - pymdownx.details # add notes toggleable notes ???
- - pymdownx.tabbed:
- alternate_style: true
- - toc:
- title: On this page
-
-
-plugins:
- - search
- - gallery:
- conf_script: docs/gallery_conf.py
- # These directories contain the input .py scripts for mkdocs-gallery
- examples_dirs:
- - docs/background
- - docs/how_to_guide
- - docs/tutorials
- # These are the output directories for mkdocs-gallery, and correspond
- # directly to the input dirs listed above. their contents should not be
- # touched
- gallery_dirs:
- - docs/generated/background
- - docs/generated/how_to_guide
- - docs/generated/tutorials
- - gen-files:
- scripts:
- - docs/gen_ref_pages.py # Specify the script to generate the code reference pages
- - literate-nav:
- nav_file: docs/SUMMARY.md # Specify the navigation file for literate-style navigation
- - section-index # Enable the section-index plugin for generating a section index
- - mkdocstrings:
- handlers:
- python:
- options:
- docstring_style: numpy
- show_source: true
- members_order: source
- inherited_members: true
-
-extra_javascript:
- - javascripts/katex.js
- - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.js
- - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/contrib/auto-render.min.js
-
-extra_css:
- - https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.16.7/katex.min.css
- - assets/stylesheets/extra.css
-
-nav:
- - Home: index.md # Link to the index.md file (home page)
- - Install: installation.md # Link to the installation.md file
- - Quickstart: quickstart.md
- - Background: generated/background # Link to the generated gallery Tutorials
- - How-To Guide: generated/how_to_guide # Link to the generated gallery tutorials
- - Tutorials: generated/tutorials # Link to the generated gallery tutorials
- - Getting Help: getting_help.md
- - API Guide: reference/ # Link to the reference/ directory
- - For Developers: developers_notes/ # Link to the developers notes
diff --git a/pyproject.toml b/pyproject.toml
index 1e4b7e4a..1937d7d3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,20 +57,29 @@ dev = [
"seaborn", # Required by doctest for _documentation_utils module
]
docs = [
- "mkdocs", # Documentation generator
- "mkdocstrings[python]", # Python-specific plugin for mkdocs
- "mkdocs-section-index", # Plugin for generating a section index in mkdocs
- "mkdocs-gen-files", # Plugin for generating additional files in mkdocs
- "mkdocs-literate-nav>=0.6.1", # Plugin for literate-style navigation in mkdocs
- "mkdocs-gallery", # Plugin for adding image galleries to mkdocs
- "mkdocs-material", # Material theme for mkdocs
- "mkdocs-autorefs>=0.5",
- "mkdocs-linkcheck", # Check relative links
+ "numpydoc",
+ "sphinx",
+ "pydata-sphinx-theme",
+ "sphinx-autodoc-typehints",
+ "sphinx-copybutton",
+ "sphinx-design",
+ "sphinx-issues",
+ "sphinxcontrib-apidoc",
+ "sphinx-togglebutton",
+ "sphinx_code_tabs",
+ "sphinxemoji",
+ "myst-parser",
+ "myst-nb",
+ "dandi",
+ "sphinx-autobuild",
+ "sphinx-contributors",
"scikit-learn",
"dandi",
"matplotlib>=3.7",
"seaborn",
"pooch",
+ "ipywidgets",
+ "ipykernel",
]
examples = [
"scikit-learn",
diff --git a/src/nemos/_documentation_utils/plotting.py b/src/nemos/_documentation_utils/plotting.py
index 5514787f..086ccb2c 100644
--- a/src/nemos/_documentation_utils/plotting.py
+++ b/src/nemos/_documentation_utils/plotting.py
@@ -292,6 +292,7 @@ def current_injection_plot(
bbox_to_anchor=(0.5, -0.4),
bbox_transform=zoom_axes[1].transAxes,
)
+ return fig
def plot_weighted_sum_basis(time, weights, basis_kernels, basis_coeff):
diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py
index c60da8da..692f7067 100644
--- a/src/nemos/base_regressor.py
+++ b/src/nemos/base_regressor.py
@@ -608,7 +608,7 @@ def _get_solver_class(solver_name: str):
return solver_class
- def optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
+ def _optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
"""
Compute and update solver parameters with optimal defaults if available.
@@ -635,7 +635,7 @@ def optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
# get the model specific configs
compute_defaults, compute_l_smooth, strong_convexity = (
- self.get_optimal_solver_params_config()
+ self._get_optimal_solver_params_config()
)
if compute_defaults and compute_l_smooth:
# Check if the user has provided batch size or stepsize, or else use None
@@ -658,6 +658,6 @@ def optimize_solver_params(self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray) -> dict:
return new_solver_kwargs
@abstractmethod
- def get_optimal_solver_params_config(self):
+ def _get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
pass
diff --git a/src/nemos/basis.py b/src/nemos/basis.py
index c709177d..164936c7 100644
--- a/src/nemos/basis.py
+++ b/src/nemos/basis.py
@@ -77,8 +77,8 @@ def min_max_rescale_samples(
sample_pts:
The original samples.
bounds:
- Sample bounds. `bounds[0]` and `bounds[1]` are mapped to 0 and 1, respectively.
- Default are `min(sample_pts), max(sample_pts)`.
+ Sample bounds. ``bounds[0]`` and ``bounds[1]`` are mapped to 0 and 1, respectively.
+ Default are ``min(sample_pts), max(sample_pts)``.
Warns
-----
@@ -107,7 +107,7 @@ def min_max_rescale_samples(
class TransformerBasis:
- """Basis as `scikit-learn` transformers.
+ """Basis as ``scikit-learn`` transformers.
This class abstracts the underlying basis function details, offering methods
similar to scikit-learn's transformers but specifically designed for basis
@@ -115,14 +115,14 @@ class TransformerBasis:
of the basis functions), transforming data (applying the basis functions to
data), and both fitting and transforming in one step.
- `TransformerBasis`, unlike `Basis`, is compatible with scikit-learn pipelining and
+ ``TransformerBasis``, unlike ``Basis``, is compatible with scikit-learn pipelining and
model selection, enabling the cross-validation of the basis type and parameters,
- for example `n_basis_funcs`. See the example section below.
+ for example ``n_basis_funcs``. See the example section below.
Parameters
----------
basis :
- A concrete subclass of `Basis`.
+ A concrete subclass of ``Basis``.
Examples
--------
@@ -132,7 +132,6 @@ class TransformerBasis:
>>> from sklearn.model_selection import GridSearchCV
>>> import numpy as np
>>> np.random.seed(123)
-
>>> # Generate data
>>> num_samples, num_features = 10000, 1
>>> x = np.random.normal(size=(num_samples, )) # raw time series
@@ -140,7 +139,6 @@ class TransformerBasis:
>>> features = basis.compute_features(x) # basis transformed time series
>>> weights = np.random.normal(size=basis.n_basis_funcs) # true weights
>>> y = np.random.poisson(np.exp(features.dot(weights))) # spike counts
-
>>> # transformer can be used in pipelines
>>> transformer = TransformerBasis(basis)
>>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),])
@@ -162,9 +160,9 @@ def __init__(self, basis: Basis):
def _unpack_inputs(X: FeatureMatrix):
"""Unpack impute without using transpose.
- Unpack horizontally stacked inputs using slicing. This works gracefully with `pynapple`,
- returning a list of Tsd objects. Attempt to unpack using *X.T will raise a `pynapple`
- exception since `pynapple` assumes that the time axis is the first axis.
+ Unpack horizontally stacked inputs using slicing. This works gracefully with ``pynapple``,
+ returning a list of Tsd objects. Attempt to unpack using *X.T will raise a ``pynapple``
+ exception since ``pynapple`` assumes that the time axis is the first axis.
Parameters
----------
@@ -201,10 +199,8 @@ def fit(self, X: FeatureMatrix, y=None):
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
-
>>> # Example input
>>> X = np.random.normal(size=(100, 2))
-
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
@@ -233,21 +229,17 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
-
>>> # Example input
>>> X = np.random.normal(size=(10000, 2))
-
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_
-
>>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)
-
>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1])
"""
@@ -280,14 +272,11 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
-
>>> # Example input
>>> X = np.random.normal(size=(100, 1))
-
>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
-
>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X)
"""
@@ -343,7 +332,7 @@ def __setattr__(self, name: str, value) -> None:
Raises
------
ValueError
- If the attribute being set is not `_basis` or an attribute of `_basis`.
+ If the attribute being set is not ``_basis`` or an attribute of ``_basis``.
Examples
--------
@@ -389,7 +378,7 @@ def set_params(self, **parameters) -> TransformerBasis:
"""
Set TransformerBasis parameters.
- When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly
+ When used with ``sklearn.model_selection``, users can set either the ``_basis`` attribute directly
or the parameters of the underlying Basis, but not both.
Examples
@@ -397,7 +386,6 @@ def set_params(self, **parameters) -> TransformerBasis:
>>> from nemos.basis import BSplineBasis, MSplineBasis, TransformerBasis
>>> basis = MSplineBasis(10)
>>> transformer_basis = TransformerBasis(basis=basis)
-
>>> # setting parameters of _basis is allowed
>>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs)
8
@@ -508,27 +496,30 @@ class Basis(Base, abc.ABC):
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Raises
------
ValueError:
- - If `mode` is not 'eval' or 'conv'.
- - If `kwargs` are not None and `mode =="eval".
- - If `kwargs` include parameters not recognized or do not have
- default values in `create_convolutional_predictor`.
- - If `axis` different from 0 is provided as a keyword argument (samples must always be in the first axis).
+ If ``mode`` is not 'eval' or 'conv'.
+ ValueError:
+ If ``kwargs`` are not None and ``mode =="eval"``.
+ ValueError:
+ If ``kwargs`` include parameters not recognized or do not have
+ default values in ``create_convolutional_predictor``.
+ ValueError:
+ If ``axis`` different from 0 is provided as a keyword argument (samples must always be in the first axis).
"""
def __init__(
@@ -578,11 +569,11 @@ def _check_convolution_kwargs(self):
Raises
------
ValueError:
- - If `self._conv_kwargs` are not None and `mode =="eval".
- - If `axis` is provided as an argument, and it is different from 0
+ - If ``self._conv_kwargs`` are not None and ``mode =="eval"``.
+ - If ``axis`` is provided as an argument, and it is different from 0
(samples must always be in the first axis).
- - If `self._conv_kwargs` include parameters not recognized or that do not have
- default values in `create_convolutional_predictor`.
+ - If ``self._conv_kwargs`` include parameters not recognized or that do not have
+ default values in ``create_convolutional_predictor``.
"""
if self._mode == "eval" and self._conv_kwargs:
raise ValueError(
@@ -622,28 +613,34 @@ def _check_convolution_kwargs(self):
@property
def n_output_features(self) -> int | None:
"""
- Read-only property indicating the number of features returned by the basis, when available.
+ Number of features returned by the basis.
Notes
-----
The number of output features can be determined only when the number of inputs
- provided to the basis is known. Therefore, before the first call to `compute_features`,
- this property will return `None`. After that call, `n_output_features` will be available.
+ provided to the basis is known. Therefore, before the first call to ``compute_features``,
+ this property will return ``None``. After that call, ``n_output_features`` will be available.
"""
return self._n_output_features
@property
def label(self) -> str:
+ """Label for the basis."""
return self._label
@property
def n_basis_input(self) -> tuple | None:
+ """Number of expected inputs.
+
+ The number of inputs ``compute_feature`` expects.
+ """
if self._n_basis_input is None:
return
return self._n_basis_input
@property
def n_basis_funcs(self):
+ """Number of basis functions."""
return self._n_basis_funcs
@n_basis_funcs.setter
@@ -658,6 +655,7 @@ def n_basis_funcs(self, value):
@property
def bounds(self):
+ """Range of values covered by the basis."""
return self._bounds
@bounds.setter
@@ -684,10 +682,15 @@ def bounds(self, values: Union[None, Tuple[float, float]]):
@property
def mode(self):
+ """Mode of operation, either ``"conv"`` or ``"eval"``."""
return self._mode
@property
def window_size(self):
+ """Window size as number of samples.
+
+ Duration of the convolutional kernel in number of samples.
+ """
return self._window_size
@window_size.setter
@@ -714,9 +717,9 @@ def window_size(self, window_size):
@staticmethod
def _apply_identifiability_constraints(X: NDArray):
- """Apply identifiability constraints to a design matrix `X`.
+ """Apply identifiability constraints to a design matrix ``X``.
- Removes columns from `X` until `[1, X]` is full rank to ensure the uniqueness
+ Removes columns from ``X`` until ``[1, X]`` is full rank to ensure the uniqueness
of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly
crucial for models using bases like BSplines and CyclicBspline, which, due to their
construction, sum to 1 and can cause rank deficiency when combined with an intercept.
@@ -768,16 +771,16 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
-------
:
A matrix with the transformed features. The shape of the output depends on the operation mode:
- - If `mode == 'eval'`, the basis evaluated at the samples, or $b_i(*xi)$, where $b_i$ is a
- basis element. xi[k] must be a one-dimensional array or a pynapple Tsd.
+ - If ``mode == 'eval'``, the basis evaluated at the samples, or :math:`b_i(*xi)`, where :math:`b_i`
+ is a basis element. ``xi[k]`` must be a one-dimensional array or a pynapple Tsd.
- - If `mode == 'conv'`, a bank of basis filters (created by calling fit) is convolved with the
+ - If ``mode == 'conv'``, a bank of basis filters (created by calling fit) is convolved with the
samples. Samples can be a NDArray, or a pynapple Tsd/TsdFrame/TsdTensor. All the dimensions
except for the sample-axis are flattened, so that the method always returns a matrix.
For example, if samples are of shape (num_samples, 2, 3), the output will be
(num_samples, num_basis_funcs * 2 * 3).
- The time-axis can be specified at basis initialization by setting the keyword argument `axis`.
- For example, if `axis == 1` your samples should be (N1, num_samples N3, ...), the output of
+ The time-axis can be specified at basis initialization by setting the keyword argument ``axis``.
+ For example, if ``axis == 1`` your samples should be (N1, num_samples N3, ...), the output of
transform will be (num_samples, num_basis_funcs * N1 * N3 *...).
Raises
@@ -830,7 +833,6 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
--------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis
-
>>> # Generate data
>>> num_samples = 10000
>>> X = np.random.normal(size=(num_samples, )) # raw time series
@@ -878,7 +880,7 @@ def _set_kernel(self, *xi: ArrayLike) -> Basis:
-----
Subclasses implementing this method should detail the specifics of how the kernel is
computed and how the input parameters are utilized. If the basis operates in 'eval'
- mode exclusively, this method should simply return `self` without modification.
+ mode exclusively, this method should simply return ``self`` without modification.
"""
if self.mode == "conv":
self.kernel_ = self.__call__(np.linspace(0, 1, self.window_size))
@@ -924,7 +926,7 @@ def _get_samples(self, *n_samples: int) -> Generator[NDArray]:
Returns
-------
:
- A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by `n_samples`.
+ A generator yielding numpy arrays of linspaces from 0 to 1 of sizes specified by ``n_samples``.
"""
# handling of defaults when evaluating on a grid
# (i.e. when we cannot use max and min of samples)
@@ -983,8 +985,8 @@ def _check_has_kernel(self) -> None:
def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
"""Evaluate the basis set on a grid of equi-spaced sample points.
- The i-th axis of the grid will be sampled with n_samples[i] equi-spaced points.
- The method uses numpy.meshgrid with `indexing="ij"`, returning matrix indexing
+ The i-th axis of the grid will be sampled with ``n_samples[i]`` equi-spaced points.
+ The method uses numpy.meshgrid with ``indexing="ij"``, returning matrix indexing
instead of the default cartesian indexing, see Notes.
Parameters
@@ -998,24 +1000,25 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
*Xs :
A tuple of arrays containing the meshgrid values, one element for each of the n dimension of the grid,
where n equals to the number of inputs.
- The size of Xs[i] is (n_samples[0], ... , n_samples[n]).
+ The size of ``Xs[i]`` is ``(n_samples[0], ... , n_samples[n])``.
Y :
The basis function evaluated at the samples,
- shape (n_samples[0], ... , n_samples[n], number of basis).
+ shape ``(n_samples[0], ... , n_samples[n], number of basis)``.
Raises
------
ValueError
- - If the time point number is inconsistent between inputs or if the number of inputs doesn't match what
+ If the time point number is inconsistent between inputs or if the number of inputs doesn't match what
the Basis object requires.
- - If one of the n_samples is <= 0.
+ ValueError
+ If one of the n_samples is <= 0.
Notes
-----
- Setting "indexing = 'ij'" returns a meshgrid with matrix indexing. In the N-D case with inputs of size
- $M_1,...,M_N$, outputs are of shape $(M_1, M_2, M_3, ....,M_N)$.
+ Setting ``indexing = 'ij'`` returns a meshgrid with matrix indexing. In the N-D case with inputs of size
+ :math:`M_1,...,M_N`, outputs are of shape :math:`(M_1, M_2, M_3, ....,M_N)`.
This differs from the numpy.meshgrid default, which uses Cartesian indexing.
- For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$.
+ For the same input, Cartesian indexing would return an output of shape :math:`(M_2, M_1, M_3, ....,M_N)`.
Examples
--------
@@ -1152,7 +1155,7 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis:
Returns
-------
:
- The product of the basis with itself "exponent" times. Equivalent to self * self * ... * self.
+ The product of the basis with itself "exponent" times. Equivalent to ``self * self * ... * self``.
Raises
------
@@ -1211,19 +1214,19 @@ def _get_feature_slicing(
Calculate and return the slicing for features based on the input structure.
This method determines how to slice the features for different basis types.
- If the instance is of `AdditiveBasis` type, the slicing is calculated recursively
+ If the instance is of ``AdditiveBasis`` type, the slicing is calculated recursively
for each component basis. Otherwise, it determines the slicing based on
- the number of basis functions and `split_by_input` flag.
+ the number of basis functions and ``split_by_input`` flag.
Parameters
----------
n_inputs :
- The number of input basis for each component, by default it uses `self._n_basis_input`.
+ The number of input basis for each component, by default it uses ``self._n_basis_input``.
start_slice :
The starting index for slicing, by default it starts from 0.
split_by_input :
Flag indicating whether to split the slicing by individual inputs or not.
- If `False`, a single slice is generated for all inputs.
+ If ``False``, a single slice is generated for all inputs.
Returns
-------
@@ -1326,45 +1329,48 @@ def split_by_feature(
**How it works:**
- - If the basis expects an input shape `(n_samples, n_inputs)`, then the feature axis length will
- be `total_n_features = n_inputs * n_basis_funcs`. This axis is reshaped into dimensions
- `(n_inputs, n_basis_funcs)`.
- - If the basis expects an input of shape `(n_samples,)`, then the feature axis length will
- be `total_n_features = n_basis_funcs`. This axis is reshaped into `(1, n_basis_funcs)`.
+ - If the basis expects an input shape ``(n_samples, n_inputs)``, then the feature axis length will
+ be ``total_n_features = n_inputs * n_basis_funcs``. This axis is reshaped into dimensions
+ ``(n_inputs, n_basis_funcs)``.
- For example, if the input array `x` has shape `(1, 2, total_n_features, 4, 5)`,
- then after applying this method, it will be reshaped into `(1, 2, n_inputs, n_basis_funcs, 4, 5)`.
+ - If the basis expects an input of shape ``(n_samples,)``, then the feature axis length will
+ be ``total_n_features = n_basis_funcs``. This axis is reshaped into ``(1, n_basis_funcs)``.
- The specified axis (`axis`) determines where the split occurs, and all other dimensions
+ For example, if the input array ``x`` has shape ``(1, 2, total_n_features, 4, 5)``,
+ then after applying this method, it will be reshaped into ``(1, 2, n_inputs, n_basis_funcs, 4, 5)``.
+
+ The specified axis (``axis``) determines where the split occurs, and all other dimensions
remain unchanged. See the example section below for the most common use cases.
Parameters
----------
x :
The input array to be split, representing concatenated features, coefficients,
- or other data. The shape of `x` along the specified axis must match the total
- number of features generated by the basis, i.e., `self.n_output_features`.
+ or other data. The shape of ``x`` along the specified axis must match the total
+ number of features generated by the basis, i.e., ``self.n_output_features``.
**Examples:**
- - For a design matrix: `(n_samples, total_n_features)`
- - For model coefficients: `(total_n_features,)` or `(total_n_features, n_neurons)`.
+
+ - For a design matrix: ``(n_samples, total_n_features)``
+
+ - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``.
axis : int, optional
The axis along which to split the features. Defaults to 1.
- Use `axis=1` for design matrices (features along columns) and `axis=0` for
+ Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for
coefficient arrays (features along rows). All other dimensions are preserved.
Raises
------
ValueError
- If the shape of `x` along the specified axis does not match `self.n_output_features`.
+ If the shape of ``x`` along the specified axis does not match ``self.n_output_features``.
Returns
-------
dict
A dictionary where:
- **Key**: Label of the basis.
- - **Value**: the array reshaped to: `(..., n_inputs, n_basis_funcs, ...)
+ - **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)``
Examples
--------
@@ -1464,8 +1470,8 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis:
This function computes the number of inputs that are provided to the basis and uses
that number, and the n_basis_funcs to calculate the number of output features that
- `self.compute_features` will return. These quantities and the input shape (excluding the sample axis)
- are stored in `self._n_basis_input` and `self._n_output_features`, and `self._input_shape`
+ ``self.compute_features`` will return. These quantities and the input shape (excluding the sample axis)
+ are stored in ``self._n_basis_input`` and ``self._n_output_features``, and ``self._input_shape``
respectively.
Parameters
@@ -1481,15 +1487,15 @@ def _set_num_output_features(self, *xi: NDArray) -> Basis:
Raises
------
ValueError:
- If the number of inputs do not match `self._n_basis_input`, if `self._n_basis_input` was
+ If the number of inputs do not match ``self._n_basis_input``, if ``self._n_basis_input`` was
not None.
Notes
-----
- Once a `compute_features` is called, we enforce that for all subsequent calls of the method,
+ Once a ``compute_features`` is called, we enforce that for all subsequent calls of the method,
the input that the basis receives preserves the shape of all axes, except for the sample axis.
This condition guarantees the consistency of the feature axis, and therefore that
- `self.split_by_feature` behaves appropriately.
+ ``self.split_by_feature`` behaves appropriately.
"""
# Check that the input shape matches expectation
@@ -1524,7 +1530,7 @@ class AdditiveBasis(Basis):
Attributes
----------
- n_basis_funcs : int
+ n_basis_funcs :
Number of basis functions.
Examples
@@ -1533,16 +1539,15 @@ class AdditiveBasis(Basis):
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 2))
-
>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = basis_1 + basis_2
-
>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
+
"""
def __init__(self, basis1: Basis, basis2: Basis) -> None:
@@ -1641,7 +1646,7 @@ def _set_kernel(self, *xi: ArrayLike) -> Basis:
Parameters
----------
*xi:
- The sample inputs. Unused, necessary to conform to `scikit-learn` API.
+ The sample inputs. Unused, necessary to conform to ``scikit-learn`` API.
Returns
-------
@@ -1666,52 +1671,50 @@ def split_by_feature(
**How It Works:**
- Suppose the basis is made up of **m components**, each with $b_i$ basis functions and $n_i$ inputs.
- The total number of features, $N$, is calculated as:
+ Suppose the basis is made up of **m components**, each with :math:`b_i` basis functions and :math:`n_i` inputs.
+ The total number of features, :math:`N`, is calculated as:
- $$
- N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m
- $$
+ .. math::
+ N = b_1 \cdot n_1 + b_2 \cdot n_2 + \ldots + b_m \cdot n_m
- This method splits any axis of length $N$ into sub-arrays, one for each basis component.
+ This method splits any axis of length :math:`N` into sub-arrays, one for each basis component.
The sub-array for the i-th basis component is reshaped into dimensions
- $(n_i, b_i)$.
+ :math:`(n_i, b_i)`.
- For example, if the array shape is $(1, 2, N, 4, 5)$, then each split sub-array will have shape:
+ For example, if the array shape is :math:`(1, 2, N, 4, 5)`, then each split sub-array will have shape:
- $$
- (1, 2, n_i, b_i, 4, 5)
- $$
+ .. math::
+ (1, 2, n_i, b_i, 4, 5)
where:
- - $n_i$ represents the number of inputs associated with the i-th component,
- - $b_i$ represents the number of basis functions in that component.
+ - :math:`n_i` represents the number of inputs associated with the i-th component,
+ - :math:`b_i` represents the number of basis functions in that component.
- The specified axis (`axis`) determines where the split occurs, and all other dimensions
+ The specified axis (``axis``) determines where the split occurs, and all other dimensions
remain unchanged. See the example section below for the most common use cases.
Parameters
----------
x :
The input array to be split, representing concatenated features, coefficients,
- or other data. The shape of `x` along the specified axis must match the total
- number of features generated by the basis, i.e., `self.n_output_features`.
+ or other data. The shape of ``x`` along the specified axis must match the total
+ number of features generated by the basis, i.e., ``self.n_output_features``.
**Examples:**
- - For a design matrix: `(n_samples, total_n_features)`
- - For model coefficients: `(total_n_features,)` or `(total_n_features, n_neurons)`.
+ - For a design matrix: ``(n_samples, total_n_features)``
+ - For model coefficients: ``(total_n_features,)`` or ``(total_n_features, n_neurons)``.
axis : int, optional
The axis along which to split the features. Defaults to 1.
- Use `axis=1` for design matrices (features along columns) and `axis=0` for
+ Use ``axis=1`` for design matrices (features along columns) and ``axis=0`` for
coefficient arrays (features along rows). All other dimensions are preserved.
Raises
------
ValueError
- If the shape of `x` along the specified axis does not match `self.n_output_features`.
+ If the shape of ``x`` along the specified axis does not match ``self.n_output_features``.
Returns
-------
@@ -1720,12 +1723,11 @@ def split_by_feature(
- **Keys**: Labels of the additive basis components.
- **Values**: Sub-arrays corresponding to each component. Each sub-array has the shape:
- $$
- (..., n_i, b_i, ...)
- $$
+ .. math::
+ (..., n_i, b_i, ...)
- - `n_i`: The number of inputs processed by the i-th basis component.
- - `b_i`: The number of basis functions for the i-th basis component.
+ - ``n_i``: The number of inputs processed by the i-th basis component.
+ - ``b_i``: The number of basis functions for the i-th basis component.
These sub-arrays are reshaped along the specified axis, with all other dimensions
remaining the same.
@@ -1783,7 +1785,7 @@ class MultiplicativeBasis(Basis):
Attributes
----------
- n_basis_funcs : int
+ n_basis_funcs :
Number of basis functions.
Examples
@@ -1792,12 +1794,10 @@ class MultiplicativeBasis(Basis):
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 3))
-
>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = basis_1 * basis_2
-
>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
@@ -1828,7 +1828,7 @@ def _set_kernel(self, *xi: NDArray) -> Basis:
Parameters
----------
*xi:
- The sample inputs. Unused, necessary to conform to `scikit-learn` API.
+ The sample inputs. Unused, necessary to conform to ``scikit-learn`` API.
Returns
-------
@@ -1921,18 +1921,18 @@ class SplineBasis(Basis, abc.ABC):
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Attributes
----------
@@ -1964,6 +1964,10 @@ def __init__(
@property
def order(self):
+ """Spline order.
+
+ Spline order, i.e. the polynomial degree of the spline plus one.
+ """
return self._order
@order.setter
@@ -2056,16 +2060,16 @@ def _check_n_basis_min(self) -> None:
class MSplineBasis(SplineBasis):
r"""
- M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation.
+ M-spline basis functions for modeling and data transformation.
- M-splines are a type of spline basis function used for smooth curve fitting
+ M-splines [1]_ are a type of spline basis function used for smooth curve fitting
and data representation. They are positive and integrate to one, making them
suitable for probabilistic models and density estimation. The order of an
M-spline defines its smoothness, with higher orders resulting in smoother
splines.
This class provides functionality to create M-spline basis functions, allowing
- for flexible and smooth modeling of data. It inherits from the `SplineBasis`
+ for flexible and smooth modeling of data. It inherits from the ``SplineBasis``
abstract class, providing specific implementations for M-splines.
Parameters
@@ -2083,18 +2087,18 @@ class MSplineBasis(SplineBasis):
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs:
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Examples
--------
@@ -2106,16 +2110,16 @@ class MSplineBasis(SplineBasis):
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = mspline_basis(sample_points)
- # References
- ------------
- [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science,
+ References
+ ----------
+ .. [1] Ramsay, J. O. (1988). Monotone regression splines in action. Statistical science,
3(4), 425-441.
Notes
-----
- MSplines must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain
- (x-axis) of an MSpline basis is expanded by a factor of $\alpha$, the values on the co-domain (y-axis) values
- will shrink by a factor of $1/\alpha$.
+ ``MSplines`` must integrate to 1 over their domain (the area under the curve is 1). Therefore, if the domain
+ (x-axis) of an MSpline basis is expanded by a factor of :math:`\alpha`, the values on the co-domain (y-axis) values
+ will shrink by a factor of :math:`1/\alpha`.
For example, over the standard bounds of (0, 1), the maximum value of the MSpline is 18.
If we set the bounds to (0, 2), the maximum value will be 9, i.e., 18 / 2.
"""
@@ -2202,10 +2206,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-------
X : NDArray
A 1D array of uniformly spaced sample points within the domain [0, 1].
- Shape: `(n_samples,)`.
+ Shape: ``(n_samples,)``.
Y : NDArray
A 2D array where each row corresponds to the evaluated M-spline basis
- function values at the points in X. Shape: `(n_samples, n_basis_funcs)`.
+ function values at the points in X. Shape: ``(n_samples, n_basis_funcs)``.
Examples
--------
@@ -2231,34 +2235,36 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class BSplineBasis(SplineBasis):
"""
- B-spline[$^{[1]}$](#references) 1-dimensional basis functions.
+ B-spline 1-dimensional basis functions.
+
+ Implementation of the one-dimensional BSpline basis [1]_.
Parameters
----------
n_basis_funcs :
Number of basis functions.
mode :
- The mode of operation. 'eval' for evaluation at sample points,
+ The mode of operation. ``'eval'`` for evaluation at sample points,
'conv' for convolutional operation.
order :
- Order of the splines used in basis functions. Must lie within [1, n_basis_funcs].
+ Order of the splines used in basis functions. Must lie within ``[1, n_basis_funcs]``.
The B-splines have (order-2) continuous derivatives at each interior knot.
The higher this number, the smoother the basis representation will be.
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Attributes
----------
@@ -2266,16 +2272,15 @@ class BSplineBasis(SplineBasis):
Spline order.
- # References
- ------------
- [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
+ References
+ ----------
+ .. [1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import BSplineBasis
-
>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
@@ -2325,7 +2330,7 @@ def __call__(self, sample_pts: ArrayLike) -> FeatureMatrix:
Notes
-----
- The evaluation is performed by looping over each element and using `splev`
+ The evaluation is performed by looping over each element and using ``splev``
from SciPy to compute the basis values.
"""
sample_pts, _ = min_max_rescale_samples(sample_pts, self.bounds)
@@ -2348,14 +2353,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Returns
-------
X :
- Array of shape (n_samples,) containing the equi-spaced sample
+ Array of shape ``(n_samples,)`` containing the equi-spaced sample
points where we've evaluated the basis.
basis_funcs :
- Raised cosine basis functions, shape (n_samples, n_basis_funcs)
+ Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)``
Notes
-----
- The evaluation is performed by looping over each element and using `splev` from
+ The evaluation is performed by looping over each element and using ``splev`` from
SciPy to compute the basis values.
Examples
@@ -2387,32 +2392,31 @@ class CyclicBSplineBasis(SplineBasis):
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Attributes
----------
- n_basis_funcs : int
- Number of basis functions.
- order : int
- Order of the splines used in basis functions.
+ n_basis_funcs :
+ Number of basis functions, int.
+ order :
+ Order of the splines used in basis functions, int.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import CyclicBSplineBasis
>>> X = np.random.normal(size=(1000, 1))
-
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
@@ -2465,7 +2469,7 @@ def __call__(
Notes
-----
- The evaluation is performed by looping over each element and using `splev` from
+ The evaluation is performed by looping over each element and using ``splev`` from
SciPy to compute the basis values.
"""
@@ -2513,14 +2517,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Returns
-------
X :
- Array of shape (n_samples,) containing the equi-spaced sample
+ Array of shape ``(n_samples,)`` containing the equi-spaced sample
points where we've evaluated the basis.
basis_funcs :
- Raised cosine basis functions, shape (n_samples, n_basis_funcs)
+ Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)``
Notes
-----
- The evaluation is performed by looping over each element and using `splev` from
+ The evaluation is performed by looping over each element and using ``splev`` from
SciPy to compute the basis values.
Examples
@@ -2537,7 +2541,7 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
class RaisedCosineBasisLinear(Basis):
"""Represent linearly-spaced raised cosine basis functions.
- This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
+ This implementation is based on the cosine bumps used by Pillow et al. [1]_
to uniformly tile the internal points of the domain.
Parameters
@@ -2552,35 +2556,34 @@ class RaisedCosineBasisLinear(Basis):
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLinear
>>> X = np.random.normal(size=(1000, 1))
-
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
- # References
- ------------
- [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
+ References
+ ----------
+ .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
- 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
+ 11003–11013.
"""
def __init__(
@@ -2709,10 +2712,10 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Returns
-------
X :
- Array of shape (n_samples,) containing the equi-spaced sample
+ Array of shape ``(n_samples,)`` containing the equi-spaced sample
points where we've evaluated the basis.
basis_funcs :
- Raised cosine basis functions, shape (n_samples, n_basis_funcs)
+ Raised cosine basis functions, shape ``(n_samples, n_basis_funcs)``
Examples
--------
@@ -2744,8 +2747,8 @@ def _check_n_basis_min(self) -> None:
class RaisedCosineBasisLog(RaisedCosineBasisLinear):
"""Represent log-spaced raised cosine basis functions.
- Similar to `RaisedCosineBasisLinear` but the basis functions are log-spaced.
- This implementation is based on the cosine bumps used by Pillow et al.[$^{[1]}$](#references)
+ Similar to ``RaisedCosineBasisLinear`` but the basis functions are log-spaced.
+ This implementation is based on the cosine bumps used by Pillow et al. [1]_
to uniformly tile the internal points of the domain.
Parameters
@@ -2762,41 +2765,40 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
larger values resulting in more stretching. As this approaches 0, the
transformation becomes linear.
enforce_decay_to_zero:
- If set to True, the algorithm first constructs a basis with `n_basis_funcs + ceil(width)` elements
+ If set to True, the algorithm first constructs a basis with ``n_basis_funcs + ceil(width)`` elements
and subsequently trims off the extra basis elements. This ensures that the final basis element
decays to 0.
window_size :
The window size for convolution. Required if mode is 'conv'.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLog
>>> X = np.random.normal(size=(1000, 1))
-
>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
- # References
- ------------
- [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
+ References
+ ----------
+ .. [1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
C. E. (2005). Prediction and decoding of retinal ganglion cell responses
with a probabilistic spiking model. Journal of Neuroscience, 25(47),
- 11003–11013. http://dx.doi.org/10.1523/jneurosci.3305-05.2005
+ 11003–11013.
"""
def __init__(
@@ -2933,25 +2935,25 @@ class OrthExponentialBasis(Basis):
n_basis_funcs
Number of basis functions.
decay_rates :
- Decay rates of the exponentials, shape (n_basis_funcs,).
+ Decay rates of the exponentials, shape ``(n_basis_funcs,)``.
mode :
- The mode of operation. 'eval' for evaluation at sample points,
- 'conv' for convolutional operation.
+ The mode of operation. ``'eval'`` for evaluation at sample points,
+ ``'conv'`` for convolutional operation.
window_size :
- The window size for convolution. Required if mode is 'conv'.
+ The window size for convolution. Required if mode is ``'conv'``.
bounds :
- The bounds for the basis domain in `mode="eval"`. The default `bounds[0]` and `bounds[1]` are the
+ The bounds for the basis domain in ``mode="eval"``. The default ``bounds[0]`` and ``bounds[1]`` are the
minimum and the maximum of the samples provided when evaluating the basis.
If a sample is outside the bounds, the basis will return NaN.
label :
The label of the basis, intended to be descriptive of the task variable being processed.
For example: velocity, position, spike_counts.
**kwargs :
- Additional keyword arguments passed to `nemos.convolve.create_convolutional_predictor` when
- `mode='conv'`; These arguments are used to change the default behavior of the convolution.
- For example, changing the `predictor_causality`, which by default is set to `"causal"`.
- Note that one cannot change the default value for the `axis` parameter. Basis assumes
- that the convolution axis is `axis=0`.
+ Additional keyword arguments passed to :func:`nemos.convolve.create_convolutional_predictor` when
+ ``mode='conv'``; These arguments are used to change the default behavior of the convolution.
+ For example, changing the ``predictor_causality``, which by default is set to ``"causal"``.
+ Note that one cannot change the default value for the ``axis`` parameter. Basis assumes
+ that the convolution axis is ``axis=0``.
Examples
--------
@@ -2969,8 +2971,8 @@ class OrthExponentialBasis(Basis):
def __init__(
self,
n_basis_funcs: int,
- decay_rates: NDArray[np.floating],
- mode="eval",
+ decay_rates: NDArray,
+ mode: Literal["eval", "conv"] = "eval",
window_size: Optional[int] = None,
bounds: Optional[Tuple[float, float]] = None,
label: Optional[str] = "OrthExponentialBasis",
@@ -2990,7 +2992,11 @@ def __init__(
@property
def decay_rates(self):
- """Decay rate getter."""
+ r"""Decay rate.
+
+ The rate of decay of the exponential functions. If :math:`f_i(t) = e^{-\alpha_i t}` is the i-th decay
+ exponential before orthogonalization, :math:`\alpha_i` is the i-th element of the ``decay_rate`` vector.
+ """
return self._decay_rates
@decay_rates.setter
@@ -3072,14 +3078,14 @@ def __call__(
Parameters
----------
sample_pts
- Spacing for basis functions, holding elements on the interval [0,
- inf), shape (n_samples,).
+ Spacing for basis functions, holding elements on the interval :math:`[0,inf)`,
+ shape ``(n_samples,)``.
Returns
-------
basis_funcs
Evaluated exponentially decaying basis functions, numerically
- orthogonalized, shape (n_samples, n_basis_funcs)
+ orthogonalized, shape ``(n_samples, n_basis_funcs)``.
"""
self._check_sample_size(sample_pts)
@@ -3113,11 +3119,11 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Returns
-------
X :
- Array of shape (n_samples,) containing the equi-spaced sample
+ Array of shape ``(n_samples,)`` containing the equi-spaced sample
points where we've evaluated the basis.
basis_funcs :
Evaluated exponentially decaying basis functions, numerically
- orthogonalized, shape (n_samples, n_basis_funcs)
+ orthogonalized, shape ``(n_samples, n_basis_funcs)``
Examples
--------
@@ -3157,7 +3163,6 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray:
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import mspline
-
>>> sample_points = linspace(0, 1, 100)
>>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline
>>> mspline_eval.shape
@@ -3223,7 +3228,7 @@ def bspline(
Raises
------
AssertionError
- If `outer_ok` is False and the sample points lie outside the B-spline knots range.
+ If ``outer_ok`` is False and the sample points lie outside the B-spline knots range.
Notes
-----
@@ -3234,7 +3239,6 @@ def bspline(
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import bspline
-
>>> sample_points = linspace(0, 1, 100)
>>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1])
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
diff --git a/src/nemos/convolve.py b/src/nemos/convolve.py
index 6a6294c3..248305c7 100644
--- a/src/nemos/convolve.py
+++ b/src/nemos/convolve.py
@@ -21,7 +21,7 @@
@jax.jit
-def reshape_convolve(array: NDArray, eval_basis: NDArray):
+def tensor_convolve(array: NDArray, eval_basis: NDArray):
"""
Apply a convolution on the given array with the evaluation basis and reshapes the result.
@@ -32,22 +32,23 @@ def reshape_convolve(array: NDArray, eval_basis: NDArray):
Parameters
----------
array :
- The input array to convolve. It is expected to be at least 1D.
+ The input array to convolve. It is expected to be at least 1D. The first axis is expeted to be
+ the sample axis, i.e. the shape of array is ``(num_samples, ...)``.
eval_basis :
The evaluation basis array for convolution. It should be 2D, where the first dimension
- represents the window size for convolution.
+ represents the window size for convolution. Shape ``(window_size, n_basis_funcs)``.
Returns
-------
:
The convolved array, reshaped to maintain the original dimensions except for the first one,
- which is adjusted based on the window size of `eval_basis`.
+ which is adjusted based on the window size of ``eval_basis``.
Notes
-----
- The convolution implemented here is in mode `"valid"`. This implies that the time axis shrinks
- `num_samples - window_size + 1`, where num_samples is the first size of the first axis of `array`
- and `window_size` is the size of the first axis in `eval_basis`.
+ The convolution implemented here is in mode ``"valid"``. This implies that the time axis shrinks
+ ``num_samples - window_size + 1``, where num_samples is the first size of the first axis of ``array``
+ and ``window_size`` is the size of the first axis in ``eval_basis``.
"""
# flatten over other dims & apply vectorized conv
conv = _CORR_VEC(array.reshape(array.shape[0], -1), eval_basis)
@@ -95,7 +96,7 @@ def _shift_time_axis_and_convolve(array: NDArray, eval_basis: NDArray, axis: int
# convolve
if array.ndim > 1:
- conv = reshape_convolve(array, eval_basis)
+ conv = tensor_convolve(array, eval_basis)
else:
conv = _CORR_VEC_BASIS(array, eval_basis)
@@ -247,14 +248,18 @@ def create_convolutional_predictor(
Raises
------
- ValueError
- - If `basis_matrix` is not a 2-dimensional array or has a singleton first dimension.
- - If `time_series` does not contain arrays of at least one dimension or contains
- arrays with a dimensionality less than `axis`.
- - If any array within `time_series` or `basis_matrix` is empty.
- - If the number of elements along the convolution axis in any array within `time_series`
- is less than the window size of the `basis_matrix`.
- - If shifting is attempted with 'acausal' causality.
+ ValueError:
+ If `basis_matrix` is not a 2-dimensional array or has a singleton first dimension.
+ ValueError:
+ If `time_series` does not contain arrays of at least one dimension or contains
+ arrays with a dimensionality less than `axis`.
+ ValueError:
+ If any array within `time_series` or `basis_matrix` is empty.
+ ValueError:
+ If the number of elements along the convolution axis in any array within `time_series`
+ is less than the window size of the `basis_matrix`.
+ ValueError:
+ If shifting is attempted with 'acausal' causality.
"""
# convert to jnp.ndarray
basis_matrix = jnp.asarray(basis_matrix)
diff --git a/src/nemos/fetch/fetch_data.py b/src/nemos/fetch/fetch_data.py
index 1246c993..faf43f59 100644
--- a/src/nemos/fetch/fetch_data.py
+++ b/src/nemos/fetch/fetch_data.py
@@ -14,9 +14,11 @@
try:
import pooch
from pooch import Pooch
+ from tqdm.auto import tqdm
except ImportError:
pooch = None
Pooch = None
+ tqdm = None
try:
import dandi
@@ -124,7 +126,7 @@ def fetch_data(
)
retriever = _create_retriever(path)
# Fetch the dataset using pooch.
- return retriever.fetch(dataset_name, progressbar=True)
+ return retriever.fetch(dataset_name)
def download_dandi_data(dandiset_id: str, filepath: str) -> NWBHDF5IO:
diff --git a/src/nemos/glm.py b/src/nemos/glm.py
index 9b442ebf..509d8beb 100644
--- a/src/nemos/glm.py
+++ b/src/nemos/glm.py
@@ -54,38 +54,47 @@ class GLM(BaseRegressor):
don't follow a normal distribution.
Below is a table listing the default and available solvers for each regularizer.
+ +---------------+------------------+-------------------------------------------------------------+
| Regularizer | Default Solver | Available Solvers |
- | ------------- | ---------------- | ----------------------------------------------------------- |
- | UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
- | Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
- | Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
- | GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG |
-
+ +===============+==================+=============================================================+
+ | UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | Lasso | ProximalGradient | ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | GroupLasso | ProximalGradient | ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
- ([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
- ([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
+ :class:`nemos.solvers._svrg.SVRG` or its proximal variant
+ :class:`nemos.solvers._svrg.ProxSVRG` solver,
which take advantage of batched computation. You can change the solver by passing
- `"SVRG"` as `solver_name` at model initialization.
+ ``"SVRG"`` as ``solver_name`` at model initialization.
- The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
+ The performance of the SVRG solver depends critically on the choice of ``batch_size`` and ``stepsize``
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
- To assist with this, for certain GLM configurations, we provide `batch_size` and `stepsize` default
+ To assist with this, for certain GLM configurations, we provide ``batch_size`` and ``stepsize`` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed default hyperparameters:
- | GLM / PopulationGLM Configuration | Stepsize | Batch Size |
- | --------------------------------- | :------: | :---------: |
- | Poisson + soft-plus + UnRegularized | ✅ | ❌ |
- | Poisson + soft-plus + Ridge | ✅ | ✅ |
- | Poisson + soft-plus + Lasso | ✅ | ❌ |
- | Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | GLM / PopulationGLM Configuration | Stepsize | Batch Size |
+ +=======================================+===========+=============+
+ | Poisson + soft-plus + UnRegularized | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + Ridge | ✅ | ✅ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + Lasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
Parameters
----------
@@ -102,7 +111,7 @@ class GLM(BaseRegressor):
solver_name :
Solver to use for model optimization. Defines the optimization scheme and related parameters.
The solver must be an appropriate match for the chosen regularizer.
- Default is `None`. If no solver specified, one will be chosen based on the regularizer.
+ Default is ``None``. If no solver specified, one will be chosen based on the regularizer.
Please see table above for regularizer/optimizer pairings.
solver_kwargs :
Optional dictionary for keyword arguments that are passed to the solver when instantiated.
@@ -113,15 +122,15 @@ class GLM(BaseRegressor):
----------
intercept_ :
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline
- firing rate will be `jnp.exp(model.intercept_)`.
+ firing rate will be ``jnp.exp(model.intercept_)``.
coef_ :
Basis coefficients for the model.
solver_state_ :
State of the solver after fitting. May include details like optimization error.
scale_:
- Scale parameter for the model. The scale parameter is the constant $\Phi$, for which
- $\text{Var} \left( y \right) = \Phi V(\mu)$. This parameter, together with the estimate
- of the mean $\mu$ fully specifies the distribution of the activity $y$.
+ Scale parameter for the model. The scale parameter is the constant :math:`\Phi`, for which
+ :math:`\text{Var} \left( y \right) = \Phi V(\mu)`. This parameter, together with the estimate
+ of the mean :math:`\mu` fully specifies the distribution of the activity :math:`y`.
dof_resid_:
Degrees of freedom for the residuals.
@@ -129,20 +138,17 @@ class GLM(BaseRegressor):
Raises
------
TypeError
- If provided `regularizer` or `observation_model` are not valid.
+ If provided ``regularizer`` or ``observation_model`` are not valid.
Examples
--------
>>> import nemos as nmo
-
>>> # define single neuron GLM model
>>> model = nmo.glm.GLM()
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type:
>>> print("Observation model: ", type(model.observation_model))
Observation model:
-
-
>>> # define GLM model of PoissonObservations model with soft-plus NL
>>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
>>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS")
@@ -178,7 +184,7 @@ def __init__(
@property
def observation_model(self) -> Union[None, obs.Observations]:
- """Getter for the observation_model attribute."""
+ """Getter for the ``observation_model`` attribute."""
return self._observation_model
@observation_model.setter
@@ -260,9 +266,9 @@ def _check_input_and_params_consistency(
Raises
------
ValueError
- - If param and X have different structures.
- - if the number of features is inconsistent between params[1] and X
- (when provided).
+ If param and X have different structures.
+ ValueError
+ if the number of features is inconsistent between params[1] and X (when provided).
"""
if X is not None:
@@ -303,8 +309,8 @@ def _predict(
Predicts firing rates based on given parameters and design matrix.
This function computes the predicted firing rates using the provided parameters
- and model design matrix `X`. It is a streamlined version used internally within
- optimization routines, where it serves as the loss function. Unlike the `GLM.predict`
+ and model design matrix ``X``. It is a streamlined version used internally within
+ optimization routines, where it serves as the loss function. Unlike the ``GLM.predict``
method, it does not perform any input validation, assuming that the inputs are pre-validated.
@@ -336,47 +342,50 @@ def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray:
Parameters
----------
X :
- Predictors, array of shape (n_time_bins, n_features) or pytree of same.
+ Predictors, array of shape ``(n_time_bins, n_features)`` or pytree of same.
Returns
-------
:
- The predicted rates with shape (n_time_bins, ).
+ The predicted rates with shape ``(n_time_bins, )``.
Raises
------
NotFittedError
If ``fit`` has not been called first with this instance.
ValueError
- - If `params` is not a JAX pytree of size two.
- - If weights and bias terms in `params` don't have the expected dimensions.
- - If `X` is not three-dimensional.
- - If there's an inconsistent number of features between spike basis coefficients and `X`.
+ If ``params`` is not a JAX pytree of size two.
+ ValueError
+ If weights and bias terms in ``params`` don't have the expected dimensions.
+ ValueError
+ If ``X`` is not three-dimensional.
+ ValueError
+ If there's an inconsistent number of features between spike basis coefficients and ``X``.
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
-
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
-
>>> # predict new spike data
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> predicted_spikes = model.predict(Xnew)
See Also
--------
- - [score](./#nemos.glm.GLM.score)
+ :meth:`nemos.glm.GLM.score`
Score predicted rates against target spike counts.
- - [simulate (feed-forward only)](../glm/#nemos.glm.GLM.simulate)
- Simulate neural activity in response to a feed-forward input .
- - [simulate_recurrent (feed-forward + coupling)](../simulation/#nemos.simulation.simulate_recurrent)
+
+ :meth:`nemos.glm.GLM.simulate`
+ Simulate neural activity in response to a feed-forward input (feed-forward only).
+
+ :func:`nemos.simulation.simulate_recurrent`
Simulate neural activity in response to a feed-forward input
- using the GLM as a recurrent network.
+ using the GLM as a recurrent network (feed-forward + coupling).
"""
# check that the model is fitted
self._check_is_fit()
@@ -403,7 +412,7 @@ def _predict_and_compute_loss(
) -> jnp.ndarray:
r"""Predict the rate and compute the negative log-likelihood against neural activity.
- This method computes the negative log-likelihood up to a constant term. Unlike `score`,
+ This method computes the negative log-likelihood up to a constant term. Unlike ``score``,
it does not conduct parameter checks prior to evaluation. Passed directly to the solver,
it serves to establish the optimization objective for learning the model parameters.
@@ -437,7 +446,7 @@ def score(
r"""Evaluate the goodness-of-fit of the model to the observed neural data.
This method computes the goodness-of-fit score, which can either be the mean
- log-likelihood or of two versions of the pseudo-R^2.
+ log-likelihood or of two versions of the pseudo-:math:`R^2`.
The scoring process includes validation of input compatibility with the model's
parameters, ensuring that the model has been previously fitted and the input data
are appropriate for scoring. A higher score indicates a better fit of the model
@@ -447,24 +456,26 @@ def score(
Parameters
----------
X :
- The exogenous variables. Shape (n_time_bins, n_features)
+ The exogenous variables. Shape ``(n_time_bins, n_features)``.
y :
- Neural activity. Shape (n_time_bins, ).
+ Neural activity. Shape ``(n_time_bins, )``.
score_type :
- Type of scoring: either log-likelihood or pseudo-r2.
+ Type of scoring: either log-likelihood or pseudo-:math:`R^2`.
aggregate_sample_scores :
Function that aggregates the score of all samples.
Returns
-------
score :
- The log-likelihood or the pseudo-$R^2$ of the current model.
+ The log-likelihood or the pseudo-:math:`R^2` of the current model.
Raises
------
NotFittedError
+
If ``fit`` has not been called first with this instance.
ValueError
+
If X structure doesn't match the params, and if X and y have different
number of samples.
@@ -473,14 +484,11 @@ def score(
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
-
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
-
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
-
>>> # get a pseudo-R2 score
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
@@ -490,25 +498,26 @@ def score(
among which the number of model parameters. The log-likelihood can assume both positive
and negative values.
- The Pseudo-$ R^2 $ is not equivalent to the $ R^2 $ value in linear regression. While both
+ The Pseudo-:math:`R^2` is not equivalent to the :math:`R^2` value in linear regression. While both
provide a measure of model fit, and assume values in the [0,1] range, the methods and
- interpretations can differ. The Pseudo-$ R^2 $ is particularly useful for generalized linear
- models when the interpretation of the $ R^2 $ as explained variance does not apply
+ interpretations can differ. The Pseudo-:math:`R^2` is particularly useful for generalized linear
+ models when the interpretation of the :math:`R^2` as explained variance does not apply
(i.e., when the observations are not Gaussian distributed).
- Why does the traditional $R^2$ is usually a poor measure of performance in GLMs?
+ Why does the traditional :math:`R^2` is usually a poor measure of performance in GLMs?
1. In the context of GLMs the variance and the mean of the observations are related.
- Ignoring the relation between them can result in underestimating the model
- performance; for instance, when we model a Poisson variable with large mean we expect an
- equally large variance. In this scenario, even if our model perfectly captures the mean,
- the high-variance will result in large residuals and low $R^2$.
- Additionally, when the mean of the observations varies, the variance will vary too. This
- violates the "homoschedasticity" assumption, necessary for interpreting the $R^2$ as
- variance explained.
- 2. The $R^2$ capture the variance explained when the relationship between the observations and
- the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors
- and the mean of the observations, compromising the interpretation of the $R^2$.
+ Ignoring the relation between them can result in underestimating the model
+ performance; for instance, when we model a Poisson variable with large mean we expect an
+ equally large variance. In this scenario, even if our model perfectly captures the mean,
+ the high-variance will result in large residuals and low :math:`R^2`.
+ Additionally, when the mean of the observations varies, the variance will vary too. This
+ violates the "homoschedasticity" assumption, necessary for interpreting the :math:`R^2` as
+ variance explained.
+
+ 2. The :math:`R^2` capture the variance explained when the relationship between the observations and
+ the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors
+ and the mean of the observations, compromising the interpretation of the :math:`R^2`.
Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional
to the model standard deviation (i.e. Pearson residuals). This "rescaled" residual distribution however
@@ -516,8 +525,8 @@ def score(
Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially
for GLM modeling counting data.
- Refer to the `nmo.observation_models.Observations` concrete subclasses for the likelihood and
- pseudo-$R^2$ equations.
+ Refer to the ``nmo.observation_models.Observations`` concrete subclasses for the likelihood and
+ pseudo-:math:`R^2` equations.
"""
self._check_is_fit()
@@ -572,18 +581,19 @@ def _initialize_parameters(
This method initializes the coefficients (spike basis coefficients) and intercepts (bias terms)
required for the GLM. The coefficients are initialized to zeros with dimensions based on the input X.
- If X is a FeaturePytree, the coefficients retain the pytree structure with arrays of zeros shaped
- according to the features in X. If X is a simple ndarray, the coefficients are initialized as a 2D
- array. The intercepts are initialized based on the log mean of the target data y across the first
- axis, corresponding to the average log activity of the neuron.
+ If X is a :class:`nemos.pytrees.FeaturePytree`, the coefficients retain the pytree structure with
+ arrays of zeros shaped according to the features in X.
+ If X is a simple ndarray, the coefficients are initialized as a 2D array. The intercepts are initialized
+ based on the log mean of the target data y across the first axis, corresponding to the average log activity
+ of the neuron.
Parameters
----------
X :
- The input data which can be a FeaturePytree with n_features arrays of shape (n_timebins,
- n_features), or a simple ndarray of shape (n_timebins, n_features).
+ The input data which can be a :class:`nemos.pytrees.FeaturePytree` with n_features arrays of shape
+ ``(n_timebins, n_features)``, or a simple ndarray of shape ``(n_timebins, n_features)``.
y :
- The target data array of shape (n_timebins, ), representing
+ The target data array of shape ``(n_timebins, )``, representing
the neuron firing rates or similar metrics.
Returns
@@ -660,27 +670,30 @@ def fit(
Raises
------
ValueError
- - If `init_params` is not of length two.
- - If dimensionality of `init_params` are not correct.
- - If `X` is not two-dimensional.
- - If `y` is not one-dimensional.
- - If solver returns at least one NaN parameter, which means it found
+ If ``init_params`` is not of length two.
+ ValueError
+ If dimensionality of ``init_params`` are not correct.
+ ValueError
+ If ``X`` is not two-dimensional.
+ ValueError
+ If ``y`` is not one-dimensional.
+ ValueError
+ If solver returns at least one NaN parameter, which means it found
an invalid solution. Try tuning optimization hyperparameters.
TypeError
- - If `init_params` are not array-like
- - If `init_params[i]` cannot be converted to jnp.ndarray for all i
+ If ``init_params`` are not array-like
+ TypeError
+ If ``init_params[i]`` cannot be converted to ``jnp.ndarray`` for all ``i``
Examples
- -------
+ --------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
-
>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
-
>>> # get model weights and intercept
>>> model_weights = model.coef_
>>> model_intercept = model.intercept_
@@ -767,14 +780,14 @@ def simulate(
-------
simulated_activity :
Simulated activity (spike counts for Poisson GLMs) for the neuron over time.
- Shape: (n_time_bins, ).
+ Shape: ``(n_time_bins, )``.
firing_rates :
- Simulated rates for the neuron over time. Shape, (n_time_bins, ).
+ Simulated rates for the neuron over time. Shape, ``(n_time_bins, )``.
Raises
------
NotFittedError
- If the model hasn't been fitted prior to calling this method.
+ - If the model hasn't been fitted prior to calling this method.
ValueError
- If the instance has not been previously fitted.
@@ -783,12 +796,10 @@ def simulate(
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
-
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
-
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
@@ -796,8 +807,8 @@ def simulate(
See Also
--------
- [predict](./#nemos.glm.GLM.predict) :
- Method to predict rates based on the model's parameters.
+ :meth:`nemos.glm.GLM.predict`
+ Method to predict rates based on the model's parameters.
"""
# check if the model is fit
self._check_is_fit()
@@ -834,8 +845,8 @@ def _estimate_resid_degrees_of_freedom(
X :
The design matrix.
n_samples :
- The number of samples observed. If not provided, n_samples is set to `X.shape[0]`. If the fit is
- batched, the n_samples could be larger than `X.shape[0]`.
+ The number of samples observed. If not provided, n_samples is set to ``X.shape[0]``. If the fit is
+ batched, the n_samples could be larger than ``X.shape[0]``.
Returns
-------
@@ -907,14 +918,18 @@ def initialize_params(
Raises
------
ValueError
- - If `params` is not of length two.
- - If dimensionality of `init_params` are not correct.
- - If `X` is not two-dimensional.
- - If `y` is not correct (1D for GLM, 2D for populationGLM).
+ If ``params`` is not of length two.
+ ValueError
+ If dimensionality of ``init_params`` are not correct.
+ ValueError
+ If ``X`` is not two-dimensional.
+ ValueError
+ If ``y`` is not correct (1D for GLM, 2D for populationGLM).
TypeError
- - If `params` are not array-like when provided.
- - If `init_params[i]` cannot be converted to jnp.ndarray for all i
+ If ``params`` are not array-like when provided.
+ TypeError
+ If ``init_params[i]`` cannot be converted to jnp.ndarray for all i
Examples
--------
@@ -995,7 +1010,7 @@ def initialize_state(
)
self.regularizer.mask = jnp.ones((1, data.shape[1]))
- opt_solver_kwargs = self.optimize_solver_params(data, y)
+ opt_solver_kwargs = self._optimize_solver_params(data, y)
# set up the solver init/run/update attrs
self.instantiate_solver(solver_kwargs=opt_solver_kwargs)
@@ -1033,7 +1048,7 @@ def update(
step sizes, and other optimizer-specific metrics.
X :
The predictors used in the model fitting process, which may include feature matrices
- or FeaturePytree objects.
+ or :class:`nemos.pytrees.FeaturePytree` objects.
y :
The response variable or output data corresponding to the predictors, used in the model
fitting process.
@@ -1041,7 +1056,7 @@ def update(
Additional positional arguments to be passed to the solver's update method.
n_samples:
The tot number of samples. Usually larger than the samples of an indivisual batch,
- the `n_samples` are used to estimate the scale parameter of the GLM.
+ the ``n_samples`` are used to estimate the scale parameter of the GLM.
**kwargs
Additional keyword arguments to be passed to the solver's update method.
@@ -1067,6 +1082,7 @@ def update(
>>> params = glm_instance.coef_, glm_instance.intercept_
>>> opt_state = glm_instance.solver_state_
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)
+
"""
# find non-nans
is_valid = tree_utils.get_valid_multitree(X, y)
@@ -1095,7 +1111,7 @@ def update(
return opt_step
- def get_optimal_solver_params_config(self):
+ def _get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
return glm_compute_optimal_stepsize_configs(self)
@@ -1109,41 +1125,50 @@ class PopulationGLM(GLM):
combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model.
It is suitable for scenarios where the relationship between predictors and the response
variable might be non-linear, and the residuals don't follow a normal distribution. The predictors must be
- stored in tabular format, shape (n_timebins, num_features) or as [FeaturePytree](../pytrees).
+ stored in tabular format, shape (n_timebins, num_features) or as :class:`nemos.pytrees.FeaturePytree`.
Below is a table listing the default and available solvers for each regularizer.
+ +---------------+------------------+-------------------------------------------------------------+
| Regularizer | Default Solver | Available Solvers |
- | ------------- | ---------------- | ----------------------------------------------------------- |
- | UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
- | Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, SVRG, ProximalGradient, ProxSVRG |
- | Lasso | ProximalGradient | ProximalGradient, ProxSVRG |
- | GroupLasso | ProximalGradient | ProximalGradient, ProxSVRG |
-
+ +===============+==================+=============================================================+
+ | UnRegularized | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | Ridge | GradientDescent | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | Lasso | ProximalGradient | ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
+ | GroupLasso | ProximalGradient | ProximalGradient |
+ +---------------+------------------+-------------------------------------------------------------+
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
- ([SVRG](../solvers/_svrg/#nemos.solvers._svrg.SVRG)) or its proximal variant
- ([ProxSVRG](../solvers/_svrg/#nemos.solvers._svrg.ProxSVRG)) solver,
+ :class:`nemos.solvers._svrg.SVRG` or its proximal variant
+ (:class:`nemos.solvers._svrg.ProxSVRG`) solver,
which take advantage of batched computation. You can change the solver by passing
- `"SVRG"` or `"ProxSVRG"` as `solver_name` at model initialization.
+ ``"SVRG"`` or ``"ProxSVRG"`` as ``solver_name`` at model initialization.
- The performance of the SVRG solver depends critically on the choice of `batch_size` and `stepsize`
+ The performance of the SVRG solver depends critically on the choice of ``batch_size`` and ``stepsize``
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
- To assist with this, for certain GLM configurations, we provide `batch_size` and `stepsize` default
+ To assist with this, for certain GLM configurations, we provide ``batch_size`` and ``stepsize`` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed hyperparameters:
- | GLM / PopulationGLM Configuration | Stepsize | Batch Size |
- | --------------------------------- | :------: | :---------: |
- | Poisson + soft-plus + UnRegularized | ✅ | ❌ |
- | Poisson + soft-plus + Ridge | ✅ | ✅ |
- | Poisson + soft-plus + Lasso | ✅ | ❌ |
- | Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | GLM / PopulationGLM Configuration | Stepsize | Batch Size |
+ +=======================================+===========+=============+
+ | Poisson + soft-plus + UnRegularized | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + Ridge | ✅ | ✅ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + Lasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
+ | Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+ +---------------------------------------+-----------+-------------+
Parameters
----------
@@ -1160,22 +1185,22 @@ class PopulationGLM(GLM):
solver_name :
Solver to use for model optimization. Defines the optimization scheme and related parameters.
The solver must be an appropriate match for the chosen regularizer.
- Default is `None`. If no solver specified, one will be chosen based on the regularizer.
+ Default is ``None``. If no solver specified, one will be chosen based on the regularizer.
Please see table above for regularizer/optimizer pairings.
solver_kwargs :
Optional dictionary for keyword arguments that are passed to the solver when instantiated.
E.g. stepsize, acceleration, value_and_grad, etc.
See the jaxopt documentation for details on each solver's kwargs: https://jaxopt.github.io/stable/
feature_mask :
- Either a matrix of shape (num_features, num_neurons) or a [FeaturePytree](../pytrees) of 0s and 1s, with
- `feature_mask[feature_name]` of shape (num_neurons, ).
+ Either a matrix of shape (num_features, num_neurons) or a :meth:`nemos.pytrees.FeaturePytree` of 0s and 1s, with
+ ``feature_mask[feature_name]`` of shape (num_neurons, ).
The mask will be used to select which features are used as predictors for which neuron.
Attributes
----------
intercept_ :
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline
- firing rate will be `jnp.exp(model.intercept_)`.
+ firing rate will be ``jnp.exp(model.intercept_)``.
coef_ :
Basis coefficients for the model.
solver_state_ :
@@ -1184,8 +1209,9 @@ class PopulationGLM(GLM):
Raises
------
TypeError
- - If provided `regularizer` or `observation_model` are not valid.
- - If provided `feature_mask` is not an array-like of dimension two.
+ If provided ``regularizer`` or ``observation_model`` are not valid.
+ TypeError
+ If provided ``feature_mask`` is not an array-like of dimension two.
Examples
--------
@@ -1221,12 +1247,10 @@ class PopulationGLM(GLM):
>>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"]))
>>> y = np.random.poisson(rate)
>>> # Define a feature mask with arrays of shape (num_neurons, )
-
>>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0]))
>>> print(feature_mask)
feature_1: shape (2,), dtype int32
feature_2: shape (2,), dtype int32
-
>>> # Fit a PopulationGLM
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> # Coefficients are stored in a dictionary with keys the feature labels
@@ -1256,7 +1280,7 @@ def __init__(
@property
def feature_mask(self) -> Union[jnp.ndarray, dict]:
- """Define a feature mask of shape (n_features, n_neurons)."""
+ """Define a feature mask of shape ``(n_features, n_neurons)``."""
return self._feature_mask
@feature_mask.setter
@@ -1470,10 +1494,10 @@ def fit(
):
"""Fit GLM to the activity of a population of neurons.
- Fit and store the model parameters as attributes `coef_` and `intercept_`.
- Each neuron can have different predictors. The `feature_mask` will determine which
+ Fit and store the model parameters as attributes ``coef_`` and ``intercept_``.
+ Each neuron can have different predictors. The ``feature_mask`` will determine which
feature will be used for which neurons. See the note below for more information on
- the `feature_mask`.
+ the ``feature_mask``.
Parameters
----------
@@ -1492,26 +1516,33 @@ def fit(
Raises
------
ValueError
- - If `init_params` is not of length two.
- - If dimensionality of `init_params` are not correct.
- - If `X` is not two-dimensional.
- - If `y` is not two-dimensional.
- - If the `feature_mask` is not of the right shape.
- - If solver returns at least one NaN parameter, which means it found
- an invalid solution. Try tuning optimization hyperparameters.
+ If ``init_params`` is not of length two.
+ ValueError
+ If dimensionality of ``init_params`` are not correct.
+ ValueError
+ If ``X`` is not two-dimensional.
+ ValueError
+ If ``y`` is not two-dimensional.
+ ValueError
+ If the ``feature_mask`` is not of the right shape.
+ ValueError
+ If solver returns at least one NaN parameter, which means it found
+ an invalid solution. Try tuning optimization hyperparameters.
+ TypeError
+ If ``init_params`` are not array-like
TypeError
- - If `init_params` are not array-like
- - If `init_params[i]` cannot be converted to jnp.ndarray for all i
+ If ``init_params[i]`` cannot be converted to jnp.ndarray for all i
Notes
-----
- The `feature_mask` is used to select features for each neuron, and it is
- an NDArray or a `FeaturePytree` of 0s and 1s. In particular,
+ The ``feature_mask`` is used to select features for each neuron, and it is
+ an NDArray or a :class:`nemos.pytrees.FeaturePytree` of 0s and 1s. In particular,
+
+ - If the mask is in array format, feature ``i`` is a predictor for neuron ``j`` if
+ ``feature_mask[i, j] == 1``.
- - If the mask is in array format, feature `i` is a predictor for neuron `j` if
- `feature_mask[i, j] == 1`.
- - If the mask is a `FeaturePytree`, then `"feature_name"` is a predictor of neuron `j` if
- `feature_mask["feature_name"][j] == 1`.
+ - If the mask is a :class:``nemos.pytrees.FeaturePytree``, then
+ ``"feature_name"`` is a predictor of neuron ``j`` if ``feature_mask["feature_name"][j] == 1``.
Examples
--------
@@ -1519,7 +1550,6 @@ def fit(
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM
-
>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
@@ -1527,10 +1557,8 @@ def fit(
>>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]])
>>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts
>>> y = np.random.poisson(np.exp(X.dot(weights)))
-
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
-
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> print(model.coef_.shape)
@@ -1559,8 +1587,8 @@ def _predict(
Predicts firing rates based on given parameters and design matrix.
This function computes the predicted firing rates using the provided parameters, the feature
- mask and model design matrix `X`. It is a streamlined version used internally within
- optimization routines, where it serves as the loss function. Unlike the `GLM.predict`
+ mask and model design matrix ``X``. It is a streamlined version used internally within
+ optimization routines, where it serves as the loss function. Unlike the ``GLM.predict``
method, it does not perform any input validation, assuming that the inputs are pre-validated.
The parameters are first element-wise multiplied with the mask, then the canonical
linear-non-linear GLM map is applied.
diff --git a/src/nemos/identifiability_constraints.py b/src/nemos/identifiability_constraints.py
index 79437c65..cb665ec8 100644
--- a/src/nemos/identifiability_constraints.py
+++ b/src/nemos/identifiability_constraints.py
@@ -1,5 +1,7 @@
"""Utility functions for applying identifiability constraints to rank deficient feature matrices."""
+from __future__ import annotations
+
from functools import partial
from typing import Callable, Tuple
@@ -180,9 +182,9 @@ def apply_identifiability_constraints(
warn_if_float32: bool = True,
) -> Tuple[NDArray, NDArray[int]]:
"""
- Apply identifiability constraints to a design matrix `X`.
+ Apply identifiability constraints to a design matrix ``X``.
- Removes columns from `X` until it is full rank to ensure the uniqueness
+ Removes columns from ``X`` until it is full rank to ensure the uniqueness
of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly
crucial for models using bases like BSplines and CyclicBspline, which, due to their
construction, sum to 1 and can cause rank deficiency when combined with an intercept.
@@ -192,7 +194,7 @@ def apply_identifiability_constraints(
in the absence of regularization.
For very large feature matrices generated by a sum of low-dimensional basis components, consider
- `apply_identifiability_constraints_by_basis_component`.
+ ``apply_identifiability_constraints_by_basis_component``.
Parameters
----------
@@ -256,12 +258,12 @@ def apply_identifiability_constraints_by_basis_component(
feature_matrix: NDArray,
add_intercept: bool = True,
) -> Tuple[NDArray, NDArray]:
- """Apply identifiability constraint to a design matrix to each component of an additive basis.
+ """Apply identifiability constraint to a design matrix for each component of an additive basis.
Parameters
----------
basis:
- The basis that computed X.
+ The basis that computed ``feature_matrix``.
feature_matrix:
The feature matrix before applying the identifiability constraints.
add_intercept:
diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py
index 9d683ae1..4a72d561 100644
--- a/src/nemos/observation_models.py
+++ b/src/nemos/observation_models.py
@@ -10,7 +10,7 @@
from . import utils
from .base_class import Base
-__all__ = ["PoissonObservations"]
+__all__ = ["PoissonObservations", "GammaObservations"]
def __dir__():
@@ -23,8 +23,9 @@ class Observations(Base, abc.ABC):
This is an abstract base class used to implement observation models for neural data.
Specific observation models that inherit from this class should define their versions
- of the abstract methods: _negative_log_likelihood, emission_probability, and
- residual_deviance.
+ of the abstract methods such as :meth:`~nemos.observation_models.Observations.log_likelihood`,
+ :meth:`~nemos.observation_models.Observations.sample_generator`, and
+ :meth:`~nemos.observation_models.Observations.deviance`.
Attributes
----------
@@ -33,8 +34,10 @@ class Observations(Base, abc.ABC):
See Also
--------
- [PoissonObservations](./#nemos.observation_models.PoissonObservations) : A specific implementation of a
- observation model using the Poisson distribution.
+ :class:`~nemos.observation_models.PoissonObservations`
+ A specific implementation of a observation model using the Poisson distribution.
+ :class:`~nemos.observation_models.GammaObservations`
+ A specific implementation of a observation model using the Gamma distribution.
"""
def __init__(self, inverse_link_function: Callable, **kwargs):
@@ -72,6 +75,7 @@ def check_inverse_link_function(inverse_link_function: Callable):
Check if the provided inverse_link_function is usable.
This function verifies if the inverse link function:
+
1. Is callable
2. Returns a jax.numpy.ndarray
3. Is differentiable (via jax)
@@ -210,9 +214,9 @@ def deviance(
Parameters
----------
spike_counts:
- The spike counts. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.
+ The spike counts. Shape ``(n_time_bins, )`` or ``(n_time_bins, n_neurons)`` for population models.
predicted_rate:
- The predicted firing rates. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.
+ The predicted firing rates. Shape ``(n_time_bins, )`` or ``(n_time_bins, n_neurons)`` for population models.
scale:
Scale parameter of the model.
@@ -232,17 +236,17 @@ def estimate_scale(
) -> Union[float, jnp.ndarray]:
r"""Estimate the scale parameter for the model.
- This method estimates the scale parameter, often denoted as $\phi$, which determines the dispersion
+ This method estimates the scale parameter, often denoted as :math:`\phi`, which determines the dispersion
of an exponential family distribution. The probability density function (pdf) for such a distribution
is generally expressed as
- $f(x; \theta, \phi) \propto \exp \left(a(\phi)\left( y\theta - \mathcal{k}(\theta) \right)\right)$.
+ :math:`f(x; \theta, \phi) \propto \exp \left(a(\phi)\left( y\theta - \mathcal{k}(\theta) \right)\right)`.
The relationship between variance and the scale parameter is given by:
- $$
- \text{var}(Y) = \frac{V(\mu)}{a(\phi)}.
- $$
- The scale parameter, $\phi$, is necessary for capturing the variance of the data accurately.
+ .. math::
+ \text{var}(Y) = \frac{V(\mu)}{a(\phi)}.
+
+ The scale parameter, :math:`\phi`, is necessary for capturing the variance of the data accurately.
Parameters
----------
@@ -265,61 +269,65 @@ def pseudo_r2(
scale: Union[float, jnp.ndarray, NDArray] = 1.0,
aggregate_sample_scores: Callable = jnp.mean,
) -> jnp.ndarray:
- r"""Pseudo-$R^2$ calculation for a GLM.
+ r"""Pseudo-:math:`R^2` calculation for a GLM.
- Compute the pseudo-$R^2$ metric for the GLM, as defined by McFadden et al.[$^{[1]}$](#references)
- or by Cohen et al.[$^{[2]}$](#references).
+ Compute the pseudo-:math:`R^2` metric for the GLM, as defined by McFadden et al. [1]_
+ or by Cohen et al. [2]_.
This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a
- constant mean for the observations. While the pseudo-$R^2$ is bounded between 0 and 1 for the training set,
- it can yield negative values on out-of-sample data, indicating potential over-fitting.
+ constant mean for the observations. While the pseudo-:math:`R^2` is bounded between 0 and 1 for the
+ training set, it can yield negative values on out-of-sample data, indicating potential over-fitting.
Parameters
----------
y:
- The neural activity. Expected shape: (n_time_bins, )
+ The neural activity. Expected shape: ``(n_time_bins, )``
predicted_rate:
- The mean neural activity. Expected shape: (n_time_bins, )
+ The mean neural activity. Expected shape: ``(n_time_bins, )``
score_type:
- The pseudo-R$^2$ type.
+ The pseudo-:math:`R^2` type.
scale:
The scale parameter of the model.
Returns
-------
:
- The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ The pseudo-:math:`R^2` of the model. A value closer to 1 indicates a better model fit,
whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
Notes
-----
- - The McFadden pseudo-$R^2$ is given by:
- $$
+ - The McFadden pseudo-:math:`R^2` is given by:
+
+ .. math::
R^2_{\text{mcf}} = 1 - \frac{\log(L_{M})}{\log(L_0)}.
- $$
- *Equivalent to statsmodels
- [`GLMResults.pseudo_rsquared(kind='mcf')`](https://www.statsmodels.org/dev/generated/statsmodels.genmod.generalized_linear_model.GLMResults.pseudo_rsquared.html).*
- - The Cohen pseudo-$R^2$ is given by:
- $$
+
+ *Equivalent to statsmodels*
+ `GLMResults.pseudo_rsquared(kind='mcf') `_ .
+
+ - The Cohen pseudo-:math:`R^2` is given by:
+
+ .. math::
\begin{aligned}
R^2_{\text{Cohen}} &= \frac{D_0 - D_M}{D_0} \\\
&= 1 - \frac{\log(L_s) - \log(L_M)}{\log(L_s)-\log(L_0)},
\end{aligned}
- $$
- where $L_M$, $L_0$ and $L_s$ are the likelihood of the fitted model, the null model (a
- model with only the intercept term), and the saturated model (a model with one parameter per
- sample, i.e. the maximum value that the likelihood could possibly achieve). $D_M$ and $D_0$ are
- the model and the null deviance, $D_i = -2 \left[ \log(L_s) - \log(L_i) \right]$ for $i=M,0$.
-
- # References
- ------------
- [1] McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
- developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318).
- London: Croom Helm.
-
- [2] Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
- *Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
- 3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
+
+ where :math:`L_M`, :math:`L_0` and :math:`L_s` are the likelihood of the fitted model, the null model (a
+ model with only the intercept term), and the saturated model (a model with one parameter per
+ sample, i.e. the maximum value that the likelihood could possibly achieve). :math:`D_M` and :math:`D_0` are
+ the model and the null deviance, :math:`D_i = -2 \left[ \log(L_s) - \log(L_i) \right]` for :math:`i=M,0`.
+
+ References
+ ----------
+ .. [1] McFadden D (1979). Quantitative methods for analysing travel behavior of individuals: Some recent
+ developments. In D. A. Hensher & P. R. Stopher (Eds.), *Behavioural travel modelling* (pp. 279-318).
+ London: Croom Helm.
+
+ .. [2] Jacob Cohen, Patricia Cohen, Steven G. West, Leona S. Aiken.
+ *Applied Multiple Regression/Correlation Analysis for the Behavioral Sciences*.
+ 3rd edition. Routledge, 2002. p.502. ISBN 978-0-8058-2223-6. (May 2012)
"""
if score_type == "pseudo-r2-McFadden":
pseudo_r2 = self._pseudo_r2_mcfadden(
@@ -342,22 +350,22 @@ def _pseudo_r2_cohen(
predicted_rate: jnp.ndarray,
aggregate_sample_scores: Callable = jnp.mean,
) -> jnp.ndarray:
- r"""Cohen's pseudo-$R^2$.
+ r"""Cohen's pseudo-:math:`R^2`.
- Compute the pseudo-$R^2$ metric as defined by Cohen et al. (2002). See
- [`pseudo_r2`](#pseudo_r2) for additional information.
+ Compute the pseudo-:math:`R^2` metric as defined by Cohen et al. (2002). See
+ :meth:`nemos.observation_models.Observations.pseudo_r2` for additional information.
Parameters
----------
y:
- The neural activity. Expected shape: (n_time_bins, )
+ The neural activity. Expected shape: ``(n_time_bins, )``.
predicted_rate:
- The mean neural activity. Expected shape: (n_time_bins, )
+ The mean neural activity. Expected shape: ``(n_time_bins, )``
Returns
-------
:
- The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ The pseudo-:math:`R^2` of the model. A value closer to 1 indicates a better model fit,
whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
"""
model_dev_t = self.deviance(y, predicted_rate)
@@ -376,10 +384,10 @@ def _pseudo_r2_mcfadden(
aggregate_sample_scores: Callable = jnp.mean,
):
"""
- McFadden's pseudo-$R^2$.
+ McFadden's pseudo-:math:`R^2`.
- Compute the pseudo-$R^2$ metric as defined by McFadden et al. (1979). See
- [`pseudo_r2`](#pseudo_r2) for additional information.
+ Compute the pseudo-:math:`R^2` metric as defined by McFadden et al. (1979). See
+ :meth:`nemos.observation_models.Observations.pseudo_r2` for additional information.
Parameters
----------
@@ -393,7 +401,7 @@ def _pseudo_r2_mcfadden(
Returns
-------
:
- The pseudo-$R^2$ of the model. A value closer to 1 indicates a better model fit,
+ The pseudo-:math:`R^2` of the model. A value closer to 1 indicates a better model fit,
whereas a value closer to 0 suggests that the model doesn't improve much over the null model.
"""
mean_y = jnp.ones(y.shape) * y.mean(axis=0)
@@ -420,11 +428,7 @@ class PoissonObservations(Observations):
Attributes
----------
inverse_link_function :
- A function that maps the predicted rate to the domain of the Poisson parameter. Defaults to jnp.exp.
-
- See Also
- --------
- [Observations](./#nemos.observation_models.Observations) : Base class for observation models.
+ A function that maps the predicted rate to the domain of the Poisson parameter. Defaults to ``jax.numpy.exp``.
"""
@@ -459,20 +463,19 @@ def _negative_log_likelihood(
-----
The formula for the Poisson mean log-likelihood is the following,
- $$
+ .. math::
\begin{aligned}
\text{LL}(\hat{\lambda} | y) &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T}
- [y\_{tn} \log(\hat{\lambda}\_{tn}) - \hat{\lambda}\_{tn} - \log({y\_{tn}!})] \\\
- &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
- \hat{\lambda}\_{tn} - \Gamma({y\_{tn}+1})] \\\
- &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
- \hat{\lambda}\_{tn}] + \\text{const}
+ [y_{tn} \log(\hat{\lambda}_{tn}) - \hat{\lambda}_{tn} - \log({y_{tn}!})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) -
+ \hat{\lambda}_{tn} - \Gamma({y_{tn}+1})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) -
+ \hat{\lambda}_{tn}] + \\text{const}
\end{aligned}
- $$
- Because $\Gamma(k+1)=k!$, see [wikipedia](https://en.wikipedia.org/wiki/Gamma_function) for explanation.
+ Because :math:`\Gamma(k+1)=k!`, see `wikipedia ` for explanation.
- The $\log({y\_{tn}!})$ term is not a function of the parameters and can be disregarded
+ The :math:`\log({y_{tn}!})` term is not a function of the parameters and can be disregarded
when computing the loss-function. This is why we incorporated it into the `const` term.
"""
predicted_rate = jnp.clip(
@@ -497,9 +500,9 @@ def log_likelihood(
Parameters
----------
y :
- The target spikes to compare against. Shape (n_time_bins, ), or (n_time_bins, n_neurons).
+ The target spikes to compare against. Shape ``(n_time_bins, )``, or ``(n_time_bins, n_neurons)``.
predicted_rate :
- The predicted rate of the current model. Shape (n_time_bins, ), or (n_time_bins, n_neurons).
+ The predicted rate of the current model. Shape ``(n_time_bins, )``, or ``(n_time_bins, n_neurons)``.
scale :
The scale parameter of the model.
aggregate_sample_scores :
@@ -514,20 +517,20 @@ def log_likelihood(
-----
The formula for the Poisson mean log-likelihood is the following,
- $$
- \begin{aligned}
- \text{LL}(\hat{\lambda} | y) &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T}
- [y\_{tn} \log(\hat{\lambda}\_{tn}) - \hat{\lambda}\_{tn} - \log({y\_{tn}!})] \\\
- &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
- \hat{\lambda}\_{tn} - \Gamma({y\_{tn}+1})] \\\
- &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y\_{tn} \log(\hat{\lambda}\_{tn}) -
- \hat{\lambda}\_{tn}] + \\text{const}
- \end{aligned}
- $$
+ .. math::
+ \begin{aligned}
+ \text{LL}(\hat{\lambda} | y) &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T}
+ [y_{tn} \log(\hat{\lambda}_{tn}) - \hat{\lambda}_{tn} - \log({y_{tn}!})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) -
+ \hat{\lambda}_{tn} - \Gamma({y_{tn}+1})] \\\
+ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) -
+ \hat{\lambda}_{tn}] + \text{const}
+ \end{aligned}
+
- Because $\Gamma(k+1)=k!$, see [wikipedia](https://en.wikipedia.org/wiki/Gamma_function) for explanation.
+ Because :math:`\Gamma(k+1)=k!`, see `wikipedia `_ for explanation.
- The $\log({y\_{tn}!})$ term is not a function of the parameters and can be disregarded
+ The :math:`\log({y_{tn}!})` term is not a function of the parameters and can be disregarded
when computing the loss-function. This is why we incorporated it into the `const` term.
"""
nll = self._negative_log_likelihood(y, predicted_rate, aggregate_sample_scores)
@@ -550,7 +553,8 @@ def sample_generator(
key :
Random key used for the generation of random numbers in JAX.
predicted_rate :
- Expected rate (lambda) of the Poisson distribution. Shape (n_time_bins, ), or (n_time_bins, n_neurons).
+ Expected rate (lambda) of the Poisson distribution. Shape ``(n_time_bins, )``, or
+ ``(n_time_bins, n_neurons)``.
scale :
Scale parameter. For Poisson should be equal to 1.
@@ -572,9 +576,10 @@ def deviance(
Parameters
----------
spike_counts:
- The spike counts. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.
+ The spike counts. Shape ``(n_time_bins, )`` or ``(n_time_bins, n_neurons)`` for population models.
predicted_rate:
- The predicted firing rates. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.
+ The predicted firing rates. Shape ``(n_time_bins, )`` or ``(n_time_bins, n_neurons)`` for
+ population models.
scale:
Scale parameter of the model.
@@ -588,16 +593,15 @@ def deviance(
The deviance is a measure of the goodness of fit of a statistical model.
For a Poisson model, the residual deviance is computed as:
- $$
- \begin{aligned}
- D(y\_{tn}, \hat{y}\_{tn}) &= 2 \left[ y\_{tn} \log\left(\frac{y\_{tn}}{\hat{y}\_{tn}}\right)
- - (y\_{tn} - \hat{y}\_{tn}) \right]\\\
- &= 2 \left( \text{LL}\left(y\_{tn} | y\_{tn}\right) - \text{LL}\left(y\_{tn} | \hat{y}\_{tn}\right)\right)
- \end{aligned}
- $$
+ .. math::
+ \begin{aligned}
+ D(y_{tn}, \hat{y}_{tn}) &= 2 \left[ y_{tn} \log\left(\frac{y_{tn}}{\hat{y}_{tn}}\right)
+ - (y_{tn} - \hat{y}_{tn}) \right]\\\
+ &= 2 \left( \text{LL}\left(y_{tn} | y_{tn}\right) - \text{LL}\left(y_{tn} | \hat{y}_{tn}\right)\right)
+ \end{aligned}
- where $ y $ is the observed data, $ \hat{y} $ is the predicted data, and $\text{LL}$ is the model
- log-likelihood. Lower values of deviance indicate a better fit.
+ where :math:`y` is the observed data, :math:`\hat{y}` is the predicted data, and :math:`\text{LL}` is
+ the model log-likelihood. Lower values of deviance indicate a better fit.
"""
# this takes care of 0s in the log
ratio = jnp.clip(
@@ -615,13 +619,15 @@ def estimate_scale(
r"""
Assign 1 to the scale parameter of the Poisson model.
- For the Poisson exponential family distribution, the scale parameter $\phi$ is always 1.
+ For the Poisson exponential family distribution, the scale parameter :math:`\phi` is always 1.
This property is consistent with the fact that the variance equals the mean in a Poisson distribution.
As given in the general exponential family expression:
- $$
- \text{var}(Y) = \frac{V(\mu)}{a(\phi)},
- $$
- for the Poisson family, it simplifies to $\text{var}(Y) = \mu$ since $a(\phi) = 1$ and $V(\mu) = \mu$.
+
+ .. math::
+ \text{var}(Y) = \frac{V(\mu)}{a(\phi)},
+
+ for the Poisson family, it simplifies to :math:`\text{var}(Y) = \mu` since :math:`a(\phi) = 1`
+ and :math:`V(\mu) = \mu`.
Parameters
----------
@@ -649,10 +655,6 @@ class GammaObservations(Observations):
inverse_link_function :
A function that maps the predicted rate to the domain of the Poisson parameter. Defaults to jnp.exp.
- See Also
- --------
- [Observations](./#nemos.observation_models.Observations) : Base class for observation models.
-
"""
def __init__(self, inverse_link_function=lambda x: jnp.power(x, -1)):
@@ -786,16 +788,16 @@ def deviance(
The deviance is a measure of the goodness of fit of a statistical model.
For a Gamma model, the residual deviance is computed as:
- $$
- \begin{aligned}
- D(y\_{tn}, \hat{y}\_{tn}) &= 2 \left[ -\log \frac{ y\_{tn}}{\hat{y}\_{tn}} + \frac{y\_{tn} -
- \hat{y}\_{tn}}{\hat{y}\_{tn}}\right]\\\
- &= 2 \left( \text{LL}\left(y\_{tn} | y\_{tn}\right) - \text{LL}\left(y\_{tn} | \hat{y}\_{tn}\right) \right)
- \end{aligned}
- $$
+ .. math::
+ \begin{aligned}
+ D(y_{tn}, \hat{y}_{tn}) &= 2 \left[ -\log \frac{ y_{tn}}{\hat{y}_{tn}} + \frac{y_{tn} -
+ \hat{y}_{tn}}{\hat{y}_{tn}}\right]\\\
+ &= 2 \left( \text{LL}\left(y_{tn} | y_{tn}\right) - \text{LL}\left(y_{tn} | \hat{y}_{tn}\right) \right)
+ \end{aligned}
- where $ y $ is the observed data, $ \hat{y} $ is the predicted data, and $\text{LL}$ is the model
+ where :math:`y` is the observed data, :math:`\hat{y}` is the predicted data, and :math:`\text{LL}` is the model
log-likelihood. Lower values of deviance indicate a better fit.
+
"""
y_mu = jnp.clip(neural_activity / predicted_rate, min=jnp.finfo(float).eps)
resid_dev = 2 * (
@@ -812,11 +814,12 @@ def estimate_scale(
r"""
Estimate the scale of the model based on the GLM residuals.
- For $y \sim \Gamma$ the scale is equal to,
- $$
- \Phi = \frac{\text{Var(y)}}{V(\mu)}
- $$
- with $V(\mu) = \mu^2$.
+ For :math:`y \sim \Gamma` the scale is equal to,
+
+ .. math::
+ \Phi = \frac{\text{Var(y)}}{V(\mu)}
+
+ with :math:`V(\mu) = \mu^2`.
Therefore, the scale can be estimated as the ratio of the sample variance to the squared rate.
@@ -833,7 +836,7 @@ def estimate_scale(
Returns
-------
:
- The scale parameter. If predicted_rate is (n_samples, n_neurons), this method will return a
+ The scale parameter. If predicted_rate is ``(n_samples, n_neurons)``, this method will return a
scale for each neuron.
"""
predicted_rate = jnp.clip(
@@ -865,9 +868,11 @@ def check_observation_model(observation_model):
If the `observation_model` does not have one of the required attributes.
TypeError
- - If an attribute is not a callable function.
- - If a function does not return a jax.numpy.ndarray.
- - If 'inverse_link_function' is not differentiable.
+ If an attribute is not a callable function.
+ TypeError
+ If a function does not return a jax.numpy.ndarray.
+ TypeError
+ If 'inverse_link_function' is not differentiable.
Examples
--------
diff --git a/src/nemos/proximal_operator.py b/src/nemos/proximal_operator.py
index 51602ed5..dc4f5b6d 100644
--- a/src/nemos/proximal_operator.py
+++ b/src/nemos/proximal_operator.py
@@ -11,7 +11,7 @@
More formally, proximal operators solve the minimization problem,
$$
-\\text{prox}\_f(\bm{v}) = \arg\min\_{\bm{x}} \left( f(\bm{x}) + \frac{1}{2}\Vert \bm{x} - \bm{v}\Vert_2 ^2 \right)
+\\text{prox}_f(\bm{v}) = \arg\min_{\bm{x}} \left( f(\bm{x}) + \frac{1}{2}\Vert \bm{x} - \bm{v}\Vert_2 ^2 \right)
$$
@@ -106,7 +106,7 @@ def prox_group_lasso(
The proximal operator equation are,
$$
- \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda \sum\_{g=1}^G \Vert \beta_g \Vert_2 +
+ \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda \sum_{g=1}^G \Vert \beta_g \Vert_2 +
\frac{1}{2} \Vert \hat{\beta} - \beta \Vert_2^2
\right],
$$
@@ -115,15 +115,15 @@ def prox_group_lasso(
The analytical solution[$^{[1]}$](#references). for the beta is,
$$
- \text{prox}(\beta\_g) = \max \left(1 - \frac{\lambda \sqrt{p\_g}}{\Vert \hat{\beta}\_g \Vert_2},
- 0\right) \cdot \hat{\beta}\_g,
+ \text{prox}(\beta_g) = \max \left(1 - \frac{\lambda \sqrt{p_g}}{\Vert \hat{\beta}_g \Vert_2},
+ 0\right) \cdot \hat{\beta}_g,
$$
- where $p_g$ is the dimensionality of $\beta\_g$ and $\hat{\beta}$ is typically the gradient step
+ where $p_g$ is the dimensionality of $\beta_g$ and $\hat{\beta}$ is typically the gradient step
of the un-regularized optimization objective function. It's easy to see how the group-Lasso
proximal operator acts as a shrinkage factor for the un-penalize update, and the half-rectification
non-linearity that effectively sets to zero group of coefficients satisfying,
$$
- \Vert \hat{\beta}\_g \Vert_2 \le \frac{1}{\lambda \sqrt{p\_g}}.
+ \Vert \hat{\beta}_g \Vert_2 \le \frac{1}{\lambda \sqrt{p_g}}.
$$
# References
@@ -154,8 +154,8 @@ def prox_lasso(x: Any, l1reg: Optional[Any] = None, scaling: float = 1.0) -> Any
Minimizes the following function:
$$
- \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||\_2^2
- + \text{scaling} \cdot \text{l1reg} \cdot ||y||\_1
+ \underset{y}{\text{argmin}} ~ \frac{1}{2} ||x - y||_2^2
+ + \text{scaling} \cdot \text{l1reg} \cdot ||y||_1
$$
When `l1reg` is a pytree, the weights are applied coordinate-wise.
diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py
index 6d6cf0bd..498eb031 100644
--- a/src/nemos/regularizer.py
+++ b/src/nemos/regularizer.py
@@ -96,22 +96,10 @@ def get_proximal_operator(
class UnRegularized(Regularizer):
"""
- Solver class for optimizing unregularized models.
+ Regularizer class for unregularized models.
- This class provides an interface to various optimization methods for models that
- do not involve regularization. The optimization methods that are allowed for this
- class are defined in the `allowed_solvers` attribute.
-
- Attributes
- ----------
- allowed_solvers : list of str
- List of solver names that are allowed for this regularizer class.
- default_solver :
- Default solver for this regularizer is GradientDescent.
-
- See Also
- --------
- [Regularizer](./#nemos.regularizer.Regularizer) : Base solver class from which this class inherits.
+ This class equips models with the identity proximal operator (no shrinkage) and the
+ unpenalized loss function.
"""
_allowed_solvers = (
@@ -133,32 +121,31 @@ def __init__(
def penalized_loss(self, loss: Callable, regularizer_strength: float):
"""
- Returns the original loss function unpenalized. Unregularized regularization method does not add any
- penalty.
+ Returns the original loss function unpenalized.
+
+ Unregularized regularization method does not add any penalty.
"""
return loss
def get_proximal_operator(
self,
) -> ProximalOperator:
- """Unregularized method has no proximal operator."""
+ """
+ Returns the identity operator.
+
+ Unregularized method corresponds to an identity proximal operator, since no
+ shrinkage factor is applied.
+ """
return jaxopt.prox.prox_none
class Ridge(Regularizer):
"""
- Solver for Ridge regularization using various optimization algorithms.
-
- This class uses `jaxopt` optimizers to perform Ridge regularization. It extends
- the base Solver class, with the added feature of Ridge penalization.
+ Regularizer class for Ridge (L2 regularization).
- Attributes
- ----------
- allowed_solvers : List[..., str]
- A list of solver names that are allowed to be used with this regularizer.
- default_solver :
- Default solver for this regularizer is GradientDescent.
+ This class equips models with the Ridge proximal operator and the
+ Ridge penalized loss function.
"""
_allowed_solvers = (
@@ -240,10 +227,10 @@ def prox_op(params, l2reg, scaling=1.0):
class Lasso(Regularizer):
"""
- Optimization solver using the Lasso (L1 regularization) method with Proximal Gradient.
+ Regularizer class for Lasso (L1 regularization).
- This class is a specialized version of the ProxGradientSolver with the proximal operator
- set for L1 regularization (Lasso). It utilizes the `jaxopt` library's proximal gradient optimizer.
+ This class equips models with the Lasso proximal operator and the
+ Lasso penalized loss function.
"""
_allowed_solvers = (
@@ -318,33 +305,30 @@ def _penalized_loss(params, X, y):
class GroupLasso(Regularizer):
"""
- Optimization solver using the Group Lasso regularization method with Proximal Gradient.
+ Regularizer class for Group Lasso (group-L1) regularized models.
- This class is a specialized version of the ProxGradientSolver with the proximal operator
- set for Group Lasso regularization. The Group Lasso regularization induces sparsity on groups
- of features rather than individual features.
+ This class equips models with the group-lasso proximal operator and the
+ group-lasso penalized loss function.
Attributes
----------
mask :
- A 2d mask array indicating groups of features for regularization, shape (num_groups, num_features).
+ A 2d mask array indicating groups of features for regularization, shape ``(num_groups, num_features)``.
Each row represents a group of features.
Each column corresponds to a feature, where a value of 1 indicates that the feature belongs
to the group, and a value of 0 indicates it doesn't.
- Default is `mask = np.ones((1, num_features))`, grouping all features in a single group.
+ Default is ``mask = np.ones((1, num_features))``, grouping all features in a single group.
Examples
--------
>>> import numpy as np
>>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso
>>> from nemos.glm import GLM
-
>>> # simulate some counts
>>> num_samples, num_features, num_groups = 1000, 5, 3
>>> X = np.random.normal(size=(num_samples, num_features)) # design matrix
>>> w = [0, 0.5, 1, 0, -0.5] # define some weights
>>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts
-
>>> # Define a mask for 3 groups and 5 features
>>> mask = np.zeros((num_groups, num_features))
>>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3
@@ -439,12 +423,14 @@ def _penalization(
Note: the penalty is being calculated according to the following formula:
- $$\\text{loss}(\beta_1,...,\beta_g) + \alpha \cdot \sum _{j=1...,g} \sqrt{\dim(\beta_j)} || \beta_j||_2$$
+ .. math::
- where $g$ is the number of groups, $\dim(\cdot)$ is the dimension of the vector,
- i.e. the number of coefficient in each $\beta_j$, and $||\cdot||_2$ is the euclidean norm.
+ \\text{loss}(\beta_1,...,\beta_g) + \alpha \cdot \sum _{j=1...,g} \sqrt{\dim(\beta_j)} || \beta_j||_2
+ where :math:`g` is the number of groups, :math:`\dim(\cdot)` is the dimension of the vector,
+ i.e. the number of coefficient in each :math:`\beta_j`, and :math:`||\cdot||_2` is the euclidean norm.
"""
+
# conform to shape (1, n_features) if param is (n_features,) or (n_neurons, n_features) if
# param is (n_features, n_neurons)
param_with_extra_axis = jnp.atleast_2d(params[0].T)
diff --git a/src/nemos/simulation.py b/src/nemos/simulation.py
index e698e702..48af7008 100644
--- a/src/nemos/simulation.py
+++ b/src/nemos/simulation.py
@@ -1,5 +1,7 @@
"""Utility functions for coupling filter definition."""
+from __future__ import annotations
+
from typing import Callable, Tuple, Union
import jax
@@ -20,7 +22,8 @@ def difference_of_gammas(
inhib_b: float = 1.0,
excit_b: float = 2.0,
) -> NDArray:
- r"""Generate coupling filter as a Gamma pdf difference.
+ r"""
+ Generate coupling filter as a Gamma pdf difference.
Parameters
----------
@@ -30,22 +33,24 @@ def difference_of_gammas(
Upper bound of the gamma range as a percentile. The gamma function
will be evaluated over the range [0, ppf(upper_percentile)].
inhib_a:
- The `a` constant for the gamma pdf of the inhibitory part of the filter.
+ The ``a`` constant for the gamma pdf of the inhibitory part of the filter.
excit_a:
- The `a` constant for the gamma pdf of the excitatory part of the filter.
+ The ``a`` constant for the gamma pdf of the excitatory part of the filter.
inhib_b:
- The `b` constant for the gamma pdf of the inhibitory part of the filter.
+ The ``b`` constant for the gamma pdf of the inhibitory part of the filter.
excit_b:
- The `a` constant for the gamma pdf of the excitatory part of the filter.
+ The ``a`` constant for the gamma pdf of the excitatory part of the filter.
Notes
-----
The probability density function of a gamma distribution is parametrized as
- follows[$^{[1]}$](#references):,
- $$
+ follows [1]_ :,
+
+ .. math::
+
p(x;\; a, b) = \frac{b^a x^{a-1} e^{-x}}{\Gamma(a)},
- $$
- where $\Gamma(a)$ refers to the gamma function, see [[1]](#references):.
+
+ where :math:`\Gamma(a)` refers to the gamma function, see [1]_.
Returns
-------
@@ -55,12 +60,33 @@ def difference_of_gammas(
Raises
------
ValueError:
- - If any of the Gamma parameters is lesser or equal to 0.
- - If the upper_percentile is not in [0, 1).
+ If any of the Gamma parameters is lesser or equal to 0.
+ ValueError:
+ If the upper_percentile is not in [0, 1).
+
+ References
+ ----------
+ .. [1] SciPy Docs -
+ :meth:`scipy.stats.gamma `
+
+ Examples
+ --------
+ >>> import matplotlib.pyplot as plt
+ >>> from nemos.simulation import difference_of_gammas
+ >>> coupling_duration = 100
+ >>> inhib_a, inhib_b = 1.0, 1.0
+ >>> excit_a, excit_b = 2.0, 2.0
+ >>> coupling_filter = difference_of_gammas(
+ ... ws=coupling_duration,
+ ... inhib_a=inhib_a,
+ ... inhib_b=inhib_b,
+ ... excit_a=excit_a,
+ ... excit_b=excit_b
+ ... )
+ >>> _ = plt.plot(coupling_filter)
+ >>> _ = plt.title("Coupling filter from difference of gammas")
+ >>> _ = plt.show()
- # References
- ------------
- SciPy Docs - ["scipy.stats.gamma"](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gamma.html)
"""
# check that the gamma parameters are positive (scipy returns
# nans otherwise but no exception is raised)
@@ -102,21 +128,43 @@ def regress_filter(coupling_filters: NDArray, eval_basis: NDArray) -> NDArray:
Parameters
----------
coupling_filters:
- The coupling filters. Shape (window_size, n_neurons_receiver, n_neurons_sender)
+ The coupling filters. Shape ``(window_size, n_neurons_receiver, n_neurons_sender)``
eval_basis:
- The evaluated basis function, shape (window_size, n_basis_funcs)
+ The evaluated basis function, shape ``(window_size, n_basis_funcs)``
Returns
-------
weights:
- The weights for each neuron. Shape (n_neurons_receiver, n_neurons_sender, n_basis_funcs)
+ The weights for each neuron. Shape ``(n_basis_funcs, n_neurons_receiver, n_neurons_sender)``
Raises
------
ValueError
- - If eval_basis is not two-dimensional
- - If coupling_filters is not three-dimensional
- - If window_size differs between eval_basis and coupling_filters
+ If eval_basis is not two-dimensional.
+ ValueError
+ If coupling_filters is not three-dimensional.
+ ValueError
+ If window_size differs between eval_basis and coupling_filters.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> import matplotlib.pyplot as plt
+ >>> from nemos.simulation import regress_filter, difference_of_gammas
+ >>> from nemos.basis import RaisedCosineBasisLog
+ >>> filter_duration = 100
+ >>> n_basis_funcs = 20
+ >>> filter_bank = difference_of_gammas(filter_duration).reshape(filter_duration, 1, 1)
+ >>> _, basis = RaisedCosineBasisLog(10).evaluate_on_grid(filter_duration)
+ >>> weights = regress_filter(filter_bank, basis)[0, 0]
+ >>> print("Weights shape:", weights.shape)
+ Weights shape: (10,)
+ >>> _ = plt.plot(filter_bank[:, 0, 0], label=f"True filter")
+ >>> _ = plt.plot(basis.dot(weights), "--", label=f"Approx. filter")
+ >>> _ = plt.legend()
+ >>> _ = plt.title("True vs. Approximated Filters")
+ >>> _ = plt.show()
+
"""
# check shapes
if eval_basis.ndim != 2:
@@ -182,36 +230,69 @@ def simulate_recurrent(
Expected shape: (n_neurons (receiver), n_neurons (sender), n_basis_coupling).
feedforward_coef :
Coefficients for the feedforward inputs to each neuron.
- Expected shape: (n_neurons, n_basis_input).
+ Expected shape: ``(n_neurons, n_basis_input)``.
intercepts :
- Bias term for each neuron. Expected shape: (n_neurons,).
+ Bias term for each neuron. Expected shape: ``(n_neurons,)``.
random_key :
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
- Expected shape: (n_time_bins, n_neurons, n_basis_input).
+ Expected shape: ``(n_time_bins, n_neurons, n_basis_input)``.
init_y :
Initial observation (spike counts for PoissonGLM) matrix that kickstarts the simulation.
- Expected shape: (window_size, n_neurons).
+ Expected shape: ``(window_size, n_neurons)``.
coupling_basis_matrix :
Basis matrix for coupling, representing between-neuron couplings
- and auto-correlations. Expected shape: (window_size, n_basis_coupling).
-
+ and auto-correlations. Expected shape: ``(window_size, n_basis_coupling)``.
+ inverse_link_function :
+ The inverse link function for the observation model.
Returns
-------
simulated_activity :
Simulated activity (spike counts for PoissonGLMs) for each neuron over time.
- Shape, (n_time_bins, n_neurons).
+ Shape, ``(n_time_bins, n_neurons)``.
firing_rates :
- Simulated rates for each neuron over time. Shape, (n_time_bins, n_neurons,).
+ Simulated rates for each neuron over time. Shape, ``(n_time_bins, n_neurons,)``.
Raises
------
ValueError
- - If there's an inconsistency between the number of neurons in model parameters.
- - If the number of neurons in input arguments doesn't match with model parameters.
+ If there's an inconsistency between the number of neurons in model parameters.
+ ValueError
+ If the number of neurons in input arguments doesn't match with model parameters.
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> import jax
+ >>> import matplotlib.pyplot as plt
+ >>> from nemos.simulation import simulate_recurrent
+ >>>
+ >>> n_neurons = 2
+ >>> coupling_duration = 100
+ >>> feedforward_input = np.random.normal(size=(1000, n_neurons, 1))
+ >>> coupling_basis = np.random.normal(size=(coupling_duration, 10))
+ >>> coupling_coef = np.random.normal(size=(n_neurons, n_neurons, 10))
+ >>> intercept = -8 * np.ones(n_neurons)
+ >>> init_spikes = np.zeros((coupling_duration, n_neurons))
+ >>> random_key = jax.random.key(123)
+ >>> spikes, rates = simulate_recurrent(
+ ... coupling_coef=coupling_coef,
+ ... feedforward_coef=np.ones((n_neurons, 1)),
+ ... intercepts=intercept,
+ ... random_key=random_key,
+ ... feedforward_input=feedforward_input,
+ ... coupling_basis_matrix=coupling_basis,
+ ... init_y=init_spikes
+ ... )
+ >>> _ = plt.figure()
+ >>> _ = plt.plot(rates[:, 0], label="Neuron 0 rate")
+ >>> _ = plt.plot(rates[:, 1], label="Neuron 1 rate")
+ >>> _ = plt.legend()
+ >>> _ = plt.title("Simulated firing rates")
+ >>> _ = plt.show()
"""
if isinstance(feedforward_input, FeaturePytree):
raise ValueError(
@@ -303,7 +384,7 @@ def scan_fn(
# 1. The first dimension is time, and 1 is by construction since we are simulating 1
# sample
# 2. Flatten to shape (n_neuron * n_basis_coupling, )
- conv_act = convolve.reshape_convolve(activity, coupling_basis_matrix).reshape(
+ conv_act = convolve.tensor_convolve(activity, coupling_basis_matrix).reshape(
-1,
)
diff --git a/src/nemos/solvers/_svrg_defaults.py b/src/nemos/solvers/_svrg_defaults.py
index a12d1098..47aea11d 100644
--- a/src/nemos/solvers/_svrg_defaults.py
+++ b/src/nemos/solvers/_svrg_defaults.py
@@ -422,7 +422,7 @@ def _calculate_optimal_batch_size_svrg(
num_samples:
The number of samples.
l_smooth_max:
- The $L\_{\text{max}}$ smoothness constant.
+ The $L_{\text{max}}$ smoothness constant.
l_smooth:
The $L$ smoothness constant.
strong_convexity:
@@ -480,7 +480,7 @@ def _calculate_b_hat(num_samples: int, l_smooth_max: float, l_smooth: float):
num_samples :
Total number of data points.
l_smooth_max :
- Maximum smoothness constant $L\_{\text{max}}$.
+ Maximum smoothness constant $L_{\text{max}}$.
l_smooth :
Smoothness constant $L$.
diff --git a/src/nemos/typing.py b/src/nemos/typing.py
index dd9bc5a6..f1cfc4fc 100644
--- a/src/nemos/typing.py
+++ b/src/nemos/typing.py
@@ -4,7 +4,7 @@
import jax.numpy as jnp
import jaxopt
-from jax._src.typing import ArrayLike
+from jax.typing import ArrayLike
from .pytrees import FeaturePytree
diff --git a/src/nemos/utils.py b/src/nemos/utils.py
index 87ad0472..557b6c96 100644
--- a/src/nemos/utils.py
+++ b/src/nemos/utils.py
@@ -53,10 +53,11 @@ def validate_axis(tree: Any, axis: int):
Raises
------
ValueError
- - If the specified axis is equal to or greater than the number of dimensions (`ndim`) of any array
+ If the specified axis is equal to or greater than the number of dimensions (`ndim`) of any array
within the tree. This ensures that operations intended for a specific axis can be safely performed
on every array in the tree.
- - If the axis is negative or non-integer.
+ ValueError
+ If the axis is negative or non-integer.
"""
if not isinstance(axis, int) or axis < 0:
raise ValueError("`axis` must be a non negative integer.")
diff --git a/tests/conftest.py b/tests/conftest.py
index cb39ee37..77af28b5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -82,7 +82,7 @@ def initialize_params(self, *args, **kwargs):
def _predict_and_compute_loss(self, params, X, y):
pass
- def get_optimal_solver_params_config(self):
+ def _get_optimal_solver_params_config(self):
return None, None, None
diff --git a/tests/test_base_class.py b/tests/test_base_class.py
index 2324434b..95680638 100644
--- a/tests/test_base_class.py
+++ b/tests/test_base_class.py
@@ -24,7 +24,7 @@ def predict(self, X: Union[NDArray, jnp.ndarray]) -> jnp.ndarray:
def score(self, X, y, score_type="pseudo-r2-McFadden"):
pass
- def get_optimal_solver_params_config(self):
+ def _get_optimal_solver_params_config(self):
return None, None, None
diff --git a/tests/test_glm.py b/tests/test_glm.py
index d49bb69b..9cdb38eb 100644
--- a/tests/test_glm.py
+++ b/tests/test_glm.py
@@ -1962,7 +1962,7 @@ def test_optimize_solver_params(
raise e
return
- kwargs = model.optimize_solver_params(X, y)
+ kwargs = model._optimize_solver_params(X, y)
if isinstance(batch_size, int) and "batch_size" in solver_kwargs:
# if batch size was provided, then it should be returned unchanged
assert batch_size == kwargs["batch_size"]
@@ -2049,7 +2049,7 @@ def test_optimize_solver_params_pytree(
raise e
return
- kwargs = model.optimize_solver_params(X, y)
+ kwargs = model._optimize_solver_params(X, y)
if isinstance(batch_size, int) and "batch_size" in solver_kwargs:
# if batch size was provided, then it should be returned unchanged
assert batch_size == kwargs["batch_size"]
@@ -4057,7 +4057,7 @@ def test_optimize_solver_params(
raise e
return
- kwargs = model.optimize_solver_params(X, y)
+ kwargs = model._optimize_solver_params(X, y)
if isinstance(batch_size, int) and "batch_size" in solver_kwargs:
# if batch size was provided, then it should be returned unchanged
assert batch_size == kwargs["batch_size"]
@@ -4144,7 +4144,7 @@ def test_optimize_solver_params_pytree(
raise e
return
- kwargs = model.optimize_solver_params(X, y)
+ kwargs = model._optimize_solver_params(X, y)
if isinstance(batch_size, int) and "batch_size" in solver_kwargs:
# if batch size was provided, then it should be returned unchanged
assert batch_size == kwargs["batch_size"]
diff --git a/tox.ini b/tox.ini
index 7df11c35..94bd7a4f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -8,6 +8,10 @@ envlist = py,fix
# and the linters from pyproject.toml
extras = dev
+# Non-interactive backend for doctests
+setenv =
+ MPLBACKEND = Agg
+
# Enable package caching
package_cache = .tox/cache