From a6d60d75910bbbc3a60fc1b66e340b8254227851 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:05:05 -0500 Subject: [PATCH 01/21] add svgs basis --- docs/background/basis/images/EvalBSpline.svg | 360 +++++++++++ .../basis/images/EvalCyclicBSpline.svg | 427 +++++++++++++ docs/background/basis/images/EvalMSpline.svg | 289 +++++++++ .../basis/images/EvalOrthExponential.svg | 564 ++++++++++++++++++ .../basis/images/EvalRaisedCosineLinear.svg | 367 ++++++++++++ .../basis/images/EvalRaisedCosineLog.svg | 310 ++++++++++ 6 files changed, 2317 insertions(+) create mode 100644 docs/background/basis/images/EvalBSpline.svg create mode 100644 docs/background/basis/images/EvalCyclicBSpline.svg create mode 100644 docs/background/basis/images/EvalMSpline.svg create mode 100644 docs/background/basis/images/EvalOrthExponential.svg create mode 100644 docs/background/basis/images/EvalRaisedCosineLinear.svg create mode 100644 docs/background/basis/images/EvalRaisedCosineLog.svg diff --git a/docs/background/basis/images/EvalBSpline.svg b/docs/background/basis/images/EvalBSpline.svg new file mode 100644 index 00000000..66776ff8 --- /dev/null +++ b/docs/background/basis/images/EvalBSpline.svg @@ -0,0 +1,360 @@ + + + + + + + + 2024-12-02T16:54:50.980497 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/background/basis/images/EvalCyclicBSpline.svg b/docs/background/basis/images/EvalCyclicBSpline.svg new file mode 100644 index 00000000..35fa9481 --- /dev/null +++ b/docs/background/basis/images/EvalCyclicBSpline.svg @@ -0,0 +1,427 @@ + + + + + + + + 2024-12-02T16:54:50.989322 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/background/basis/images/EvalMSpline.svg b/docs/background/basis/images/EvalMSpline.svg new file mode 100644 index 00000000..36945902 --- /dev/null +++ b/docs/background/basis/images/EvalMSpline.svg @@ -0,0 +1,289 @@ + + + + + + + + 2024-12-02T16:54:50.998126 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/background/basis/images/EvalOrthExponential.svg b/docs/background/basis/images/EvalOrthExponential.svg new file mode 100644 index 00000000..0326c6e6 --- /dev/null +++ b/docs/background/basis/images/EvalOrthExponential.svg @@ -0,0 +1,564 @@ + + + + + + + + 2024-12-02T16:54:51.006464 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/background/basis/images/EvalRaisedCosineLinear.svg b/docs/background/basis/images/EvalRaisedCosineLinear.svg new file mode 100644 index 00000000..fbdbd641 --- /dev/null +++ b/docs/background/basis/images/EvalRaisedCosineLinear.svg @@ -0,0 +1,367 @@ + + + + + + + + 2024-12-02T16:54:51.015187 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/background/basis/images/EvalRaisedCosineLog.svg b/docs/background/basis/images/EvalRaisedCosineLog.svg new file mode 100644 index 00000000..6389fc18 --- /dev/null +++ b/docs/background/basis/images/EvalRaisedCosineLog.svg @@ -0,0 +1,310 @@ + + + + + + + + 2024-12-02T16:54:51.023103 + image/svg+xml + + + Matplotlib v3.9.2, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 4be4c63245cda3352e530f319659a32ffa8c135b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:05:15 -0500 Subject: [PATCH 02/21] added table for basis --- docs/assets/stylesheets/custom.css | 36 ++++++---- docs/background/README.md | 17 +---- docs/background/basis/README.md | 71 +++++++++++++++++++ .../{ => basis}/plot_01_1D_basis_function.md | 0 .../{ => basis}/plot_02_ND_basis_function.md | 0 docs/conf.py | 4 ++ docs/tutorials/plot_02_head_direction.md | 1 + docs/tutorials/plot_03_grid_cells.md | 11 ++- docs/tutorials/plot_05_place_cells.md | 1 + 9 files changed, 109 insertions(+), 32 deletions(-) create mode 100644 docs/background/basis/README.md rename docs/background/{ => basis}/plot_01_1D_basis_function.md (100%) rename docs/background/{ => basis}/plot_02_ND_basis_function.md (100%) diff --git a/docs/assets/stylesheets/custom.css b/docs/assets/stylesheets/custom.css index f7dc7a81..463be8a4 100644 --- a/docs/assets/stylesheets/custom.css +++ b/docs/assets/stylesheets/custom.css @@ -94,17 +94,25 @@ html[data-theme=light]{ font-weight: normal; } -/*!* Style the brackets *!*/ -/*span.fn-bracket {*/ -/* color: #666; !* Dim the brackets *!*/ -/*}*/ - -/*!* Style the links within the footnotes *!*/ -/*aside.footnote a {*/ -/* color: #007BFF; !* Blue link color *!*/ -/* text-decoration: none; !* Remove underline *!*/ -/*}*/ - -/*aside.footnote a:hover {*/ -/* text-decoration: underline; !* Add underline on hover *!*/ -/*}*/ + #table-basis { + table-layout: auto; + width: 100%; +} + +#table-basis th:nth-child(1), #table-basis td:nth-child(1) { + width: 22%; +} + +#table-basis th:nth-child(2), #table-basis td:nth-child(2) { + width: 22%; +} +#table-basis th:nth-child(3), #table-basis td:nth-child(3) { + width: 10%; +} + +#table-basis th:nth-child(4), #table-basis td:nth-child(4) { + width: 20%; +} +#table-basis th:nth-child(5), #table-basis td:nth-child(5) { + width: 10%; +} diff --git a/docs/background/README.md b/docs/background/README.md index 3215c329..331a9b8c 100644 --- a/docs/background/README.md +++ b/docs/background/README.md @@ -34,26 +34,13 @@ plot_00_conceptual_intro.md :::{grid-item-card}
-One-Dimensional Basis. +Basis Functions
```{toctree} :maxdepth: 2 -plot_01_1D_basis_function.md -``` -::: - -:::{grid-item-card} - -
-N-Dimensional Basis. -
- -```{toctree} -:maxdepth: 2 - -plot_02_ND_basis_function.md +basis/README.md ``` ::: diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md new file mode 100644 index 00000000..3ed0c087 --- /dev/null +++ b/docs/background/basis/README.md @@ -0,0 +1,71 @@ +# Basis Function + +```{table} +:name: table-basis + +| **Basis** | **Kernel Visualization** | **Examples** | **Evaluation/Convolution** | **Preferred Mode** | +|:---------------------------------:|:----------------------------------------------------------------------------:|:--------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------:|:--------------------:| +| **B-Spline** | B-spline. | [Grid cells](grid_cells_nemos) | [EvalBSpline](nemos.basis.basis.EvalBSpline)
[ConvBSpline](nemos.basis.basis.ConvBSpline) | 🟢 Eval | +| **Cyclic B-Spline** | Cyclic B-spline. | [Place cells](basis_eval_place_cells) | [EvalCyclicBSpline](nemos.basis.basis.EvalCyclicBSpline)
[ConvCyclicBSpline](nemos.basis.basis.ConvCyclicBSpline) | 🟢 Eval | +| **M-Spline** | M-spline. | [Place cells](basis_eval_place_cells) | [EvalMSpline](nemos.basis.basis.EvalMSpline)
[ConvMSpline](nemos.basis.basis.ConvMSpline) | 🟢 Eval | +| **Linearly Spaced Raised Cosine** | Raised Cosine Linear. | | [EvalRaisedCosineLinear](nemos.basis.basis.EvalRaisedCosineLinear)
[ConvRaisedCosineLinear](nemos.basis.basis.ConvRaisedCosineLinear) | 🟢 Eval | +| **Log Spaced Raised Cosine** | Raised Cosine Log. | [Head Direction](head_direction_reducing_dimensionality) | [EvalRaisedCosineLog](nemos.basis.basis.EvalRaisedCosineLog)
[ConvRaisedCosineLog](nemos.basis.basis.ConvRaisedCosineLog) | 🔵 Conv | +``` + +## Overview + +A basis function is a collection of simple building blocks—functions that, when combined (weighted and summed together), can represent more complex, non-linear relationships. Think of them as tools for constructing predictors in GLMs, helping to model: + +1. **Non-linear mappings** between task variables (like velocity or position) and firing rates. +2. **Linear temporal effects**, such as spike history, neuron-to-neuron couplings, or how stimuli are integrated over time. + +In a GLM, we assume a non-linear mapping exists between task variables and neuronal firing rates. This mapping isn’t something we can directly observe—what we do see are the inputs (task covariates) and the resulting neural activity. The challenge is to infer a "good" approximation of this hidden relationship. + +Basis functions help simplify this process by representing the non-linearity as a weighted sum of fixed functions, $\psi_1(x), \dots, \psi_n(x)$, with weights $\alpha_1, \dots, \alpha_n$. Mathematically: + +$$ +f(x) \approx \alpha_1 \psi_1(x) + \dots + \alpha_n \psi_n(x) +$$ + +Here, $\approx$ means "approximately equal". Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. + + + +## Basis in NeMoS + +NeMoS provides a variety of basis function objects, each tailored for specific shapes and use cases. These objects make it easy to define both non-linear features and temporal predictors. Depending on the type of modeling you need, NeMoS offers: + +- **Eval-basis objects**: For creating non-linear features. (Names start with `Eval`.) +- **Conv-basis objects**: For defining temporal predictors. (Names start with `Conv`.) + +If you want to know how to create and use one-dimensional bases or combining them to build multi-dimensional predictors, check out these resources: + +::::{grid} 1 2 2 2 + +:::{grid-item-card} + +
+One-Dimensional Basis. +
+ +```{toctree} +:maxdepth: 2 + +plot_01_1D_basis_function.md +``` +::: + +:::{grid-item-card} + +
+N-Dimensional Basis. +
+ +```{toctree} +:maxdepth: 2 + +plot_02_ND_basis_function.md +``` +::: + +:::: \ No newline at end of file diff --git a/docs/background/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md similarity index 100% rename from docs/background/plot_01_1D_basis_function.md rename to docs/background/basis/plot_01_1D_basis_function.md diff --git a/docs/background/plot_02_ND_basis_function.md b/docs/background/basis/plot_02_ND_basis_function.md similarity index 100% rename from docs/background/plot_02_ND_basis_function.md rename to docs/background/basis/plot_02_ND_basis_function.md diff --git a/docs/conf.py b/docs/conf.py index c8f5e3a2..6204fa1e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -121,6 +121,10 @@ "logo": { "image_light": "_static/NeMoS_Logo_CMYK_Full.svg", "image_dark": "_static/NeMoS_Logo_CMYK_White.svg", + }, + "secondary_sidebar_items": { + "**": ["page-toc", "sourcelink"], + "background/basis/README": [], } } diff --git a/docs/tutorials/plot_02_head_direction.md b/docs/tutorials/plot_02_head_direction.md index e4402053..ee7c478c 100644 --- a/docs/tutorials/plot_02_head_direction.md +++ b/docs/tutorials/plot_02_head_direction.md @@ -374,6 +374,7 @@ worst if we needed a finer temporal resolution, such 1ms time bins (which would require 800 coefficients instead of 80). What can we do to mitigate over-fitting now? +(head_direction_reducing_dimensionality)= #### Reducing feature dimensionality One way to proceed is to find a lower-dimensional representation of the response by parametrizing the decay effect. For instance, we could try to model it diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index a7f767ef..d0b58e0a 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -123,6 +123,8 @@ for i in range(len(spikes)): plt.tight_layout() ``` + +(grid_cells_nemos)= ## NeMoS It's time to use NeMoS. Let's try to predict the spikes as a function of position and see if we can generate better tuning curves @@ -146,9 +148,9 @@ We can define a two-dimensional basis for position by multiplying two one-dimens see [here](../../background/plot_02_ND_basis_function) for more details. ```{code-cell} ipython3 -basis_2d = nmo.basis.EvalRaisedCosineLinear( +basis_2d = nmo.basis.EvalBspline( n_basis_funcs=10 -) * nmo.basis.EvalRaisedCosineLinear(n_basis_funcs=10) +) * nmo.basis.EvalBspline(n_basis_funcs=10) ``` Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid. @@ -219,7 +221,10 @@ Here we will focus on the last neuron (neuron 7) who has a nice grid pattern ```{code-cell} ipython3 model = nmo.glm.GLM( regularizer="Ridge", - regularizer_strength=0.001 + regularizer_strength=0.001, + # lowering the tolerance means that the solution will be closer to the optimum + # (at the cost of increasing execution time) + solver_kwargs=dict(tol=10**-12), ) ``` diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index 597959c3..88eb081f 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -331,6 +331,7 @@ print(speed.shape) print(count.shape) ``` +(basis_eval_place_cells)= ## Basis evaluation For each feature, we will use a different set of basis : From 98eaf253ecc6597bed736d663fd7023adb626d73 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:10:18 -0500 Subject: [PATCH 03/21] fix notebook --- docs/tutorials/plot_03_grid_cells.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index d0b58e0a..92a045bc 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -148,9 +148,9 @@ We can define a two-dimensional basis for position by multiplying two one-dimens see [here](../../background/plot_02_ND_basis_function) for more details. ```{code-cell} ipython3 -basis_2d = nmo.basis.EvalBspline( +basis_2d = nmo.basis.EvalBSpline( n_basis_funcs=10 -) * nmo.basis.EvalBspline(n_basis_funcs=10) +) * nmo.basis.EvalBSpline(n_basis_funcs=10) ``` Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid. From 7a13bcdd20156b3ae654838fc52aab756d6b53d6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:18:10 -0500 Subject: [PATCH 04/21] fix notebook --- docs/tutorials/plot_03_grid_cells.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index 92a045bc..d2e8da7a 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -221,7 +221,7 @@ Here we will focus on the last neuron (neuron 7) who has a nice grid pattern ```{code-cell} ipython3 model = nmo.glm.GLM( regularizer="Ridge", - regularizer_strength=0.001, + regularizer_strength=0.0001, # lowering the tolerance means that the solution will be closer to the optimum # (at the cost of increasing execution time) solver_kwargs=dict(tol=10**-12), From 0a53a594064342bf8f3946a190cc224eddab7655 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:43:13 -0500 Subject: [PATCH 05/21] ADDED ORTH exp --- docs/background/basis/README.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 3ed0c087..507060b0 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -1,5 +1,6 @@ # Basis Function +(table_basis)= ```{table} :name: table-basis @@ -10,6 +11,7 @@ | **M-Spline** | M-spline. | [Place cells](basis_eval_place_cells) | [EvalMSpline](nemos.basis.basis.EvalMSpline)
[ConvMSpline](nemos.basis.basis.ConvMSpline) | 🟢 Eval | | **Linearly Spaced Raised Cosine** | Raised Cosine Linear. | | [EvalRaisedCosineLinear](nemos.basis.basis.EvalRaisedCosineLinear)
[ConvRaisedCosineLinear](nemos.basis.basis.ConvRaisedCosineLinear) | 🟢 Eval | | **Log Spaced Raised Cosine** | Raised Cosine Log. | [Head Direction](head_direction_reducing_dimensionality) | [EvalRaisedCosineLog](nemos.basis.basis.EvalRaisedCosineLog)
[ConvRaisedCosineLog](nemos.basis.basis.ConvRaisedCosineLog) | 🔵 Conv | +| **Orthogonalized Exponential Decays** | Orth Exponential Decays | [Head Direction](head_direction_reducing_dimensionality) | [EvalOrthExponential](nemos.basis.basis.EvalOrthExponential)
[ConvOrthExponential](nemos.basis.basis.ConvOrthExponential) | 🟢 Eval | ``` ## Overview @@ -27,18 +29,21 @@ $$ f(x) \approx \alpha_1 \psi_1(x) + \dots + \alpha_n \psi_n(x) $$ -Here, $\approx$ means "approximately equal". Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. +Here, $\approx$ means "approximately equal". +Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. ## Basis in NeMoS -NeMoS provides a variety of basis function objects, each tailored for specific shapes and use cases. These objects make it easy to define both non-linear features and temporal predictors. Depending on the type of modeling you need, NeMoS offers: +NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two key uses described in the overview: -- **Eval-basis objects**: For creating non-linear features. (Names start with `Eval`.) -- **Conv-basis objects**: For defining temporal predictors. (Names start with `Conv`.) +- **Eval-basis objects**: For representing non-linear mappings between task variables and outputs. These objects are identified by names starting with `Eval`. +- **Conv-basis objects**: For linear temporal effects. These objects are identified by names starting with `Conv`. -If you want to know how to create and use one-dimensional bases or combining them to build multi-dimensional predictors, check out these resources: +`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling complex feature construction. + +## Learn More ::::{grid} 1 2 2 2 From 4777a252ea79296cfdbd443b079ca8efa4787d4b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 2 Dec 2024 17:44:41 -0500 Subject: [PATCH 06/21] fixed entry --- docs/background/basis/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 507060b0..cbfbbab9 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -11,7 +11,7 @@ | **M-Spline** | M-spline. | [Place cells](basis_eval_place_cells) | [EvalMSpline](nemos.basis.basis.EvalMSpline)
[ConvMSpline](nemos.basis.basis.ConvMSpline) | 🟢 Eval | | **Linearly Spaced Raised Cosine** | Raised Cosine Linear. | | [EvalRaisedCosineLinear](nemos.basis.basis.EvalRaisedCosineLinear)
[ConvRaisedCosineLinear](nemos.basis.basis.ConvRaisedCosineLinear) | 🟢 Eval | | **Log Spaced Raised Cosine** | Raised Cosine Log. | [Head Direction](head_direction_reducing_dimensionality) | [EvalRaisedCosineLog](nemos.basis.basis.EvalRaisedCosineLog)
[ConvRaisedCosineLog](nemos.basis.basis.ConvRaisedCosineLog) | 🔵 Conv | -| **Orthogonalized Exponential Decays** | Orth Exponential Decays | [Head Direction](head_direction_reducing_dimensionality) | [EvalOrthExponential](nemos.basis.basis.EvalOrthExponential)
[ConvOrthExponential](nemos.basis.basis.ConvOrthExponential) | 🟢 Eval | +| **Orthogonalized Exponential Decays** | Orth Exponential Decays | | [EvalOrthExponential](nemos.basis.basis.EvalOrthExponential)
[ConvOrthExponential](nemos.basis.basis.ConvOrthExponential) | 🟢 Eval | ``` ## Overview From be197c13db1c91d620832261d71fd015d8b3bcbc Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 09:17:30 -0500 Subject: [PATCH 07/21] add script to gen figs --- docs/conf.py | 7 +- docs/scripts/basis_table_figs.py | 50 ++++ src/nemos/_inspect_utils/__init__.py | 34 +++ .../nemos/_inspect_utils/inpsect_utils.py | 41 ++- tests/test_basis.py | 265 +++++++++--------- tox.ini | 3 + 6 files changed, 261 insertions(+), 139 deletions(-) create mode 100644 docs/scripts/basis_table_figs.py create mode 100644 src/nemos/_inspect_utils/__init__.py rename tests/utils_testing.py => src/nemos/_inspect_utils/inpsect_utils.py (85%) diff --git a/docs/conf.py b/docs/conf.py index 6204fa1e..f4a40b05 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,6 +40,7 @@ 'sphinx.ext.mathjax', 'sphinx_autodoc_typehints', 'sphinx_togglebutton', + 'matplotlib.sphinxext.plot_directive', ] myst_enable_extensions = [ @@ -68,7 +69,7 @@ 'inherited-members': True, 'undoc-members': True, 'show-inheritance': True, - 'special-members': '__call__, __add__, __mul__, __pow__' + 'special-members': ' __add__, __mul__, __pow__' } # # napolean configs @@ -121,10 +122,6 @@ "logo": { "image_light": "_static/NeMoS_Logo_CMYK_Full.svg", "image_dark": "_static/NeMoS_Logo_CMYK_White.svg", - }, - "secondary_sidebar_items": { - "**": ["page-toc", "sourcelink"], - "background/basis/README": [], } } diff --git a/docs/scripts/basis_table_figs.py b/docs/scripts/basis_table_figs.py new file mode 100644 index 00000000..7e3b503f --- /dev/null +++ b/docs/scripts/basis_table_figs.py @@ -0,0 +1,50 @@ +import matplotlib.pyplot as plt +import numpy as np + +import nemos as nmo +from nemos._inspect_utils import trim_kwargs + +KWARGS = dict( + n_basis_funcs=10, + decay_rates=np.arange(1, 10 + 1), + enforce_decay_to_zero=True, + order=4, + width=2, +) + + +def plot_basis(cls): + cls_params = cls._get_param_names() + new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) + bas = cls(**new_kwargs) + fig, ax = plt.subplots(1, 1, figsize=(5 / 4, 2.5 / 4)) + ax.plot(*bas.evaluate_on_grid(300), lw=0.8) + for side in ["left", "right", "top", "bottom"]: + ax.spines[side].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02) + + +def plot_raised_cosine_linear(): + plot_basis(nmo.basis.RaisedCosineLinearEval) + + +def plot_raised_cosine_log(): + plot_basis(nmo.basis.RaisedCosineLogEval) + + +def plot_mspline(): + plot_basis(nmo.basis.MSplineEval) + + +def plot_bspline(): + plot_basis(nmo.basis.BSplineEval) + + +def plot_cyclic_bspline(): + plot_basis(nmo.basis.CyclicBSplineEval) + + +def plot_orth_exp_basis(): + plot_basis(nmo.basis.OrthExponentialEval) diff --git a/src/nemos/_inspect_utils/__init__.py b/src/nemos/_inspect_utils/__init__.py new file mode 100644 index 00000000..21743bd6 --- /dev/null +++ b/src/nemos/_inspect_utils/__init__.py @@ -0,0 +1,34 @@ +""" +This module provides utilities for inspecting class hierarchies, +abstract methods, and subclass method implementations. + +Modules +------- +inspect_utils : module + Contains utility functions to analyze abstract and concrete class methods, + identify abstract classes, and verify method compliance in subclasses. +""" + +from .inpsect_utils import ( + check_all_abstract_methods_compliance, + get_abstract_classes, + get_non_abstract_classes, + get_subclass_methods, + get_superclass_abstract_methods, + is_abstract, + list_abstract_methods, + reimplements_method, + trim_kwargs, +) + +__all__ = [ + "reimplements_method", + "get_subclass_methods", + "list_abstract_methods", + "is_abstract", + "get_non_abstract_classes", + "get_abstract_classes", + "get_superclass_abstract_methods", + "check_all_abstract_methods_compliance", + "trim_kwargs", +] diff --git a/tests/utils_testing.py b/src/nemos/_inspect_utils/inpsect_utils.py similarity index 85% rename from tests/utils_testing.py rename to src/nemos/_inspect_utils/inpsect_utils.py index 344ad451..eab425ef 100644 --- a/tests/utils_testing.py +++ b/src/nemos/_inspect_utils/inpsect_utils.py @@ -40,7 +40,7 @@ def get_subclass_methods(class_obj: type) -> List[Tuple[str, type]]: Returns ------- - List[Tuple[str, type]] + : A list of tuples representing the methods that are specific to the subclass. Each tuple contains the method name (str) and the corresponding method object. """ @@ -113,7 +113,7 @@ def get_non_abstract_classes(module) -> List[Tuple[str, type]]: Returns ------- - List[Tuple[str, type]] + : A list of tuples representing the non-abstract classes in the module. Each tuple contains the class name (str) and the corresponding class object. """ @@ -208,3 +208,40 @@ def check_all_abstract_methods_compliance(module) -> None: raise ValueError( f"Abstract method {method} not implemented in {base_class} sub-class!" ) + + +def trim_kwargs(cls: type, kwargs: dict, class_specific_params: dict): + """ + Filter a dictionary of keyword arguments to include only those specific to a given class. + + Parameters + ---------- + cls : + The class object for which the keyword arguments are filtered. + kwargs : + A dictionary of keyword arguments to be filtered. + class_specific_params : + A mapping of class names to sets or lists of allowed parameter names. + + Returns + ------- + : + A dictionary containing only the keyword arguments specific to the given class. + + Example + ------- + >>> class_specific_params = { + ... 'MyClass': {'param1', 'param2'}, + ... 'OtherClass': {'param3', 'param4'} + ... } + >>> kwargs = {'param1': 10, 'param3': 20, 'param5': 30} + >>> class MyClass: + ... pass + >>> trim_kwargs(MyClass, kwargs, class_specific_params) + {'param1': 10} + """ + return { + key: value + for key, value in kwargs.items() + if key in class_specific_params[cls.__name__] + } diff --git a/tests/test_basis.py b/tests/test_basis.py index 3f86ebe7..b85f2d9e 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -11,10 +11,10 @@ import numpy as np import pynapple as nap import pytest -import utils_testing from sklearn.base import clone as sk_clone import nemos as nmo +import nemos._inspect_utils as inspect_utils import nemos.basis.basis as basis import nemos.convolve as convolve from nemos.basis._basis import AdditiveBasis, Basis, MultiplicativeBasis, add_docstring @@ -27,14 +27,6 @@ from nemos.utils import pynapple_concatenate_numpy -def trim_kwargs(cls, kwargs, class_specific_params): - return { - key: value - for key, value in kwargs.items() - if key in class_specific_params[cls.__name__] - } - - def extra_decay_rates(cls, n_basis): name = cls.__name__ if "OrthExp" in name: @@ -50,11 +42,11 @@ def list_all_basis_classes(filter_basis="all") -> list[type]: """ all_basis = [ class_obj - for _, class_obj in utils_testing.get_non_abstract_classes(basis) + for _, class_obj in inspect_utils.get_non_abstract_classes(basis) if issubclass(class_obj, Basis) ] + [ bas - for _, bas in utils_testing.get_non_abstract_classes(nmo.basis._basis) + for _, bas in inspect_utils.get_non_abstract_classes(nmo.basis._basis) if bas != basis.TransformerBasis ] if filter_basis != "all": @@ -132,7 +124,7 @@ def test_all_basis_are_tested() -> None: ("evaluate_on_grid", "The number of points in the uniformly spaced grid"), ( "compute_features", - "Compute the basis functions and transform input data into model features", + "Apply the basis transformation to the input data", ), ( "split_by_feature", @@ -167,7 +159,7 @@ def test_example_docstrings_add( continue if basis_name == basis_instance.__class__.__name__: continue - assert basis_name not in doc_components[1] + assert f" {basis_name}" not in doc_components[1] def test_add_docstring(): @@ -191,19 +183,19 @@ def method(self): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10), BSplineBasis), - (basis.ConvBSpline(10, window_size=11), BSplineBasis), - (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), - (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), - (basis.EvalMSpline(10), MSplineBasis), - (basis.ConvMSpline(10, window_size=11), MSplineBasis), - (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), - (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), - (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), - (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.BSplineEval(10), BSplineBasis), + (basis.BSplineConv(10, window_size=11), BSplineBasis), + (basis.CyclicBSplineEval(10), CyclicBSplineBasis), + (basis.CyclicBSplineConv(10, window_size=11), CyclicBSplineBasis), + (basis.MSplineEval(10), MSplineBasis), + (basis.MSplineConv(10, window_size=11), MSplineBasis), + (basis.RaisedCosineLinearEval(10), RaisedCosineBasisLinear), + (basis.RaisedCosineLinearConv(10, window_size=11), RaisedCosineBasisLinear), + (basis.RaisedCosineLogEval(10), RaisedCosineBasisLog), + (basis.RaisedCosineLogConv(10, window_size=11), RaisedCosineBasisLog), + (basis.OrthExponentialEval(10, np.arange(1, 11)), OrthExponentialBasis), ( - basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + basis.OrthExponentialConv(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis, ), ], @@ -218,19 +210,19 @@ def test_expected_output_eval_on_grid(basis_instance, super_class): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10), BSplineBasis), - (basis.ConvBSpline(10, window_size=11), BSplineBasis), - (basis.EvalCyclicBSpline(10), CyclicBSplineBasis), - (basis.ConvCyclicBSpline(10, window_size=11), CyclicBSplineBasis), - (basis.EvalMSpline(10), MSplineBasis), - (basis.ConvMSpline(10, window_size=11), MSplineBasis), - (basis.EvalRaisedCosineLinear(10), RaisedCosineBasisLinear), - (basis.ConvRaisedCosineLinear(10, window_size=11), RaisedCosineBasisLinear), - (basis.EvalRaisedCosineLog(10), RaisedCosineBasisLog), - (basis.ConvRaisedCosineLog(10, window_size=11), RaisedCosineBasisLog), - (basis.EvalOrthExponential(10, np.arange(1, 11)), OrthExponentialBasis), + (basis.BSplineEval(10), BSplineBasis), + (basis.BSplineConv(10, window_size=11), BSplineBasis), + (basis.CyclicBSplineEval(10), CyclicBSplineBasis), + (basis.CyclicBSplineConv(10, window_size=11), CyclicBSplineBasis), + (basis.MSplineEval(10), MSplineBasis), + (basis.MSplineConv(10, window_size=11), MSplineBasis), + (basis.RaisedCosineLinearEval(10), RaisedCosineBasisLinear), + (basis.RaisedCosineLinearConv(10, window_size=11), RaisedCosineBasisLinear), + (basis.RaisedCosineLogEval(10), RaisedCosineBasisLog), + (basis.RaisedCosineLogConv(10, window_size=11), RaisedCosineBasisLog), + (basis.OrthExponentialEval(10, np.arange(1, 11)), OrthExponentialBasis), ( - basis.ConvOrthExponential(10, decay_rates=np.arange(1, 11), window_size=12), + basis.OrthExponentialConv(10, decay_rates=np.arange(1, 11), window_size=12), OrthExponentialBasis, ), ], @@ -246,31 +238,31 @@ def test_expected_output_compute_features(basis_instance, super_class): @pytest.mark.parametrize( "basis_instance, super_class", [ - (basis.EvalBSpline(10, label="label"), BSplineBasis), - (basis.ConvBSpline(10, window_size=11, label="label"), BSplineBasis), - (basis.EvalCyclicBSpline(10, label="label"), CyclicBSplineBasis), + (basis.BSplineEval(10, label="label"), BSplineBasis), + (basis.BSplineConv(10, window_size=11, label="label"), BSplineBasis), + (basis.CyclicBSplineEval(10, label="label"), CyclicBSplineBasis), ( - basis.ConvCyclicBSpline(10, window_size=11, label="label"), + basis.CyclicBSplineConv(10, window_size=11, label="label"), CyclicBSplineBasis, ), - (basis.EvalMSpline(10, label="label"), MSplineBasis), - (basis.ConvMSpline(10, window_size=11, label="label"), MSplineBasis), - (basis.EvalRaisedCosineLinear(10, label="label"), RaisedCosineBasisLinear), + (basis.MSplineEval(10, label="label"), MSplineBasis), + (basis.MSplineConv(10, window_size=11, label="label"), MSplineBasis), + (basis.RaisedCosineLinearEval(10, label="label"), RaisedCosineBasisLinear), ( - basis.ConvRaisedCosineLinear(10, window_size=11, label="label"), + basis.RaisedCosineLinearConv(10, window_size=11, label="label"), RaisedCosineBasisLinear, ), - (basis.EvalRaisedCosineLog(10, label="label"), RaisedCosineBasisLog), + (basis.RaisedCosineLogEval(10, label="label"), RaisedCosineBasisLog), ( - basis.ConvRaisedCosineLog(10, window_size=11, label="label"), + basis.RaisedCosineLogConv(10, window_size=11, label="label"), RaisedCosineBasisLog, ), ( - basis.EvalOrthExponential(10, np.arange(1, 11), label="label"), + basis.OrthExponentialEval(10, np.arange(1, 11), label="label"), OrthExponentialBasis, ), ( - basis.ConvOrthExponential( + basis.OrthExponentialConv( 10, decay_rates=np.arange(1, 11), window_size=12, label="label" ), OrthExponentialBasis, @@ -305,12 +297,12 @@ def cls(self): @pytest.mark.parametrize( "cls", [ - {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog}, - {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear}, - {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline}, - {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline}, - {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline}, - {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential}, + {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv}, + {"eval": basis.RaisedCosineLinearEval, "conv": basis.RaisedCosineLinearConv}, + {"eval": basis.BSplineEval, "conv": basis.BSplineConv}, + {"eval": basis.CyclicBSplineEval, "conv": basis.CyclicBSplineConv}, + {"eval": basis.MSplineEval, "conv": basis.MSplineConv}, + {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv}, ], ) class TestSharedMethods: @@ -345,7 +337,7 @@ def test_call_vmin_vmax(self, samples, vmin, vmax, expectation, cls): return bas = cls["eval"](5, bounds=(vmin, vmax), **extra_decay_rates(cls["eval"], 5)) with expectation: - bas(samples) + bas._evaluate(samples) @pytest.mark.parametrize( "attribute, value", @@ -532,7 +524,7 @@ def test_call_basis_number(self, n_basis, mode, kwargs, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) x = np.linspace(0, 1, 10) - assert bas(x).shape[1] == n_basis + assert bas._evaluate(x).shape[1] == n_basis @pytest.mark.parametrize("n_basis", [6]) def test_call_equivalent_in_conv(self, n_basis, cls): @@ -545,7 +537,7 @@ def test_call_equivalent_in_conv(self, n_basis, cls): n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) ) x = np.linspace(0, 1, 10) - assert np.all(bas_con(x) == bas_eval(x)) + assert np.all(bas_con._evaluate(x) == bas_eval._evaluate(x)) @pytest.mark.parametrize( "num_input, expectation", @@ -564,7 +556,7 @@ def test_call_input_num(self, num_input, n_basis, mode, kwargs, expectation, cls n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with expectation: - bas(*([np.linspace(0, 1, 10)] * num_input)) + bas._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -582,7 +574,7 @@ def test_call_input_shape(self, inp, mode, kwargs, expectation, n_basis, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with expectation: - bas(inp) + bas._evaluate(inp) @pytest.mark.parametrize( "samples, expectation", @@ -600,7 +592,7 @@ def test_call_input_type(self, samples, expectation, n_basis, cls): n_basis_funcs=n_basis, **extra_decay_rates(cls["eval"], n_basis) ) # Only eval mode is relevant here with expectation: - bas(samples) + bas._evaluate(samples) @pytest.mark.parametrize( "mode, kwargs", [("eval", {}), ("conv", {"window_size": 3})] @@ -609,7 +601,7 @@ def test_call_nan(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x[3] = np.nan - assert all(np.isnan(bas(x)[3])) + assert all(np.isnan(bas._evaluate(x)[3])) @pytest.mark.parametrize("n_basis", [6, 7]) @pytest.mark.parametrize( @@ -620,7 +612,7 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): n_basis_funcs=n_basis, **kwargs, **extra_decay_rates(cls[mode], n_basis) ) with pytest.raises(ValueError, match="All sample provided must"): - bas(np.array([])) + bas._evaluate(np.array([])) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize( @@ -628,7 +620,10 @@ def test_call_non_empty(self, n_basis, mode, kwargs, cls): ) def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) - assert bas(np.linspace(0, 1, time_axis_shape)).shape[0] == time_axis_shape + assert ( + bas._evaluate(np.linspace(0, 1, time_axis_shape)).shape[0] + == time_axis_shape + ) @pytest.mark.parametrize( "mn, mx, expectation", @@ -643,7 +638,7 @@ def test_call_sample_axis(self, time_axis_shape, mode, kwargs, cls): def test_call_sample_range(self, mn, mx, expectation, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) with expectation: - bas(np.linspace(mn, mx, 10)) + bas._evaluate(np.linspace(mn, mx, 10)) @pytest.mark.parametrize( "kwargs, input1_shape, expectation", @@ -718,7 +713,7 @@ def test_compute_features_conv_input( ) # figure out which kwargs needs to be removed - kwargs = trim_kwargs(cls["conv"], kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs(cls["conv"], kwargs, class_specific_params) basis_obj = cls["conv"](**kwargs) out = basis_obj.compute_features(x) @@ -1020,9 +1015,9 @@ def test_init_window_size(self, mode, ws, expectation, cls): # @pytest.mark.parametrize("mode, kwargs", [("eval", {}), ("conv", {"window_size": 2})]) # def test_minimum_number_of_basis_required_is_matched(self, n_basis_funcs, mode, kwargs, order, cls): # min_per_basis = { - # "EvalMSpline": (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs), - # "EvalRaisedCosineLog": lambda x: x < 2, - # "EvalBSpline": lambda x: order > x, + # "MSplineEval": (order < 1) | (n_basis_funcs < 1) | (order > n_basis_funcs), + # "RaisedCosineLogEval": lambda x: x < 2, + # "BSplineEval": lambda x: order > x, # } # if n_basis_funcs < 2: # with pytest.raises( @@ -1070,7 +1065,7 @@ def test_number_of_required_inputs_compute_features( ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( - TypeError, match="takes 2 positional arguments but \d were given" + TypeError, match=r"takes 2 positional arguments but \d were given" ) else: expectation = does_not_raise() @@ -1085,8 +1080,8 @@ def test_pynapple_support(self, mode, kwargs, cls): bas = cls[mode](n_basis_funcs=5, **kwargs, **extra_decay_rates(cls[mode], 5)) x = np.linspace(0, 1, 10) x_nap = nap.Tsd(t=np.arange(10), d=x) - y = bas(x) - y_nap = bas(x_nap) + y = bas._evaluate(x) + y_nap = bas._evaluate(x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap.t) @@ -1258,7 +1253,7 @@ def test_transformer_get_params(self, cls): class TestRaisedCosineLogBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalRaisedCosineLog, "conv": basis.ConvRaisedCosineLog} + cls = {"eval": basis.RaisedCosineLogEval, "conv": basis.RaisedCosineLogConv} @pytest.mark.parametrize("width", [1.5, 2, 2.5]) def test_decay_to_zero_basis_number_match(self, width): @@ -1332,7 +1327,7 @@ def test_set_width(self, width, expectation, mode, kwargs): def test_time_scaling_property(self): time_scaling = [0.1, 10, 100] n_basis_funcs = 5 - _, lin_ev = basis.EvalRaisedCosineLinear(n_basis_funcs).evaluate_on_grid(100) + _, lin_ev = basis.RaisedCosineLinearEval(n_basis_funcs).evaluate_on_grid(100) corr = np.zeros(len(time_scaling)) for idx, ts in enumerate(time_scaling): basis_log = self.cls["eval"]( @@ -1395,7 +1390,7 @@ def test_width_values(self, width, expectation, mode, kwargs): class TestRaisedCosineLinearBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalRaisedCosineLinear, "conv": basis.ConvRaisedCosineLinear} + cls = {"eval": basis.RaisedCosineLinearEval, "conv": basis.RaisedCosineLinearConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize( @@ -1469,7 +1464,7 @@ def test_width_values(self, width, expectation, mode, kwargs): class TestMSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalMSpline, "conv": basis.ConvMSpline} + cls = {"eval": basis.MSplineEval, "conv": basis.MSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [-1, 0, 1, 2, 3, 4, 5]) @@ -1573,7 +1568,7 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( class TestOrthExponentialBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalOrthExponential, "conv": basis.ConvOrthExponential} + cls = {"eval": basis.OrthExponentialEval, "conv": basis.OrthExponentialConv} @pytest.mark.parametrize( "decay_rates", [[1, 2, 3], [0.01, 0.02, 0.001], [2, 1, 1, 2.4]] @@ -1645,7 +1640,7 @@ def test_minimum_number_of_basis_required_is_matched( class TestBSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalBSpline, "conv": basis.ConvBSpline} + cls = {"eval": basis.BSplineEval, "conv": basis.BSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [1, 2, 3, 4, 5]) @@ -1733,7 +1728,7 @@ def test_samples_range_matches_compute_features_requirements( class TestCyclicBSplineBasis(BasisFuncsTesting): - cls = {"eval": basis.EvalCyclicBSpline, "conv": basis.ConvCyclicBSpline} + cls = {"eval": basis.CyclicBSplineEval, "conv": basis.CyclicBSplineConv} @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("order", [2, 3, 4, 5]) @@ -1871,28 +1866,28 @@ def instantiate_basis( kwargs = {**default_kwargs, **kwargs} if basis_class == AdditiveBasis: - kwargs_mspline = trim_kwargs( - basis.EvalMSpline, kwargs, class_specific_params + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params ) - kwargs_raised_cosine = trim_kwargs( - basis.ConvRaisedCosineLinear, kwargs, class_specific_params + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params ) - b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) basis_obj = b1 + b2 elif basis_class == MultiplicativeBasis: - kwargs_mspline = trim_kwargs( - basis.EvalMSpline, kwargs, class_specific_params + kwargs_mspline = inspect_utils.trim_kwargs( + basis.MSplineEval, kwargs, class_specific_params ) - kwargs_raised_cosine = trim_kwargs( - basis.ConvRaisedCosineLinear, kwargs, class_specific_params + kwargs_raised_cosine = inspect_utils.trim_kwargs( + basis.RaisedCosineLinearConv, kwargs, class_specific_params ) - b1 = basis.EvalMSpline(**kwargs_mspline) - b2 = basis.ConvRaisedCosineLinear(**kwargs_raised_cosine) + b1 = basis.MSplineEval(**kwargs_mspline) + b2 = basis.RaisedCosineLinearConv(**kwargs_raised_cosine) basis_obj = b1 * b2 else: basis_obj = basis_class( - **trim_kwargs(basis_class, kwargs, class_specific_params) + **inspect_utils.trim_kwargs(basis_class, kwargs, class_specific_params) ) return basis_obj @@ -1901,10 +1896,10 @@ class TestAdditiveBasis(CombinedBasis): cls = {"eval": AdditiveBasis, "conv": AdditiveBasis} @pytest.mark.parametrize("samples", [[[0], []], [[], [0]], [[0, 0], [0, 0]]]) - @pytest.mark.parametrize("base_cls", [basis.EvalBSpline, basis.ConvBSpline]) + @pytest.mark.parametrize("base_cls", [basis.BSplineEval, basis.BSplineConv]) def test_non_empty_samples(self, base_cls, samples, class_specific_params): kwargs = {"window_size": 2, "n_basis_funcs": 5} - kwargs = trim_kwargs(base_cls, kwargs, class_specific_params) + kwargs = inspect_utils.trim_kwargs(base_cls, kwargs, class_specific_params) basis_obj = base_cls(**kwargs) + base_cls(**kwargs) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( @@ -1928,7 +1923,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.EvalMSpline(5) + basis.EvalMSpline(5) + basis_obj = basis.MSplineEval(5) + basis.MSplineEval(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -2183,7 +2178,7 @@ def test_call_input_num( TypeError, match="Input dimensionality mismatch" ) with expectation: - basis_obj(*([np.linspace(0, 1, 10)] * num_input)) + basis_obj._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -2216,7 +2211,7 @@ def test_call_input_shape( ) basis_obj = basis_a_obj + basis_b_obj with expectation: - basis_obj(*([inp] * basis_obj._n_input_dimensionality)) + basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize(" window_size", [3]) @@ -2242,7 +2237,7 @@ def test_call_sample_axis( ) basis_obj = basis_a_obj + basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality - assert basis_obj(*inp).shape[0] == time_axis_shape + assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2267,7 +2262,7 @@ def test_call_nan( inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality for x in inp: x[3] = np.nan - assert all(np.isnan(basis_obj(*inp)[3])) + assert all(np.isnan(basis_obj._evaluate(*inp)[3])) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2293,7 +2288,7 @@ def test_call_equivalent_in_conv( bas_con = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality - assert np.all(bas_con(*x) == bas_eva(*x)) + assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2313,8 +2308,8 @@ def test_pynapple_support( x = np.linspace(0, 1, 10) x_nap = [nap.Tsd(t=np.arange(10), d=x)] * bas._n_input_dimensionality x = [x] * bas._n_input_dimensionality - y = bas(*x) - y_nap = bas(*x_nap) + y = bas._evaluate(*x) + y_nap = bas._evaluate(*x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) @@ -2335,7 +2330,10 @@ def test_call_basis_number( ) bas = basis_a_obj + basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality - assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs + assert ( + bas._evaluate(*x).shape[1] + == basis_a_obj.n_basis_funcs + basis_b_obj.n_basis_funcs + ) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2353,7 +2351,7 @@ def test_call_non_empty( ) bas = basis_a_obj + basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): - bas(*([np.array([])] * bas._n_input_dimensionality)) + bas._evaluate(*([np.array([])] * bas._n_input_dimensionality)) @pytest.mark.parametrize( "mn, mx, expectation", @@ -2398,7 +2396,7 @@ def test_call_sample_range( ) bas = basis_a_obj + basis_b_obj with expectation: - bas(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) + bas._evaluate(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2457,8 +2455,8 @@ def test_transform_fails( @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(11, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -2469,8 +2467,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 + bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -2488,8 +2486,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas = bas1 + bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -2505,7 +2503,7 @@ class TestMultiplicativeBasis(CombinedBasis): ) @pytest.mark.parametrize(" ws", [3]) def test_non_empty_samples(self, samples, ws): - basis_obj = basis.EvalMSpline(5) * basis.EvalRaisedCosineLinear(5) + basis_obj = basis.MSplineEval(5) * basis.RaisedCosineLinearEval(5) if any(tuple(len(s) == 0 for s in samples)): with pytest.raises( ValueError, match="All sample provided must be non empty" @@ -2528,7 +2526,7 @@ def test_compute_features_input(self, eval_input): """ Checks that the sample size of the output from the compute_features() method matches the input sample size. """ - basis_obj = basis.EvalMSpline(5) * basis.EvalMSpline(5) + basis_obj = basis.MSplineEval(5) * basis.MSplineEval(5) basis_obj.compute_features(*eval_input) @pytest.mark.parametrize("n_basis_a", [5, 6]) @@ -2821,7 +2819,7 @@ def test_call_input_num( TypeError, match="Input dimensionality mismatch" ) with expectation: - basis_obj(*([np.linspace(0, 1, 10)] * num_input)) + basis_obj._evaluate(*([np.linspace(0, 1, 10)] * num_input)) @pytest.mark.parametrize( "inp, expectation", @@ -2854,7 +2852,7 @@ def test_call_input_shape( ) basis_obj = basis_a_obj * basis_b_obj with expectation: - basis_obj(*([inp] * basis_obj._n_input_dimensionality)) + basis_obj._evaluate(*([inp] * basis_obj._n_input_dimensionality)) @pytest.mark.parametrize("time_axis_shape", [10, 11, 12]) @pytest.mark.parametrize(" window_size", [3]) @@ -2880,7 +2878,7 @@ def test_call_sample_axis( ) basis_obj = basis_a_obj * basis_b_obj inp = [np.linspace(0, 1, time_axis_shape)] * basis_obj._n_input_dimensionality - assert basis_obj(*inp).shape[0] == time_axis_shape + assert basis_obj._evaluate(*inp).shape[0] == time_axis_shape @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2905,7 +2903,7 @@ def test_call_nan( inp = [np.linspace(0, 1, 10)] * basis_obj._n_input_dimensionality for x in inp: x[3] = np.nan - assert all(np.isnan(basis_obj(*inp)[3])) + assert all(np.isnan(basis_obj._evaluate(*inp)[3])) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -2931,7 +2929,7 @@ def test_call_equivalent_in_conv( bas_con = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas_con._n_input_dimensionality - assert np.all(bas_con(*x) == bas_eva(*x)) + assert np.all(bas_con._evaluate(*x) == bas_eva._evaluate(*x)) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2951,8 +2949,8 @@ def test_pynapple_support( x = np.linspace(0, 1, 10) x_nap = [nap.Tsd(t=np.arange(10), d=x)] * bas._n_input_dimensionality x = [x] * bas._n_input_dimensionality - y = bas(*x) - y_nap = bas(*x_nap) + y = bas._evaluate(*x) + y_nap = bas._evaluate(*x_nap) assert isinstance(y_nap, nap.TsdFrame) assert np.all(y == y_nap.d) assert np.all(y_nap.t == x_nap[0].t) @@ -2973,7 +2971,10 @@ def test_call_basis_number( ) bas = basis_a_obj * basis_b_obj x = [np.linspace(0, 1, 10)] * bas._n_input_dimensionality - assert bas(*x).shape[1] == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs + assert ( + bas._evaluate(*x).shape[1] + == basis_a_obj.n_basis_funcs * basis_b_obj.n_basis_funcs + ) @pytest.mark.parametrize(" window_size", [3]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -2991,7 +2992,7 @@ def test_call_non_empty( ) bas = basis_a_obj * basis_b_obj with pytest.raises(ValueError, match="All sample provided must"): - bas(*([np.array([])] * bas._n_input_dimensionality)) + bas._evaluate(*([np.array([])] * bas._n_input_dimensionality)) @pytest.mark.parametrize( "mn, mx, expectation", @@ -3036,7 +3037,7 @@ def test_call_sample_range( ) bas = basis_a_obj * basis_b_obj with expectation: - bas(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) + bas._evaluate(*([np.linspace(mn, mx, 10)] * bas._n_input_dimensionality)) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @pytest.mark.parametrize("basis_b", list_all_basis_classes()) @@ -3095,8 +3096,8 @@ def test_transform_fails( @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_output_features(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(11, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(11, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_output_features is None bas_add.compute_features( @@ -3107,8 +3108,8 @@ def test_set_num_output_features(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_add = bas1 * bas2 assert bas_add.n_basis_input is None bas_add.compute_features( @@ -3126,8 +3127,8 @@ def test_set_num_basis_input(self, n_basis_input1, n_basis_input2): ], ) def test_expected_input_number(self, n_input, expectation): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas = bas1 * bas2 x = np.random.randn(20, 2), np.random.randn(20, 3) bas.compute_features(*x) @@ -3137,8 +3138,8 @@ def test_expected_input_number(self, n_input, expectation): @pytest.mark.parametrize("n_basis_input1", [1, 2, 3]) @pytest.mark.parametrize("n_basis_input2", [1, 2, 3]) def test_n_basis_input(self, n_basis_input1, n_basis_input2): - bas1 = basis.ConvRaisedCosineLinear(10, window_size=10) - bas2 = basis.ConvBSpline(10, window_size=10) + bas1 = basis.RaisedCosineLinearConv(10, window_size=10) + bas2 = basis.BSplineConv(10, window_size=10) bas_prod = bas1 * bas2 bas_prod.compute_features( np.ones((20, n_basis_input1)), np.ones((20, n_basis_input2)) @@ -3147,7 +3148,7 @@ def test_n_basis_input(self, n_basis_input1, n_basis_input2): @pytest.mark.parametrize( - "exponent", [-1, 0, 0.5, basis.EvalRaisedCosineLog(4), 1, 2, 3] + "exponent", [-1, 0, 0.5, basis.RaisedCosineLogEval(4), 1, 2, 3] ) @pytest.mark.parametrize("basis_class", list_all_basis_classes()) def test_power_of_basis(exponent, basis_class, class_specific_params): diff --git a/tox.ini b/tox.ini index 94bd7a4f..8464ba95 100644 --- a/tox.ini +++ b/tox.ini @@ -30,6 +30,9 @@ commands= black tests isort tests --profile=black flake8 --config={toxinidir}/tox.ini src # convenient instead of remembering to run fix followed by check + black docs/scripts + isort docs/scripts --profile=black + flake8 --config={toxinidir}/tox.ini docs/scripts [testenv:check] commands= From e39b1c09d225807221b6d3a72c132326b1dca9ad Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 10:17:48 -0500 Subject: [PATCH 08/21] Fix table layout --- docs/assets/stylesheets/custom.css | 4 + docs/background/README.md | 2 +- docs/background/basis/README.md | 26 +- docs/background/basis/images/EvalBSpline.svg | 360 ----------- .../basis/images/EvalCyclicBSpline.svg | 427 ------------- docs/background/basis/images/EvalMSpline.svg | 289 --------- .../basis/images/EvalOrthExponential.svg | 564 ------------------ .../basis/images/EvalRaisedCosineLinear.svg | 367 ------------ .../basis/images/EvalRaisedCosineLog.svg | 310 ---------- docs/conf.py | 4 + docs/scripts/basis_table_figs.py | 11 +- 11 files changed, 36 insertions(+), 2328 deletions(-) delete mode 100644 docs/background/basis/images/EvalBSpline.svg delete mode 100644 docs/background/basis/images/EvalCyclicBSpline.svg delete mode 100644 docs/background/basis/images/EvalMSpline.svg delete mode 100644 docs/background/basis/images/EvalOrthExponential.svg delete mode 100644 docs/background/basis/images/EvalRaisedCosineLinear.svg delete mode 100644 docs/background/basis/images/EvalRaisedCosineLog.svg diff --git a/docs/assets/stylesheets/custom.css b/docs/assets/stylesheets/custom.css index 463be8a4..e557aa08 100644 --- a/docs/assets/stylesheets/custom.css +++ b/docs/assets/stylesheets/custom.css @@ -94,6 +94,10 @@ html[data-theme=light]{ font-weight: normal; } +table.table-center{ + text-align: center; +} + #table-basis { table-layout: auto; width: 100%; diff --git a/docs/background/README.md b/docs/background/README.md index 331a9b8c..530c454a 100644 --- a/docs/background/README.md +++ b/docs/background/README.md @@ -34,7 +34,7 @@ plot_00_conceptual_intro.md :::{grid-item-card}
-Basis Functions +Basis Functions
```{toctree} diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 59dd1dee..60516ad4 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -3,9 +3,13 @@ (table_basis)= ```{eval-rst} +.. role:: raw-html(raw) + :format: html + .. list-table:: :header-rows: 1 - :widths: 20 30 20 30 20 + :name: table-basis + :align: center * - **Basis** - **Kernel Visualization** @@ -14,38 +18,44 @@ - **Preferred Mode** * - **B-Spline** - .. plot:: scripts/basis_table_figs.py plot_bspline + :show-source-link: False - `Grid cells `_ - - `EvalBSpline `_ - `ConvBSpline `_ + - `EvalBSpline `_ :raw-html:`
` + `BSplineConv `_ - 🟢 Eval * - **Cyclic B-Spline** - .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline + :show-source-link: False - `Place cells `_ - - `EvalCyclicBSpline `_ + - `EvalCyclicBSpline `_ :raw-html:`
` `ConvCyclicBSpline `_ - 🟢 Eval * - **M-Spline** - .. plot:: scripts/basis_table_figs.py plot_mspline + :show-source-link: False - `Place cells `_ - - `EvalMSpline `_ + - `EvalMSpline `_ :raw-html:`
` `ConvMSpline `_ - 🟢 Eval * - **Linearly Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear + :show-source-link: False - - - `EvalRaisedCosineLinear `_ + - `EvalRaisedCosineLinear `_ :raw-html:`
` `ConvRaisedCosineLinear `_ - 🟢 Eval * - **Log Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log + :show-source-link: False - `Head Direction `_ - - `EvalRaisedCosineLog `_ + - `EvalRaisedCosineLog `_ :raw-html:`
` `ConvRaisedCosineLog `_ - 🔵 Conv * - **Orthogonalized Exponential Decays** - .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis + :show-source-link: False - - - `EvalOrthExponential `_ + - `EvalOrthExponential `_ :raw-html:`
` `ConvOrthExponential `_ - 🟢 Eval ``` diff --git a/docs/background/basis/images/EvalBSpline.svg b/docs/background/basis/images/EvalBSpline.svg deleted file mode 100644 index 66776ff8..00000000 --- a/docs/background/basis/images/EvalBSpline.svg +++ /dev/null @@ -1,360 +0,0 @@ - - - - - - - - 2024-12-02T16:54:50.980497 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/background/basis/images/EvalCyclicBSpline.svg b/docs/background/basis/images/EvalCyclicBSpline.svg deleted file mode 100644 index 35fa9481..00000000 --- a/docs/background/basis/images/EvalCyclicBSpline.svg +++ /dev/null @@ -1,427 +0,0 @@ - - - - - - - - 2024-12-02T16:54:50.989322 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/background/basis/images/EvalMSpline.svg b/docs/background/basis/images/EvalMSpline.svg deleted file mode 100644 index 36945902..00000000 --- a/docs/background/basis/images/EvalMSpline.svg +++ /dev/null @@ -1,289 +0,0 @@ - - - - - - - - 2024-12-02T16:54:50.998126 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/background/basis/images/EvalOrthExponential.svg b/docs/background/basis/images/EvalOrthExponential.svg deleted file mode 100644 index 0326c6e6..00000000 --- a/docs/background/basis/images/EvalOrthExponential.svg +++ /dev/null @@ -1,564 +0,0 @@ - - - - - - - - 2024-12-02T16:54:51.006464 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/background/basis/images/EvalRaisedCosineLinear.svg b/docs/background/basis/images/EvalRaisedCosineLinear.svg deleted file mode 100644 index fbdbd641..00000000 --- a/docs/background/basis/images/EvalRaisedCosineLinear.svg +++ /dev/null @@ -1,367 +0,0 @@ - - - - - - - - 2024-12-02T16:54:51.015187 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/background/basis/images/EvalRaisedCosineLog.svg b/docs/background/basis/images/EvalRaisedCosineLog.svg deleted file mode 100644 index 6389fc18..00000000 --- a/docs/background/basis/images/EvalRaisedCosineLog.svg +++ /dev/null @@ -1,310 +0,0 @@ - - - - - - - - 2024-12-02T16:54:51.023103 - image/svg+xml - - - Matplotlib v3.9.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/conf.py b/docs/conf.py index f4a40b05..f4d73d48 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,6 +41,7 @@ 'sphinx_autodoc_typehints', 'sphinx_togglebutton', 'matplotlib.sphinxext.plot_directive', + "matplotlib.sphinxext.mathmpl", ] myst_enable_extensions = [ @@ -161,3 +162,6 @@ nb_execution_excludepatterns = ["tutorials/**", "how_to_guide/**", "background/**"] viewcode_follow_imported_members = True + +# option for mpl extension +plot_html_show_formats = False \ No newline at end of file diff --git a/docs/scripts/basis_table_figs.py b/docs/scripts/basis_table_figs.py index 7e3b503f..721f6635 100644 --- a/docs/scripts/basis_table_figs.py +++ b/docs/scripts/basis_table_figs.py @@ -1,9 +1,16 @@ +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import nemos as nmo from nemos._inspect_utils import trim_kwargs +plt.rcParams.update( + { + "figure.dpi": 300, + } +) + KWARGS = dict( n_basis_funcs=10, decay_rates=np.arange(1, 10 + 1), @@ -17,8 +24,8 @@ def plot_basis(cls): cls_params = cls._get_param_names() new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) bas = cls(**new_kwargs) - fig, ax = plt.subplots(1, 1, figsize=(5 / 4, 2.5 / 4)) - ax.plot(*bas.evaluate_on_grid(300), lw=0.8) + fig, ax = plt.subplots(1, 1, figsize=(5, 2.5)) + ax.plot(*bas.evaluate_on_grid(300), lw=4) for side in ["left", "right", "top", "bottom"]: ax.spines[side].set_visible(False) ax.set_xticks([]) From aafa384ff2cddab2cd0e27839f970fb0532acb1d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 12:12:25 -0500 Subject: [PATCH 09/21] fix background note path --- docs/background/basis/plot_01_1D_basis_function.md | 2 +- docs/background/basis/plot_02_ND_basis_function.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index fa5c5879..cf534004 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -96,7 +96,7 @@ else: path = Path("../_build/html/_static/thumbnails/background") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../../_build/html/_static").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/background/basis/plot_02_ND_basis_function.md b/docs/background/basis/plot_02_ND_basis_function.md index a9636285..8bb8015e 100644 --- a/docs/background/basis/plot_02_ND_basis_function.md +++ b/docs/background/basis/plot_02_ND_basis_function.md @@ -304,7 +304,7 @@ if root: path = Path(root) / "html/_static/thumbnails/background" # if local store in ../_build/html/... else: - path = Path("../_build/html/_static/thumbnails/background") + path = Path("../../_build/html/_static/thumbnails/background") # make sure the folder exists if run from build if root or Path("../_build/html/_static").exists(): From 0fa061f0f85a4303693aa6cd97f729d262c480fc Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 12:20:10 -0500 Subject: [PATCH 10/21] fix background note path --- docs/background/basis/plot_01_1D_basis_function.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index cf534004..bddf798d 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -93,7 +93,7 @@ if root: path = Path(root) / "html/_static/thumbnails/background" # if local store in ../_build/html/... else: - path = Path("../_build/html/_static/thumbnails/background") + path = Path("../../_build/html/_static/thumbnails/background") # make sure the folder exists if run from build if root or Path("../../_build/html/_static").exists(): @@ -101,6 +101,9 @@ if root or Path("../../_build/html/_static").exists(): if path.exists(): fig.savefig(path / "plot_01_1D_basis_function.svg") + + +print(path.resolve(), path.exists()) ``` ## Feature Computation From 44e61cc4d960c1a3c9bf26637fb5e3d9c80870e3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 14:28:11 -0500 Subject: [PATCH 11/21] fixing links --- docs/background/basis/README.md | 32 +++++++++---------- .../basis/plot_01_1D_basis_function.md | 2 +- docs/conf.py | 6 +++- docs/scripts/basis_table_figs.py | 2 +- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 60516ad4..60789b5c 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -19,44 +19,44 @@ * - **B-Spline** - .. plot:: scripts/basis_table_figs.py plot_bspline :show-source-link: False - - `Grid cells `_ - - `EvalBSpline `_ :raw-html:`
` - `BSplineConv `_ + - :ref:`Grid cells ` + - :class:`~nemos.basis.BSplineEval` :raw-html:`
` + :class:`~nemos.basis.BSplineConv` - 🟢 Eval * - **Cyclic B-Spline** - .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline :show-source-link: False - - `Place cells `_ - - `EvalCyclicBSpline `_ :raw-html:`
` - `ConvCyclicBSpline `_ + - :ref:`Place cells ` + - :class:`~nemos.basis.CyclicBSplineEval` :raw-html:`
` + :class:`~nemos.basis.CyclicBSplineConv` - 🟢 Eval * - **M-Spline** - .. plot:: scripts/basis_table_figs.py plot_mspline :show-source-link: False - - `Place cells `_ - - `EvalMSpline `_ :raw-html:`
` - `ConvMSpline `_ + - :ref:`Place cells ` + - :class:`~nemos.basis.MSplineEval` :raw-html:`
` + :class:`~nemos.basis.MSplineConv` - 🟢 Eval * - **Linearly Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear :show-source-link: False - - - `EvalRaisedCosineLinear `_ :raw-html:`
` - `ConvRaisedCosineLinear `_ + - :class:`~nemos.basis.RaisedCosineLinearEval` :raw-html:`
` + :class:`~nemos.basis.RaisedCosineLinearConv` - 🟢 Eval * - **Log Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log :show-source-link: False - - `Head Direction `_ - - `EvalRaisedCosineLog `_ :raw-html:`
` - `ConvRaisedCosineLog `_ + - :ref:`Head Direction ` + - :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`
` + :class:`nemos.basis.RaisedCosineLogConv` - 🔵 Conv * - **Orthogonalized Exponential Decays** - .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis :show-source-link: False - - - `EvalOrthExponential `_ :raw-html:`
` - `ConvOrthExponential `_ + - :class:`~nemos.basis.OrthExponentialEval` :raw-html:`
` + :class:`~nemos.basis.OrthExponentialConv` - 🟢 Eval ``` diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index 29f32993..36ab66d6 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -156,7 +156,7 @@ Convolution is performed in "valid" mode, and then NaN-padded. The default behav is padding left, which makes the output feature causal. This is why the first half of the `conv_feature` is full of NaNs and appears as white. If you want to learn more about convolutions, as well as how and when to change defaults -check out the tutorial on [1D convolutions](plot_03_1D_convolution). +check out the tutorial on [1D convolutions](convolution_background). ::: diff --git a/docs/conf.py b/docs/conf.py index f4d73d48..2e752238 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -123,7 +123,11 @@ "logo": { "image_light": "_static/NeMoS_Logo_CMYK_Full.svg", "image_dark": "_static/NeMoS_Logo_CMYK_White.svg", - } + }, + "secondary_sidebar_items": { + "**": ["page-toc", "sourcelink"], + "background/basis/README": [], + }, } html_sidebars = { diff --git a/docs/scripts/basis_table_figs.py b/docs/scripts/basis_table_figs.py index 721f6635..3a442c6e 100644 --- a/docs/scripts/basis_table_figs.py +++ b/docs/scripts/basis_table_figs.py @@ -24,7 +24,7 @@ def plot_basis(cls): cls_params = cls._get_param_names() new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) bas = cls(**new_kwargs) - fig, ax = plt.subplots(1, 1, figsize=(5, 2.5)) + fig, ax = plt.subplots(1, 1, figsize=(5/4, 2.5/4)) ax.plot(*bas.evaluate_on_grid(300), lw=4) for side in ["left", "right", "top", "bottom"]: ax.spines[side].set_visible(False) From 09e0c8e419cfc75a67cc69a3e14f2ea16f687bd0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 4 Dec 2024 16:39:24 -0500 Subject: [PATCH 12/21] use plot directive for thumbnail --- docs/background/README.md | 9 ++++++--- docs/background/basis/README.md | 6 ++++++ docs/scripts/basis_table_figs.py | 3 +-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/docs/background/README.md b/docs/background/README.md index 530c454a..f67f4eda 100644 --- a/docs/background/README.md +++ b/docs/background/README.md @@ -33,9 +33,12 @@ plot_00_conceptual_intro.md :::{grid-item-card} -
-Basis Functions -
+```{eval-rst} + +.. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear + :show-source-link: False + :height: 100px +``` ```{toctree} :maxdepth: 2 diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 60789b5c..bec07e73 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -19,6 +19,7 @@ * - **B-Spline** - .. plot:: scripts/basis_table_figs.py plot_bspline :show-source-link: False + :height: 80px - :ref:`Grid cells ` - :class:`~nemos.basis.BSplineEval` :raw-html:`
` :class:`~nemos.basis.BSplineConv` @@ -26,6 +27,7 @@ * - **Cyclic B-Spline** - .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline :show-source-link: False + :height: 80px - :ref:`Place cells ` - :class:`~nemos.basis.CyclicBSplineEval` :raw-html:`
` :class:`~nemos.basis.CyclicBSplineConv` @@ -33,6 +35,7 @@ * - **M-Spline** - .. plot:: scripts/basis_table_figs.py plot_mspline :show-source-link: False + :height: 80px - :ref:`Place cells ` - :class:`~nemos.basis.MSplineEval` :raw-html:`
` :class:`~nemos.basis.MSplineConv` @@ -40,6 +43,7 @@ * - **Linearly Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear :show-source-link: False + :height: 80px - - :class:`~nemos.basis.RaisedCosineLinearEval` :raw-html:`
` :class:`~nemos.basis.RaisedCosineLinearConv` @@ -47,6 +51,7 @@ * - **Log Spaced Raised Cosine** - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log :show-source-link: False + :height: 80px - :ref:`Head Direction ` - :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`
` :class:`nemos.basis.RaisedCosineLogConv` @@ -54,6 +59,7 @@ * - **Orthogonalized Exponential Decays** - .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis :show-source-link: False + :height: 80px - - :class:`~nemos.basis.OrthExponentialEval` :raw-html:`
` :class:`~nemos.basis.OrthExponentialConv` diff --git a/docs/scripts/basis_table_figs.py b/docs/scripts/basis_table_figs.py index 3a442c6e..f60bed8c 100644 --- a/docs/scripts/basis_table_figs.py +++ b/docs/scripts/basis_table_figs.py @@ -1,4 +1,3 @@ -import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np @@ -24,7 +23,7 @@ def plot_basis(cls): cls_params = cls._get_param_names() new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) bas = cls(**new_kwargs) - fig, ax = plt.subplots(1, 1, figsize=(5/4, 2.5/4)) + fig, ax = plt.subplots(1, 1, figsize=(5, 2.5)) ax.plot(*bas.evaluate_on_grid(300), lw=4) for side in ["left", "right", "top", "bottom"]: ax.spines[side].set_visible(False) From 57e2bd2e91867ebfaf94927fb540a4f33f3acc54 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 5 Dec 2024 09:27:03 -0500 Subject: [PATCH 13/21] fixed description of basis --- docs/background/basis/plot_01_1D_basis_function.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index 36ab66d6..0ccbc611 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -47,10 +47,10 @@ warnings.filterwarnings( ## Defining a 1D Basis Object We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval). -The hyperparameters required to initialize this class are: +The hyperparameters needed to initialize this class are: -- The number of basis functions, which should be a positive integer. -- The order of the spline, which should be an integer greater than 1. +- The number of basis functions, which should be a positive integer (required). +- The order of the spline, which should be an integer greater than 1 (optional, default 4 for a cubic spline). ```{code-cell} ipython3 import matplotlib.pylab as plt @@ -107,11 +107,12 @@ print(path.resolve(), path.exists()) ``` ## Feature Computation -The bases in the `nemos.basis` module can be grouped into two categories: +All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features). +We can be group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies: -1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names starting with "Eval," such as `BSplineEval`. +1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ends with "Eval," such as `BSplineEval`. -2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names starting with "Conv," such as `BSplineConv`. +2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`. Let's see how this two modalities operate. From ddc357bfbb3dc59b1a27cacece2986c812b7fe73 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Wed, 11 Dec 2024 11:13:04 -0500 Subject: [PATCH 14/21] tweak text in plot 1d --- .../basis/plot_01_1D_basis_function.md | 57 ++++++------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index e23d15d8..31131bf2 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -46,7 +46,7 @@ warnings.filterwarnings( ## Defining a 1D Basis Object -We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval). +We'll start by defining a 1D basis function object of the type [`BSplineEval`](nemos.basis.BSplineEval). The hyperparameters needed to initialize this class are: - The number of basis functions, which should be a positive integer (required). @@ -106,26 +106,25 @@ if path.exists(): print(path.resolve(), path.exists()) ``` -## Feature Computation +## Computing Features All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features). -We can be group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies: +We can group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies: -1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ends with "Eval," such as `BSplineEval`. +1. **Evaluation Bases**: These bases use `compute_features` to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`. -2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`. +2. **Convolution Bases**: These bases use `compute_features` to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`. - -Let's see how this two modalities operate. +Let's see how these two categories operate: ```{code-cell} ipython3 -eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis) -conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100) +eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis) +conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100) # define an input angles = np.linspace(0, np.pi*4, 201) y = np.cos(angles) -# compute features in the two modalities +# compute features eval_feature = eval_mode.compute_features(y) conv_feature = conv_mode.compute_features(y) @@ -162,15 +161,17 @@ check out the tutorial on [1D convolutions](convolution_background). ::: -Plotting the Basis Function Elements: +Plotting the Basis Function Elements -------------------------------------- We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns -the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become -particularly evident when working with multidimensional basis functions. You can find more details and visual -background in the -[2D basis elements plotting section](plotting-2d-additive-basis-elements). +the equi-spaced samples along with the evaluated basis functions. + +:::{admonition} Note +The array returned by `evaluate_on_grid(n_samples)` is the same as the kernel that is used by the Conv bases initialized with `window_sizes=n_samples`! + +::: ```{code-cell} ipython3 # Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples @@ -184,12 +185,13 @@ plt.plot(equispaced_samples, eval_basis) plt.show() ``` +The benefits of using `evaluate_on_grid` become particularly evident when working with multidimensional basis functions. You can find more details in the [2D basis elements plotting section](plotting-2d-additive-basis-elements). ## Setting the basis support (Eval only) Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that your basis covers the same range across multiple experimental sessions. You can specify a range for the support of your basis by setting the `bounds` -parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions). +parameter at initialization of Eval bases. Evaluating the basis at any sample outside the bounds will result in a NaN. @@ -215,26 +217,3 @@ axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato") axs[1].set_title("bounds=[0.2, 0.8]") plt.tight_layout() ``` - -Other Basis Types ------------------ -Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, -please refer to the [API Guide](nemos_basis). After instantiation, all classes -share the same syntax for basis evaluation. The following is an example of how to instantiate and -evaluate a log-spaced cosine raised function basis. - - -```{code-cell} ipython3 -# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter -raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50) - -# Evaluate the raised cosine basis at the equi-spaced sample points -# (same method in all Basis elements) -samples, eval_basis = raised_cosine_log.evaluate_on_grid(100) - -# Plot the evaluated log-spaced raised cosine basis -plt.figure() -plt.title(f"Log-spaced Raised Cosine basis with {eval_basis.shape[1]} elements") -plt.plot(samples, eval_basis) -plt.show() -``` From e384d10ecbccb801a280eb49831501c20055cf21 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Wed, 11 Dec 2024 11:13:10 -0500 Subject: [PATCH 15/21] try to add summary to plot 2d --- docs/background/basis/plot_02_ND_basis_function.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/background/basis/plot_02_ND_basis_function.md b/docs/background/basis/plot_02_ND_basis_function.md index 8bb8015e..828e0ea8 100644 --- a/docs/background/basis/plot_02_ND_basis_function.md +++ b/docs/background/basis/plot_02_ND_basis_function.md @@ -51,10 +51,9 @@ combination of some multidimensional basis elements. In this document, we introduce two strategies for defining a high-dimensional basis function by combining two lower-dimensional bases. We refer to these strategies as "addition" and "multiplication" of bases, -and the resulting basis objects will be referred to as additive or multiplicative basis respectively. +and the resulting basis objects will be referred to as additive or multiplicative basis respectively: additive bases have their component bases operate *independently*, whereas multiplicative bases take the *outer product*. And these composite basis objects can be constructed using other composite bases, so that you can combine them as much as you'd like! - -Consider we have two inputs $\mathbf{x} \in \mathbb{R}^N,\; \mathbf{y}\in \mathbb{R}^M$. +More precisely, let's say we have two inputs $\mathbf{x} \in \mathbb{R}^N,\; \mathbf{y}\in \mathbb{R}^M$. Let's say we've defined two basis functions for these inputs: - $ [ a_0 (\mathbf{x}), ..., a_{k-1} (\mathbf{x}) ] $ for $\mathbf{x}$ @@ -106,6 +105,7 @@ In the subsequent sections, we will: 1. Demonstrate the definition, evaluation, and visualization of 2D additive and multiplicative bases. 2. Illustrate how to iteratively apply addition and multiplication operations to extend to dimensions beyond two. +(composite_basis_2d)= ## 2D Basis Functions Consider an instance where we want to capture a neuron's response to an animal's position within a given arena. From a1e6ecd39e140053013a0da34b3d48e69423859b Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Wed, 11 Dec 2024 11:14:43 -0500 Subject: [PATCH 16/21] correct filename typo --- src/nemos/_inspect_utils/{inpsect_utils.py => inspect_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/nemos/_inspect_utils/{inpsect_utils.py => inspect_utils.py} (100%) diff --git a/src/nemos/_inspect_utils/inpsect_utils.py b/src/nemos/_inspect_utils/inspect_utils.py similarity index 100% rename from src/nemos/_inspect_utils/inpsect_utils.py rename to src/nemos/_inspect_utils/inspect_utils.py From acda0f24e4c5f94ed0ed99e13befae68cba432fb Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Wed, 11 Dec 2024 11:26:19 -0500 Subject: [PATCH 17/21] small fixes in basis/readme --- docs/background/basis/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index bec07e73..74e9dde6 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -54,7 +54,7 @@ :height: 80px - :ref:`Head Direction ` - :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`
` - :class:`nemos.basis.RaisedCosineLogConv` + :class:`~nemos.basis.RaisedCosineLogConv` - 🔵 Conv * - **Orthogonalized Exponential Decays** - .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis @@ -88,12 +88,12 @@ Instead of tackling the hard problem of learning an unknown function $f(x)$ dire ## Basis in NeMoS -NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two key uses described in the overview: +NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two uses described above: -- **Eval-basis objects**: For representing non-linear mappings between task variables and outputs. These objects are identified by names starting with `Eval`. -- **Conv-basis objects**: For linear temporal effects. These objects are identified by names starting with `Conv`. +- **Eval basis objects**: For representing non-linear mappings between task variables and outputs. These objects all have names ending with `Eval`. +- **Conv basis objects**: For linear temporal effects. These objects all have names ending with `Conv`. -`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling complex feature construction. +`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling [complex feature construction](composing_basis_function). ## Learn More @@ -125,4 +125,4 @@ plot_02_ND_basis_function.md ``` ::: -:::: \ No newline at end of file +:::: From 387aeff5692eaebe2400546bfb29238380143e29 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Wed, 11 Dec 2024 23:09:41 -0500 Subject: [PATCH 18/21] Update docs/background/basis/README.md Co-authored-by: William F. Broderick --- docs/background/basis/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 74e9dde6..11d7a7ca 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -83,7 +83,7 @@ $$ Here, $\approx$ means "approximately equal". -Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. +Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. This preserves convexity, resulting in a much simpler optimization problem. ## Basis in NeMoS From 0fa80d443d6ebc9d7c4df101b8d483c21357cd4c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 00:00:16 -0500 Subject: [PATCH 19/21] fix paths and module name --- docs/background/README.md | 2 +- docs/background/basis/README.md | 30 +++-- .../basis/plot_01_1D_basis_function.md | 24 ---- .../basis/plot_02_ND_basis_function.md | 22 ---- docs/background/plot_03_1D_convolution.md | 2 +- docs/how_to_guide/plot_02_glm_demo.md | 2 +- docs/how_to_guide/plot_03_population_glm.md | 2 +- docs/how_to_guide/plot_04_batch_glm.md | 2 +- .../plot_05_sklearn_pipeline_cv_demo.md | 2 +- docs/how_to_guide/plot_06_glm_pytree.md | 2 +- docs/scripts/basis_figs.py | 108 ++++++++++++++++++ docs/scripts/basis_table_figs.py | 56 --------- docs/tutorials/plot_01_current_injection.md | 2 +- docs/tutorials/plot_02_head_direction.md | 2 +- docs/tutorials/plot_03_grid_cells.md | 2 +- docs/tutorials/plot_04_v1_cells.md | 2 +- docs/tutorials/plot_05_place_cells.md | 2 +- docs/tutorials/plot_06_calcium_imaging.md | 2 +- src/nemos/_inspect_utils/__init__.py | 2 +- 19 files changed, 140 insertions(+), 128 deletions(-) create mode 100644 docs/scripts/basis_figs.py delete mode 100644 docs/scripts/basis_table_figs.py diff --git a/docs/background/README.md b/docs/background/README.md index f67f4eda..3a21d06d 100644 --- a/docs/background/README.md +++ b/docs/background/README.md @@ -35,7 +35,7 @@ plot_00_conceptual_intro.md ```{eval-rst} -.. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear +.. plot:: scripts/basis_figs.py plot_raised_cosine_linear :show-source-link: False :height: 100px ``` diff --git a/docs/background/basis/README.md b/docs/background/basis/README.md index 74e9dde6..72a8a598 100644 --- a/docs/background/basis/README.md +++ b/docs/background/basis/README.md @@ -17,7 +17,7 @@ - **Evaluation/Convolution** - **Preferred Mode** * - **B-Spline** - - .. plot:: scripts/basis_table_figs.py plot_bspline + - .. plot:: scripts/basis_figs.py plot_bspline :show-source-link: False :height: 80px - :ref:`Grid cells ` @@ -25,7 +25,7 @@ :class:`~nemos.basis.BSplineConv` - 🟢 Eval * - **Cyclic B-Spline** - - .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline + - .. plot:: scripts/basis_figs.py plot_cyclic_bspline :show-source-link: False :height: 80px - :ref:`Place cells ` @@ -33,7 +33,7 @@ :class:`~nemos.basis.CyclicBSplineConv` - 🟢 Eval * - **M-Spline** - - .. plot:: scripts/basis_table_figs.py plot_mspline + - .. plot:: scripts/basis_figs.py plot_mspline :show-source-link: False :height: 80px - :ref:`Place cells ` @@ -41,7 +41,7 @@ :class:`~nemos.basis.MSplineConv` - 🟢 Eval * - **Linearly Spaced Raised Cosine** - - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear + - .. plot:: scripts/basis_figs.py plot_raised_cosine_linear :show-source-link: False :height: 80px - @@ -49,7 +49,7 @@ :class:`~nemos.basis.RaisedCosineLinearConv` - 🟢 Eval * - **Log Spaced Raised Cosine** - - .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log + - .. plot:: scripts/basis_figs.py plot_raised_cosine_log :show-source-link: False :height: 80px - :ref:`Head Direction ` @@ -57,7 +57,7 @@ :class:`~nemos.basis.RaisedCosineLogConv` - 🔵 Conv * - **Orthogonalized Exponential Decays** - - .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis + - .. plot:: scripts/basis_figs.py plot_orth_exp_basis :show-source-link: False :height: 80px - @@ -101,9 +101,12 @@ NeMoS provides a variety of basis functions (see the [table](table_basis) above) :::{grid-item-card} -
-One-Dimensional Basis. -
+```{eval-rst} + +.. plot:: scripts/basis_figs.py plot_1d_basis_thumbnail + :show-source-link: False + :height: 100px +``` ```{toctree} :maxdepth: 2 @@ -114,9 +117,12 @@ plot_01_1D_basis_function.md :::{grid-item-card} -
-N-Dimensional Basis. -
+```{eval-rst} + +.. plot:: scripts/basis_figs.py plot_nd_basis_thumbnail + :show-source-link: False + :height: 100px +``` ```{toctree} :maxdepth: 2 diff --git a/docs/background/basis/plot_01_1D_basis_function.md b/docs/background/basis/plot_01_1D_basis_function.md index 31131bf2..d4ac0070 100644 --- a/docs/background/basis/plot_01_1D_basis_function.md +++ b/docs/background/basis/plot_01_1D_basis_function.md @@ -81,30 +81,6 @@ plt.plot(x, y, lw=2) plt.title("B-Spline Basis") ``` -```{code-cell} ipython3 -:tags: [hide-input] - -# save image for thumbnail -from pathlib import Path -import os - -root = os.environ.get("READTHEDOCS_OUTPUT") -if root: - path = Path(root) / "html/_static/thumbnails/background" -# if local store in ../_build/html/... -else: - path = Path("../../_build/html/_static/thumbnails/background") - -# make sure the folder exists if run from build -if root or Path("../../_build/html/_static").exists(): - path.mkdir(parents=True, exist_ok=True) - -if path.exists(): - fig.savefig(path / "plot_01_1D_basis_function.svg") - - -print(path.resolve(), path.exists()) -``` ## Computing Features All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features). diff --git a/docs/background/basis/plot_02_ND_basis_function.md b/docs/background/basis/plot_02_ND_basis_function.md index 828e0ea8..9efeff40 100644 --- a/docs/background/basis/plot_02_ND_basis_function.md +++ b/docs/background/basis/plot_02_ND_basis_function.md @@ -292,28 +292,6 @@ axs[2, 1].set_xlabel('y-coord') plt.tight_layout() ``` -```{code-cell} ipython3 -:tags: [hide-input] - -# save image for thumbnail -from pathlib import Path -import os - -root = os.environ.get("READTHEDOCS_OUTPUT") -if root: - path = Path(root) / "html/_static/thumbnails/background" -# if local store in ../_build/html/... -else: - path = Path("../../_build/html/_static/thumbnails/background") - -# make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): - path.mkdir(parents=True, exist_ok=True) - -if path.exists(): - fig.savefig(path / "plot_02_ND_basis_function.svg") -``` - :::{note} Basis objects of different types can be combined through multiplication or addition. This feature is particularly useful when one of the axes represents a periodic variable and another is non-periodic. diff --git a/docs/background/plot_03_1D_convolution.md b/docs/background/plot_03_1D_convolution.md index 1967148d..a237a0f7 100644 --- a/docs/background/plot_03_1D_convolution.md +++ b/docs/background/plot_03_1D_convolution.md @@ -179,7 +179,7 @@ else: path = Path("../_build/html/_static/thumbnails/background") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/how_to_guide/plot_02_glm_demo.md b/docs/how_to_guide/plot_02_glm_demo.md index 83591da4..511a9d8c 100644 --- a/docs/how_to_guide/plot_02_glm_demo.md +++ b/docs/how_to_guide/plot_02_glm_demo.md @@ -450,7 +450,7 @@ else: path = Path("../_build/html/_static/thumbnails/how_to_guide") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/how_to_guide/plot_03_population_glm.md b/docs/how_to_guide/plot_03_population_glm.md index 7c27acd5..a129a65e 100644 --- a/docs/how_to_guide/plot_03_population_glm.md +++ b/docs/how_to_guide/plot_03_population_glm.md @@ -231,7 +231,7 @@ else: path = Path("../_build/html/_static/thumbnails/how_to_guide") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/how_to_guide/plot_04_batch_glm.md b/docs/how_to_guide/plot_04_batch_glm.md index 217de9ba..36bd31b8 100644 --- a/docs/how_to_guide/plot_04_batch_glm.md +++ b/docs/how_to_guide/plot_04_batch_glm.md @@ -208,7 +208,7 @@ else: path = Path("../_build/html/_static/thumbnails/how_to_guide") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md index 166a5cbb..9f5a9652 100644 --- a/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md +++ b/docs/how_to_guide/plot_05_sklearn_pipeline_cv_demo.md @@ -435,7 +435,7 @@ else: path = Path("../_build/html/_static/thumbnails/how_to_guide") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/how_to_guide/plot_06_glm_pytree.md b/docs/how_to_guide/plot_06_glm_pytree.md index e5949f58..4c82be40 100644 --- a/docs/how_to_guide/plot_06_glm_pytree.md +++ b/docs/how_to_guide/plot_06_glm_pytree.md @@ -261,7 +261,7 @@ else: path = Path("../_build/html/_static/thumbnails/how_to_guide") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/scripts/basis_figs.py b/docs/scripts/basis_figs.py new file mode 100644 index 00000000..5113a577 --- /dev/null +++ b/docs/scripts/basis_figs.py @@ -0,0 +1,108 @@ +import matplotlib.pyplot as plt +import numpy as np + +import nemos as nmo +from nemos._inspect_utils.inspect_utils import trim_kwargs + +plt.rcParams.update( + { + "figure.dpi": 300, + } +) + +KWARGS = dict( + n_basis_funcs=10, + decay_rates=np.arange(1, 10 + 1), + enforce_decay_to_zero=True, + order=4, + width=2, +) + + +def plot_basis(cls): + cls_params = cls._get_param_names() + new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) + bas = cls(**new_kwargs) + fig, ax = plt.subplots(1, 1, figsize=(5, 2.5)) + ax.plot(*bas.evaluate_on_grid(300), lw=4) + for side in ["left", "right", "top", "bottom"]: + ax.spines[side].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02) + + +def plot_raised_cosine_linear(): + plot_basis(nmo.basis.RaisedCosineLinearEval) + + +def plot_raised_cosine_log(): + plot_basis(nmo.basis.RaisedCosineLogEval) + + +def plot_mspline(): + plot_basis(nmo.basis.MSplineEval) + + +def plot_bspline(): + plot_basis(nmo.basis.BSplineEval) + + +def plot_cyclic_bspline(): + plot_basis(nmo.basis.CyclicBSplineEval) + + +def plot_orth_exp_basis(): + plot_basis(nmo.basis.OrthExponentialEval) + + +def plot_nd_basis_thumbnail(): + a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3) + b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14) + prod_basis = a_basis * b_basis + + x_coord = np.linspace(0, 1, 1000) + y_coord = np.linspace(0, 1, 1000) + + X, Y, Z = prod_basis.evaluate_on_grid(200, 200) + + # basis element pairs + element_pairs = [[0, 0], [5, 1], [10, 5]] + + # plot the 1D basis element and their product + fig, axs = plt.subplots(3,3,figsize=(8, 6)) + cc = 0 + for i, j in element_pairs: + # plot the element form a_basis + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b") + axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b') + + # plot the element form b_basis + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3) + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b") + axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b') + + # select & plot the corresponding product basis element + k = i * b_basis.n_basis_funcs + j + axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap='Blues') + axs[cc, 2].set_title(fr"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b') + axs[cc, 2].set_xlabel('x-coord') + axs[cc, 2].set_ylabel('y-coord') + axs[cc, 2].set_aspect("equal") + + cc += 1 + axs[2, 0].set_xlabel('x-coord') + axs[2, 1].set_xlabel('y-coord') + + plt.tight_layout() + +def plot_1d_basis_thumbnail(): + order = 4 + n_basis = 10 + bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order) + x, y = bspline.evaluate_on_grid(100) + + plt.figure(figsize=(5, 3)) + plt.plot(x, y, lw=2) + plt.title("B-Spline Basis") \ No newline at end of file diff --git a/docs/scripts/basis_table_figs.py b/docs/scripts/basis_table_figs.py deleted file mode 100644 index f60bed8c..00000000 --- a/docs/scripts/basis_table_figs.py +++ /dev/null @@ -1,56 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np - -import nemos as nmo -from nemos._inspect_utils import trim_kwargs - -plt.rcParams.update( - { - "figure.dpi": 300, - } -) - -KWARGS = dict( - n_basis_funcs=10, - decay_rates=np.arange(1, 10 + 1), - enforce_decay_to_zero=True, - order=4, - width=2, -) - - -def plot_basis(cls): - cls_params = cls._get_param_names() - new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params}) - bas = cls(**new_kwargs) - fig, ax = plt.subplots(1, 1, figsize=(5, 2.5)) - ax.plot(*bas.evaluate_on_grid(300), lw=4) - for side in ["left", "right", "top", "bottom"]: - ax.spines[side].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02) - - -def plot_raised_cosine_linear(): - plot_basis(nmo.basis.RaisedCosineLinearEval) - - -def plot_raised_cosine_log(): - plot_basis(nmo.basis.RaisedCosineLogEval) - - -def plot_mspline(): - plot_basis(nmo.basis.MSplineEval) - - -def plot_bspline(): - plot_basis(nmo.basis.BSplineEval) - - -def plot_cyclic_bspline(): - plot_basis(nmo.basis.CyclicBSplineEval) - - -def plot_orth_exp_basis(): - plot_basis(nmo.basis.OrthExponentialEval) diff --git a/docs/tutorials/plot_01_current_injection.md b/docs/tutorials/plot_01_current_injection.md index bffd5a35..361f513b 100644 --- a/docs/tutorials/plot_01_current_injection.md +++ b/docs/tutorials/plot_01_current_injection.md @@ -635,7 +635,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/tutorials/plot_02_head_direction.md b/docs/tutorials/plot_02_head_direction.md index 6a44fdd7..c038c0fa 100644 --- a/docs/tutorials/plot_02_head_direction.md +++ b/docs/tutorials/plot_02_head_direction.md @@ -717,7 +717,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/tutorials/plot_03_grid_cells.md b/docs/tutorials/plot_03_grid_cells.md index 6d244f28..75659a5e 100644 --- a/docs/tutorials/plot_03_grid_cells.md +++ b/docs/tutorials/plot_03_grid_cells.md @@ -348,7 +348,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/tutorials/plot_04_v1_cells.md b/docs/tutorials/plot_04_v1_cells.md index c9faaa82..3ee0bcff 100644 --- a/docs/tutorials/plot_04_v1_cells.md +++ b/docs/tutorials/plot_04_v1_cells.md @@ -247,7 +247,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/tutorials/plot_05_place_cells.md b/docs/tutorials/plot_05_place_cells.md index 4657c0f5..9726f5cb 100644 --- a/docs/tutorials/plot_05_place_cells.md +++ b/docs/tutorials/plot_05_place_cells.md @@ -159,7 +159,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/docs/tutorials/plot_06_calcium_imaging.md b/docs/tutorials/plot_06_calcium_imaging.md index c95987a9..426f9b6a 100644 --- a/docs/tutorials/plot_06_calcium_imaging.md +++ b/docs/tutorials/plot_06_calcium_imaging.md @@ -371,7 +371,7 @@ else: path = Path("../_build/html/_static/thumbnails/tutorials") # make sure the folder exists if run from build -if root or Path("../_build/html/_static").exists(): +if root or Path("../assets/stylesheets").exists(): path.mkdir(parents=True, exist_ok=True) if path.exists(): diff --git a/src/nemos/_inspect_utils/__init__.py b/src/nemos/_inspect_utils/__init__.py index 21743bd6..0fec61ec 100644 --- a/src/nemos/_inspect_utils/__init__.py +++ b/src/nemos/_inspect_utils/__init__.py @@ -9,7 +9,7 @@ identify abstract classes, and verify method compliance in subclasses. """ -from .inpsect_utils import ( +from .inspect_utils import ( check_all_abstract_methods_compliance, get_abstract_classes, get_non_abstract_classes, From 01e6e7dc29b830d30d7a9a26271001f73a7d2ed3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 00:11:31 -0500 Subject: [PATCH 20/21] linted --- docs/scripts/basis_figs.py | 27 +++++++++++++++------------ tests/test_basis.py | 4 +++- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/docs/scripts/basis_figs.py b/docs/scripts/basis_figs.py index 5113a577..af8b6701 100644 --- a/docs/scripts/basis_figs.py +++ b/docs/scripts/basis_figs.py @@ -70,33 +70,36 @@ def plot_nd_basis_thumbnail(): element_pairs = [[0, 0], [5, 1], [10, 5]] # plot the 1D basis element and their product - fig, axs = plt.subplots(3,3,figsize=(8, 6)) + fig, axs = plt.subplots(3, 3, figsize=(8, 6)) cc = 0 for i, j in element_pairs: # plot the element form a_basis - axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3) + axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=0.3) axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b") - axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b') + axs[cc, 0].set_title(f"$a_{{{i}}}(x)$", color="b") # plot the element form b_basis - axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3) + axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=0.3) axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b") - axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b') + axs[cc, 1].set_title(f"$b_{{{j}}}(y)$", color="b") # select & plot the corresponding product basis element k = i * b_basis.n_basis_funcs + j - axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap='Blues') - axs[cc, 2].set_title(fr"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b') - axs[cc, 2].set_xlabel('x-coord') - axs[cc, 2].set_ylabel('y-coord') + axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap="Blues") + axs[cc, 2].set_title( + rf"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color="b" + ) + axs[cc, 2].set_xlabel("x-coord") + axs[cc, 2].set_ylabel("y-coord") axs[cc, 2].set_aspect("equal") cc += 1 - axs[2, 0].set_xlabel('x-coord') - axs[2, 1].set_xlabel('y-coord') + axs[2, 0].set_xlabel("x-coord") + axs[2, 1].set_xlabel("y-coord") plt.tight_layout() + def plot_1d_basis_thumbnail(): order = 4 n_basis = 10 @@ -105,4 +108,4 @@ def plot_1d_basis_thumbnail(): plt.figure(figsize=(5, 3)) plt.plot(x, y, lw=2) - plt.title("B-Spline Basis") \ No newline at end of file + plt.title("B-Spline Basis") diff --git a/tests/test_basis.py b/tests/test_basis.py index 47db8dd0..0f4574db 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -971,7 +971,9 @@ def test_fit_kernel_shape(self, cls): ( "conv", -1, - pytest.raises(ValueError, match="`window_size` must be a positive integer"), + pytest.raises( + ValueError, match="`window_size` must be a positive integer" + ), ), ( "conv", From afdf8a8e70e8bb5f838f5dc8d8865267b1b083f2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 12 Dec 2024 08:31:11 -0500 Subject: [PATCH 21/21] added pages in readme subsection --- docs/background/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/background/README.md b/docs/background/README.md index 3a21d06d..435e5bdb 100644 --- a/docs/background/README.md +++ b/docs/background/README.md @@ -41,10 +41,11 @@ plot_00_conceptual_intro.md ``` ```{toctree} -:maxdepth: 2 +:maxdepth: 3 basis/README.md ``` + ::: :::{grid-item-card}