diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 61e92b0e..2394576e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,18 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + prevent_docs_absolute_links: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Make .sh executable + run: chmod +x bash_scripts/prevent_absolute_links_to_docs.sh + + - name: Check links + run: ./bash_scripts/prevent_absolute_links_to_docs.sh + check: if: ${{ !github.event.pull_request.draft }} needs: tox diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3bcc3e63..5cbb99cb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -99,6 +99,10 @@ Lastly, you should make sure that the existing tests all run successfully and th ```bash # run tests and make sure they all pass pytest tests/ + +# run doctest (run all examples in docstrings and match output) +pytest --doctest-modules src/nemos/ + # format the code base black src/ isort src @@ -184,38 +188,89 @@ properly documented as outlined below. #### Adding documentation -1) **Docstrings** - -All public-facing functions and classes should have complete docstrings, which start with a one-line short summary of the function, -a medium-length description of the function / class and what it does, and a complete description of all arguments and return values. -Math should be included in a `Notes` section when necessary to explain what the function is doing, and references to primary literature -should be included in a `References` section when appropriate. Docstrings should be relatively short, providing the information necessary -for a user to use the code. - -Private functions and classes should have sufficient explanation that other developers know what the function / class does and how to use it, -but do not need to be as extensive. - -We follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/) conventions for docstring structure. - -2) **Examples/Tutorials** - -If your changes are significant (add a new functionality or drastically change the current codebase), then the current examples may need to be updated or -a new example may need to be added. +1. **Docstrings** + + All public-facing functions and classes should have complete docstrings, which start with a one-line short summary of the function, a medium-length description of the function/class and what it does, a complete description of all arguments and return values, and an example to illustrate usage. Math should be included in a `Notes` section when necessary to explain what the function is doing, and references to primary literature should be included in a `References` section when appropriate. Docstrings should be relatively short, providing the information necessary for a user to use the code. + + Private functions and classes should have sufficient explanation that other developers know what the function/class does and how to use it, but do not need to be as extensive. + + We follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/) conventions for docstring structure. + +2. **Examples/Tutorials** + + If your changes are significant (add a new functionality or drastically change the current codebase), then the current examples may need to be updated or a new example may need to be added. + + All examples live within the `docs/` subfolder of `nemos`. These are written as `.py` files but are converted to notebooks by [`mkdocs-gallery`](https://smarie.github.io/mkdocs-gallery/), and have a special syntax, as demonstrated in this [example gallery](https://smarie.github.io/mkdocs-gallery/generated/gallery/). + + We avoid using `.ipynb` notebooks directly because their JSON-based format makes them difficult to read, interpret, and resolve merge conflicts in version control. + + To see if changes you have made break the current documentation, you can build the documentation locally. + + ``` + # Clear the cached documentation pages + # This step is only necessary if your changes affected the src/ directory + rm -r docs/generated + # build the docs within the nemos repo + mkdocs build + ``` + + If the build fails, you will see line-specific errors that prompted the failure. + +3. **Doctest: Test the example code in your docs** + + Doctests are a great way to ensure that code examples in your documentation remain accurate as the codebase evolves. With doctests, we will test any docstrings, Markdown files, or any other text-based documentation that contains code formatted as interactive Python sessions. + + - **Docstrings:** + To include doctests in your function and class docstrings you must add an `Examples` section. The examples should be formatted as if you were typing them into a Python interactive session, with `>>>` used to indicate commands and expected outputs listed immediately below. + + ```python + def add(a, b): + """ + The sum of two numbers. + + ...Other docstrings sections (Parameters, Returns...) + + Examples + -------- + An expected output is required. + >>> add(1, 2) + 3 + + Unless the output is captured. + >>> out = add(1, 2) + + """ + return a + b + ``` + + To validate all your docstrings examples, run pytest `--doctest-module` flag, + + ``` + pytest --doctest-modules src/nemos/ + ``` + + This test is part of the Continuous Integration, every example must pass before we can merge a PR. + + - **Documentation Pages:** + Doctests can also be included in Markdown files by using code blocks with the `python` language identifier and interactive Python examples. To enable this functionality, ensure that code blocks follow the standard Python doctest format: + + ```markdown + ```python + >>> # Add any code + >>> x = 3 ** 2 + >>> x + 1 + 10 + + ``` + ``` + + To run doctests on a text file, use the following command: + + ``` + python -m doctest -v path-to-your-text-file/file_name.md + ``` + + All MarkDown files will be tested as part of the Continuous Integration. -All examples live within the `docs/` subfolder of `nemos`. These are written as `.py` files but are converted to -notebooks by [`mkdocs-gallery`](https://smarie.github.io/mkdocs-gallery/), and have a special syntax, as demonstrated in this [example -gallery](https://smarie.github.io/mkdocs-gallery/generated/gallery/). - -We avoid using `.ipynb` notebooks directly because their JSON-based format makes them difficult to read, interpret, and resolve merge conflicts in version control. - -To see if changes you have made break the current documentation, you can build the documentation locally. - -```bash -# Clear the cached documentation pages -# This step is only necessary if your changes affected the src/ directory -rm -r docs/generated -# build the docs within the nemos repo -mkdocs build -``` - -If the build fails, you will see line-specific errors that prompted the failure. +> [!NOTE] +> All internal links to NeMoS documentation pages **must be relative**. Using absolute links can lead to broken references whenever the documentation structure is altered. Any presence of absolute links to documentation pages will cause the continuous integration checks to fail. Please ensure all links follow the relative format before submitting your PR. diff --git a/bash_scripts/prevent_absolute_links_to_docs.sh b/bash_scripts/prevent_absolute_links_to_docs.sh new file mode 100644 index 00000000..51b5f559 --- /dev/null +++ b/bash_scripts/prevent_absolute_links_to_docs.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Check for any unallowed absolute links in documentation excluding the badge. +if grep -r -E https?://nemos.* docs/ | grep -v "badge"; then + echo "Error: Unallowed absolute links found in documentation." >&2 + exit 1 +else + echo "No unallowed absolute links found." +fi diff --git a/docs/assets/glm_features_scheme.svg b/docs/assets/glm_features_scheme.svg index dc0427a9..2ce956e7 100644 --- a/docs/assets/glm_features_scheme.svg +++ b/docs/assets/glm_features_scheme.svg @@ -60,15 +60,15 @@ inkscape:pageopacity="0.0" inkscape:pagecheckerboard="0" inkscape:deskcolor="#d1d1d1" - inkscape:zoom="2" - inkscape:cx="203" - inkscape:cy="97.75" - inkscape:window-width="1928" - inkscape:window-height="1212" + inkscape:zoom="3.2990946" + inkscape:cx="517.56624" + inkscape:cy="16.519684" + inkscape:window-width="2208" + inkscape:window-height="858" inkscape:window-x="0" inkscape:window-y="25" inkscape:window-maximized="0" - inkscape:current-layer="layer-oc1"> + inkscape:current-layer="g168"> - + d="m 591.00781,25.410157 h 4.29682 c 0.85155,0 1.53904,0.242184 2.06248,0.718741 0.51952,0.480462 0.78124,1.156235 0.78124,2.031224 0,0.742178 -0.23438,1.386701 -0.70312,1.937475 -0.46093,0.554681 -1.17186,0.828115 -2.1406,0.828115 h -2.99996 v 4.046823 h -1.29686 z m 5.82805,2.749965 c 0,-0.707022 -0.26171,-1.187485 -0.78124,-1.437482 -0.28124,-0.132811 -0.67186,-0.203122 -1.17186,-0.203122 h -2.57809 v 3.312457 h 2.57809 c 0.58203,0 1.05077,-0.117186 1.40623,-0.35937 0.36328,-0.249997 0.54687,-0.687491 0.54687,-1.312483 z m 5.54681,6.062423 c 0.78124,0 1.31248,-0.289059 1.59373,-0.874989 0.28906,-0.593743 0.43749,-1.249984 0.43749,-1.968725 0,-0.656242 -0.10546,-1.187485 -0.31249,-1.59373 -0.33594,-0.644523 -0.90234,-0.968738 -1.70311,-0.968738 -0.71874,0 -1.24217,0.277341 -1.56248,0.828115 -0.32421,0.542962 -0.48437,1.20311 -0.48437,1.98435 0,0.742178 0.16016,1.359357 0.48437,1.859351 0.32031,0.492181 0.83593,0.734366 1.54686,0.734366 z m 0.0469,-6.421794 c 0.89452,0 1.64842,0.304684 2.2656,0.906239 0.62499,0.593742 0.93749,1.476544 0.93749,2.640591 0,1.117174 -0.27734,2.042943 -0.82812,2.781215 -0.54296,0.730459 -1.3867,1.093736 -2.53122,1.093736 -0.96092,0 -1.71873,-0.320308 -2.28122,-0.968738 -0.56249,-0.644523 -0.84374,-1.515605 -0.84374,-2.609342 0,-1.175766 0.29687,-2.109348 0.89062,-2.796839 0.59374,-0.695304 1.3906,-1.046862 2.39059,-1.046862 z m -0.0469,0.03125 z m 4.65228,0.203123 h 1.18749 v 6.937412 h -1.18749 z m 0,-2.624967 h 1.18749 v 1.328108 h -1.18749 z m 3.66402,7.374906 c 0.0312,0.398433 0.125,0.69921 0.28124,0.906239 0.30078,0.374995 0.8164,0.562493 1.54686,0.562493 0.42578,0 0.80468,-0.09375 1.14061,-0.281247 0.33203,-0.187497 0.49999,-0.476556 0.49999,-0.874989 0,-0.300777 -0.13671,-0.531243 -0.40624,-0.687491 -0.16797,-0.09375 -0.5,-0.203122 -0.99999,-0.328121 l -0.93749,-0.234372 c -0.59374,-0.144529 -1.03123,-0.312496 -1.31248,-0.499993 -0.49999,-0.312496 -0.74999,-0.749991 -0.74999,-1.312484 0,-0.656241 0.23437,-1.187485 0.70312,-1.593729 0.47655,-0.414058 1.11717,-0.624993 1.92185,-0.624993 1.05076,0 1.81247,0.30859 2.28122,0.921864 0.28906,0.398432 0.42968,0.820302 0.42187,1.265609 h -1.10936 c -0.0234,-0.25781 -0.11719,-0.499994 -0.28125,-0.718741 -0.27343,-0.312496 -0.74218,-0.468744 -1.40623,-0.468744 -0.4375,0 -0.77343,0.08984 -0.99999,0.265621 -0.23047,0.167967 -0.34375,0.39062 -0.34375,0.671867 0,0.304683 0.14844,0.542962 0.45312,0.718741 0.17578,0.117186 0.42969,0.214841 0.76562,0.296871 l 0.78124,0.187498 c 0.83202,0.199216 1.39451,0.398432 1.68748,0.593742 0.45702,0.292965 0.68749,0.761709 0.68749,1.406232 0,0.61718 -0.24219,1.152329 -0.71874,1.609355 -0.46875,0.445307 -1.18358,0.671866 -2.1406,0.671866 -1.04296,0 -1.78123,-0.234372 -2.21872,-0.703116 -0.42968,-0.468744 -0.65624,-1.050768 -0.68749,-1.749978 z M 612.4177,27.832 Z m 4.94915,4.953062 c 0.0313,0.398433 0.125,0.69921 0.28125,0.906239 0.30078,0.374995 0.81639,0.562493 1.54685,0.562493 0.42578,0 0.80468,-0.09375 1.14061,-0.281247 0.33203,-0.187497 0.5,-0.476556 0.5,-0.874989 0,-0.300777 -0.13672,-0.531243 -0.40625,-0.687491 -0.16796,-0.09375 -0.49999,-0.203122 -0.99998,-0.328121 l -0.93749,-0.234372 c -0.59375,-0.144529 -1.03124,-0.312496 -1.31249,-0.499993 -0.49999,-0.312496 -0.74999,-0.749991 -0.74999,-1.312484 0,-0.656241 0.23438,-1.187485 0.70312,-1.593729 0.47656,-0.414058 1.11717,-0.624993 1.92185,-0.624993 1.05077,0 1.81248,0.30859 2.28122,0.921864 0.28906,0.398432 0.42968,0.820302 0.42187,1.265609 h -1.10936 c -0.0234,-0.25781 -0.11719,-0.499994 -0.28125,-0.718741 -0.27343,-0.312496 -0.74217,-0.468744 -1.40623,-0.468744 -0.43749,0 -0.77343,0.08984 -0.99999,0.265621 -0.23046,0.167967 -0.34374,0.39062 -0.34374,0.671867 0,0.304683 0.14843,0.542962 0.45312,0.718741 0.17578,0.117186 0.42968,0.214841 0.76561,0.296871 l 0.78124,0.187498 c 0.83202,0.199216 1.39452,0.398432 1.68748,0.593742 0.45703,0.292965 0.68749,0.761709 0.68749,1.406232 0,0.61718 -0.24218,1.152329 -0.71874,1.609355 -0.46874,0.445307 -1.18358,0.671866 -2.1406,0.671866 -1.04295,0 -1.78122,-0.234372 -2.21872,-0.703116 -0.42968,-0.468744 -0.65624,-1.050768 -0.68749,-1.749978 z M 619.08558,27.832 Z m 7.01163,6.390544 c 0.78124,0 1.31248,-0.289059 1.59373,-0.874989 0.28906,-0.593742 0.43749,-1.249984 0.43749,-1.968725 0,-0.656241 -0.10546,-1.187485 -0.31249,-1.59373 -0.33594,-0.644523 -0.90233,-0.968737 -1.70311,-0.968737 -0.71874,0 -1.24217,0.27734 -1.56248,0.828114 -0.32421,0.542962 -0.48436,1.20311 -0.48436,1.98435 0,0.742178 0.16015,1.359358 0.48436,1.859351 0.32031,0.492182 0.83593,0.734366 1.54686,0.734366 z m 0.0469,-6.421793 c 0.89452,0 1.64842,0.304683 2.2656,0.906238 0.62499,0.593743 0.93748,1.476544 0.93748,2.640592 0,1.117173 -0.27734,2.042942 -0.82811,2.781214 -0.54296,0.73046 -1.3867,1.093736 -2.53122,1.093736 -0.96092,0 -1.71873,-0.320308 -2.28122,-0.968737 -0.56249,-0.644523 -0.84374,-1.515606 -0.84374,-2.609342 0,-1.175766 0.29687,-2.109348 0.89061,-2.79684 0.59375,-0.695303 1.39061,-1.046861 2.3906,-1.046861 z m -0.0469,0.03125 z m 4.65228,0.171872 h 1.10936 v 0.984363 c 0.33203,-0.406245 0.67968,-0.695304 1.04687,-0.874989 0.37499,-0.175779 0.78514,-0.265622 1.23436,-0.265622 0.98827,0 1.65622,0.343746 1.99997,1.031237 0.19531,0.374996 0.29687,0.917957 0.29687,1.62498 v 4.468694 h -1.20311 v -4.390569 c 0,-0.425776 -0.0625,-0.769521 -0.1875,-1.031237 -0.21093,-0.437494 -0.58593,-0.656242 -1.12498,-0.656242 -0.28125,0 -0.51171,0.03125 -0.68749,0.09375 -0.32422,0.09375 -0.60546,0.281246 -0.84374,0.562492 -0.19922,0.242185 -0.32812,0.484369 -0.39062,0.734366 -0.0547,0.242185 -0.0781,0.58593 -0.0781,1.031237 v 3.656204 h -1.17187 z M 633.49946,27.832 Z m 0,0" + style="fill:#000000;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:1.33333" + aria-label="Poisson" /> + random + variable + inkscape:current-layer="g137-8"> weights inverse link + id="g137-8" + transform="translate(-32.857239,-0.2436684)"> Poisson + id="tspan125-2">Poisson random + r.v. + id="tspan126-8-2">variable li > p:first-of-type img { margin-right: 8px; } + +/* this sets up the figure in the quickstart to have the appropriate size*/ +.custom-figure { + width: 70% !important; /* Set the desired width */ +} + +/* Apply this style to the image within the figure */ +.custom-figure img { + width: 100%; /* Ensure the image fits the figure */ +} \ No newline at end of file diff --git a/docs/how_to_guide/README.md b/docs/how_to_guide/README.md index 0f26d81b..fa4b58fb 100644 --- a/docs/how_to_guide/README.md +++ b/docs/how_to_guide/README.md @@ -1,4 +1,5 @@ -# How-to Guide + +# How-To Guide Familiarize with NeMoS modules and learn how to take advantage of the `pynapple` and `scikit-learn` compatibility. diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py index 84282477..70dac9cd 100644 --- a/docs/how_to_guide/plot_04_population_glm.py +++ b/docs/how_to_guide/plot_04_population_glm.py @@ -23,9 +23,10 @@ """ import jax.numpy as jnp -import nemos as nmo -import numpy as np import matplotlib.pyplot as plt +import numpy as np + +import nemos as nmo np.random.seed(123) diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py index f9e758fc..84f64d98 100644 --- a/docs/how_to_guide/plot_05_batch_glm.py +++ b/docs/how_to_guide/plot_05_batch_glm.py @@ -6,10 +6,11 @@ """ +import matplotlib.pyplot as plt +import numpy as np import pynapple as nap + import nemos as nmo -import numpy as np -import matplotlib.pyplot as plt nap.nap_config.suppress_conversion_warnings = True diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py index b7168e33..ca9b167a 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py @@ -71,20 +71,19 @@ # ## Combining basis transformations and GLM in a pipeline # Let's start by creating some toy data. -import nemos as nmo +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.stats -import matplotlib.pyplot as plt import seaborn as sns - -from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import Pipeline + +import nemos as nmo # some helper plotting functions from nemos import _documentation_utils as doc_plots - # predictors, shape (n_samples, n_features) X = np.random.uniform(low=0, high=1, size=(1000, 1)) # observed counts, shape (n_samples,) diff --git a/docs/index.md b/docs/index.md index 562491b5..d7c362d8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -35,6 +35,26 @@ We provide a **Poisson GLM** for analyzing spike counts, and a **Gamma GLM** for
+- :material-hammer-wrench:{ .lg .middle }   __Installation Instructions__ + + --- + + Run the following `pip` command in your __virtual environment__. + === "macOS/Linux" + + ```bash + pip install nemos + ``` + + === "Windows" + + ``` + python -m pip install nemos + ``` + + *For more information see:*
+ [:octicons-arrow-right-24: Install](installation) + - :material-clock-fast:{ .lg .middle } __Getting Started__ --- @@ -51,23 +71,23 @@ We provide a **Poisson GLM** for analyzing spike counts, and a **Gamma GLM** for [:octicons-arrow-right-24: Background](generated/background) -- :material-lightbulb-on-10:{ .lg .middle }   __How-To Guide__ +- :material-brain:{ .lg .middle}   __Neural Modeling__ --- - Already familiar with the concepts? Learn how you to process and analyze your data with NeMoS. + Explore fully worked examples to learn how to analyze neural recordings from scratch. *Requires familiarity with the theory.*
- [:octicons-arrow-right-24: How-To Guide](generated/how_to_guide) + [:octicons-arrow-right-24: Tutorials](generated/tutorials) -- :material-brain:{ .lg .middle}   __Neural Modeling__ +- :material-lightbulb-on-10:{ .lg .middle }   __How-To Guide__ --- - Explore fully worked examples to learn how to analyze neural recordings from scratch. + Already familiar with the concepts? Learn how you to process and analyze your data with NeMoS. *Requires familiarity with the theory.*
- [:octicons-arrow-right-24: Tutorials](generated/tutorials) + [:octicons-arrow-right-24: How-To Guide](generated/how_to_guide) - :material-cog:{ .lg .middle }   __API Guide__ @@ -78,26 +98,6 @@ We provide a **Poisson GLM** for analyzing spike counts, and a **Gamma GLM** for *Requires familiarity with the theory.*
[:octicons-arrow-right-24: API Guide](reference/SUMMARY) -- :material-hammer-wrench:{ .lg .middle }   __Installation Instructions__ - - --- - - Run the following `pip` command in your virtual environment. - === "macOS/Linux" - - ```bash - pip install nemos - ``` - - === "Windows" - - ``` - python -m pip install nemos - ``` - - *For more information see:*
- [:octicons-arrow-right-24: Install](installation) -
## :material-scale-balance:{ .lg } License diff --git a/docs/quickstart.md b/docs/quickstart.md index d60e8c37..aa234489 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -2,198 +2,442 @@ hide: - navigation --- +## **Overview** -This tutorial will introduce the main NeMoS functionalities. This is intended for users that are -already familiar with the GLM framework but want to learn how to interact with the NeMoS API. -If you have used [scikit-learn](https://scikit-learn.org/stable/) before, we are compatible with the [estimator API](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html), so the quickstart should look -familiar. +NeMoS is a neural modeling software package designed to model neural spiking activity and other time-series data + powered by [JAX](https://jax.readthedocs.io/en/latest/). -In the following sessions, you will learn: +At its core, NeMoS consists of two primary modules: the **`glm`** and the **`basis`** module: -1. [How to define and fit a GLM model.](#basic-model-fitting) -2. [What are the GLM input arguments.](#model-arguments) -3. [How to use NeMoS with `pynapple` for pre-processing.](#pre-processing-with-pynapple) -4. [How to use NeMoS with `scikit-learn` for pipelines and cross-validation.](#compatibility-with-scikit-learn) +The **`glm`** module implements a Generalized Linear Model (GLM) to map features to neural activity, such as +spike counts or calcium transients. It supports learning GLM weights, evaluating model performance, and exploring +model behavior on new inputs. -Each of these sections can be run independently of the others. -### Basic Model Fitting +The **`basis`** module focuses on designing model features (inputs) for the GLM. +It includes a suite of composable feature constructors that accept time-series data, allowing users to model a wide +range of observed variables—such as stimuli, head direction, position, or spike counts— as inputs to the GLM. -Defining and fitting a NeMoS GLM model is straightforward: + +## **Generalized Linear Model** + +NeMoS provides two implementations of the GLM: one for fitting a single neuron, and one for fitting a neural population simultaneously. + +### **Single Neuron GLM** + +You can define a single neuron GLM by instantiating an `GLM` object. ```python -import nemos as nmo -import numpy as np -# predictors, shape (n_samples, n_features) -X = 0.2 * np.random.normal(size=(100, 1)) -# true coefficients, shape (n_features) -coef = np.random.normal(size=(1, )) -# observed counts, shape (n_samples, ) -y = np.random.poisson(np.exp(np.matmul(X, coef))) +>>> import nemos as nmo + +>>> # Instantiate the single model +>>> model = nmo.glm.GLM() -# model definition -model = nmo.glm.GLM() -# model fitting -model.fit(X, y) ``` +The coefficients can be learned by invoking the `fit` method of `GLM`. The method requires a design +matrix of shape `(num_samples, num_features)`, and the output neural activity of shape `(num_samples, )`. + +```python + +>>> import numpy as np +>>> num_samples, num_features = 100, 3 + +>>> # Generate a design matrix +>>> X = np.random.normal(size=(num_samples, num_features)) +>>> # generate some counts +>>> spike_counts = np.random.poisson(size=num_samples) -Once fit, you can retrieve model parameters as follows, +>>> # define fit the model +>>> model = model.fit(X, spike_counts) + +``` + +Once the model is fit, you can retrieve the model parameters as shown below. ```python ->>> # model coefficients, shape (n_features, ) ->>> print(f"Model coefficients: {model.coef_}") -Model coefficients: [-1.5791758] +>>> # model coefficients shape is (num_features, ) +>>> print(f"Model coefficients shape: {model.coef_.shape}") +Model coefficients shape: (3,) + +>>> # model intercept, shape (1,) since there is only one neuron. +>>> print(f"Model intercept shape: {model.intercept_.shape}") +Model intercept shape: (1,) ->>> # model coefficients, shape (1, ) ->>> print(f"Model intercept: {model.intercept_}") -Model intercept: [-0.0010547] ``` +Additionally, you can predict the firing rate and call the compute the model log-likelihood by calling the `predict` and the `score` method respectively. -### Model Arguments +```python -During initialization, the `GLM` class accepts the following optional input arguments, +>>> # predict the rate +>>> predicted_rate = model.predict(X) +>>> # firing rate has shape: (num_samples,) +>>> predicted_rate.shape +(100,) -1. `model.observation_model`: The statistical model for the observed variable. The available option so far are `nemos.observation_models.PoissonObservation` and `nemos.observation_models.GammaObservations`, which are the most common choices for modeling spike counts and calcium imaging traces respectively. -2. `model.regularizer`: Determines the regularization type, defaulting to `nemos.regularizer.Unregularized`. This parameter can be provided either as a string ("unregularized", "ridge", "lasso", or "group_lasso") or as an instance of `nemos.regularizer.Regularizer`. +>>> # compute the log-likelihood of the model +>>> log_likelihood = model.score(X, spike_counts) -For more information on how to change default arguments, see the API guide for [`observation_models`](../reference/nemos/observation_models) and -[`regularizer`](../reference/nemos/regularizer). +``` + +### **Population GLM** + +You can set up a population GLM by instantiating a `PopulationGLM`. The API for the `PopulationGLM` is the same as for the single-neuron `GLM`; the only difference you'll notice is that some of the methods' inputs and outputs have an additional dimension for the different neurons. ```python -import nemos as nmo -# initialize a Gamma GLM with Ridge regularization -model = nmo.glm.GLM( - regularizer="ridge", - observation_model=nmo.observation_models.GammaObservations() -) +>>> import nemos as nmo +>>> population_model = nmo.glm.PopulationGLM() + ``` +As for the single neuron GLM, you can learn the model parameters by invoking the `fit` method: the input of `fit` are the design matrix (with shape `(num_samples, num_features)` ), and the population activity (with shape `(num_samples, num_neurons)`). +Once the model is fit, you can use `predict` and `score` to predict the firing rate and the log-likelihood. + +```python + +>>> import numpy as np +>>> num_samples, num_features, num_neurons = 100, 3, 5 + +>>> # simulate a design matrix +>>> X = np.random.normal(size=(num_samples, num_features)) +>>> # simulate some counts +>>> spike_counts = np.random.poisson(size=(num_samples, num_neurons)) + +>>> # fit the model +>>> population_model = population_model.fit(X, spike_counts) + +>>> # predict the rate of each neuron in the population +>>> predicted_rate = population_model.predict(X) +>>> predicted_rate.shape # expected shape: (num_samples, num_neurons) +(100, 5) + +>>> # compute the log-likelihood of the model +>>> log_likelihood = population_model.score(X, spike_counts) + +``` + +The learned coefficient and intercept will have shape `(num_features, num_neurons)` and `(num_neurons, )` respectively. + +```python +>>> # model coefficients shape is (num_features, num_neurons) +>>> print(f"Model coefficients shape: {population_model.coef_.shape}") +Model coefficients shape: (3, 5) + +>>> # model intercept, (num_neurons,) +>>> print(f"Model intercept shape: {population_model.intercept_.shape}") +Model intercept shape: (5,) + +``` + + +## **Basis: Feature Construction** + +The `basis` module includes objects that perform two types of transformations on the inputs: + +1. **Non-linear Mapping:** This process transforms the input data through a non-linear function, + allowing it to capture complex, non-linear relationships between inputs and neuronal firing rates. + Importantly, this transformation preserves the properties that makes GLM easy to fit and guarantee a + single optimal solution (e.g. convexity). + +2. **Convolution:** This applies a convolution of the input data with a bank of filters, designed to + capture linear temporal effects. This transformation is particularly useful when analyzing data with + inherent time dependencies or when the temporal dynamics of the input are significant. + Both transformations produce a vector of features `X` that changes over time, with a shape + of `(n_time_points, n_features)`. + +### **Non-linear Mapping** + +
+ +
Figure 1: Basis as non-linear mappings. The figure demonstrate the use of basis functions to create complex non-linear features for a GLM.
+
+ +Non-linear mapping is the default mode of operation of any `basis` object. To instantiate a basis for non-linear mapping, +you need to specify the number of basis functions. For some `basis` objects, additional arguments may be required (see the [API Guide](../reference/nemos/basis) for detailed information). + +```python + +>>> import nemos as nmo + +>>> n_basis_funcs = 10 +>>> basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs) + +``` + +Once the basis is instantiated, you can apply it to your input data using the `compute_features` method. +This method takes an input array of shape `(n_samples, )` and transforms it into a two-dimensional array of +shape `(n_samples, n_basis_funcs)`, where each column represents a feature generated by the non-linear mapping. + +```python + +>>> import numpy as np + +>>> # generate an input +>>> x = np.arange(100) + +>>> # evaluate the basis +>>> X = basis.compute_features(x) +>>> X.shape +(100, 10) + +``` + +### **Convolution** + +
+ GLM Population Scheme +
Figure 2: Basis as a bank of convolutional filters. The figure shows a population GLM for functional connectivity analysis, a classical use-case for basis functions in convolutional mode.
+ +
+ +If you want to convolve a bank of basis functions with an input you must set the mode of operation of a basis object to +`"conv"` and you must provide an integer `window_size` parameter, which defines the length of the filter bank in +number of sample points. + +```python + +>>> import nemos as nmo + +>>> n_basis_funcs = 10 +>>> # define a filter bank of 10 basis function, 200 samples long. +>>> basis = nmo.basis.BSplineBasis(n_basis_funcs, mode="conv", window_size=200) + +``` + +Once the basis is initialized, you can call `compute_features` on an input of shape `(n_samples, )` or +`(n_samples, n_signals)` to perform the convolution. The output will be a 2-dimensional array of shape +`(n_samples, n_basis_funcs)` or `(n_samples, n_basis_funcs * n_signals)` respectively. + +!!! warning "Signal length and window size" + The `window_size` must be shorter than the number of samples in the signal(s) being convolved. + +```python + +>>> import numpy as np + +>>> x = np.ones(500) + +>>> # convolve a single signal +>>> X = basis.compute_features(x) +>>> X.shape +(500, 10) + +>>> x_multi = np.ones((500, 3)) + +>>> # convolve a multiple signals +>>> X_multi = basis.compute_features(x_multi) +>>> X_multi.shape +(500, 30) + +``` + +For additional information on one-dimensional convolutions, see [here](../generated/background/plot_03_1D_convolution). + +## **Continuous Observations** + + +By default, NeMoS' GLM uses [Poisson observations](../reference/nemos/observation_models/#nemos.observation_models.PoissonObservations), which are a natural choice for spike counts. However, the package also supports a [Gamma](../reference/nemos/observation_models/#nemos.observation_models.GammaObservations) GLM, which is more appropriate for modeling continuous, non-negative observations such as calcium transients. + +To change the default observation model, set the `observation_model` argument during initialization: + + +```python + +>>> import nemos as nmo + +>>> # set up a Gamma GLM for modeling continuous non-negative data +>>> glm = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) + +``` + + +Take a look at our [tutorial](../generated/tutorials/plot_06_calcium_imaging) for a detailed example. + + +## **Regularization** + + +NeMoS supports various regularization schemes, including [Ridge](../reference/nemos/regularizer/#nemos.regularizer.Ridge) ($L_2$), [Lasso](../reference/nemos/regularizer/#nemos.regularizer.Lasso) ($L_1$), and [Group Lasso](../reference/nemos/regularizer/#nemos.regularizer.GroupLasso), to prevent overfitting and improve model generalization. + +You can specify the regularization scheme and its strength when initializing the GLM model: + + +```python + +>>> import nemos as nmo + +>>> # Instantiate a GLM with Ridge (L2) regularization +>>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) + +``` + + + +## **Pre-processing with `pynapple`** -### Pre-processing with `pynapple` !!! warning - This section assumes some familiarity with the `pynapple` package for time series manipulation and data + + This section assumes some familiarity with the `pynapple` package for time series manipulation and data exploration. If you'd like to learn more about it, take a look at the [`pynapple` documentation](https://pynapple-org.github.io/pynapple/). -`pynapple` is an extremely helpful tool when working with time series data. You can easily perform operations such -as restricting your time series to specific epochs (sleep/wake, context A vs. context B, etc.), as well as common + +`pynapple` is an extremely helpful tool when working with time series data. You can easily perform operations such +as restricting your time series to specific epochs (sleep/wake, context A vs. context B, etc.), as well as common pre-processing steps in a robust and efficient manner. This includes bin-averaging, counting, convolving, smoothing and many others. All these operations can be easily concatenated for a quick and easy data pre-processing. -In NeMoS, if a transformation preserve the time axis and you use a `pynapple` time series as input, the result will +In NeMoS, if a transformation preserve the time axis and you use a `pynapple` time series as input, the result will also be a `pynapple` time series. A canonical example of this behavior is the `predict` method of `GLM`. + ```python ->>> # Assume X is a pynapple TsdFrame + +>>> import numpy as np +>>> import pynapple as nap + +>>> # create a TsdFrame with the features and a Tsd with the counts +>>> X = nap.TsdFrame(t=np.arange(100), d=np.random.normal(size=(100, 2))) +>>> y = nap.Tsd(t=np.arange(100), d=np.random.poisson(size=(100, ))) + >>> print(type(X)) # shape (num samples, num features) ->>> model.fit(X, y) # the following works +>>> model = model.fit(X, y) # the following works >>> firing_rate = model.predict(X) # predict the firing rate of the neuron >>> # this will still be a pynapple time series >>> print(type(firing_rate)) # shape (num_samples, ) + ``` Let's see how you can greatly streamline your analysis pipeline by integrating `pynapple` and NeMoS. + !!! note You can download this dataset by clicking [here](https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1). + ```python -import nemos as nmo -import numpy as np -import pynapple as nap -data = nap.load_file("A2929-200711.nwb") +>>> import nemos as nmo +>>> import pynapple as nap + +>>> path = nmo.fetch.fetch_data("A2929-200711.nwb") +>>> data = nap.load_file(path) + +>>> # load spikes and head direction +>>> spikes = data["units"] +>>> head_dir = data["ry"] -spikes = data["units"] -head_dir = data["ry"] +>>> # restrict and bin +>>> counts = spikes[6].count(0.01, ep=head_dir.time_support) -counts = spikes[6].count(0.01, ep=head_dir.time_support) # restrict and bin -upsampled_head_dir = head_dir.bin_average(0.01) # up-sample head direction +>>> # down-sample head direction +>>> upsampled_head_dir = head_dir.bin_average(0.01) -# create your features -X = nmo.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) +>>> # create your features +>>> X = nmo.basis.CyclicBSplineBasis(10).compute_features(upsampled_head_dir) + +>>> # add a neuron axis and fit model +>>> model = nmo.glm.GLM().fit(X, counts) -# add a neuron axis and fit model -model = nmo.glm.GLM().fit(X, counts) ``` + Finally, let's compare the tuning curves + ```python -import numpy as np -import matplotlib.pyplot as plt -raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] -model_tuning = nap.compute_1d_tuning_curves_continuous( - model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate - head_dir, - nb_bins=100 -)[0] +>>> import numpy as np +>>> import matplotlib.pyplot as plt + +>>> # tuning curves +>>> raw_tuning = nap.compute_1d_tuning_curves(spikes, head_dir, nb_bins=100)[6] + +>>> # model based tuning curve +>>> model_tuning = nap.compute_1d_tuning_curves_continuous( +... model.predict(X)[:, np.newaxis] * X.rate, # scale by the sampling rate +... head_dir, +... nb_bins=100 +... )[0] + + +>>> # plot results +>>> sub = plt.subplot(111, projection="polar") +>>> plt1 = plt.plot(raw_tuning.index, raw_tuning.values, label="raw") +>>> plt2 = plt.plot(model_tuning.index, model_tuning.values, label="glm") +>>> legend = plt.yticks([]) +>>> xlab = plt.xlabel("heading angle") -# plot results -plt.subplot(111, projection="polar") -plt.plot(raw_tuning.index, raw_tuning.values, label="raw") -plt.plot(model_tuning.index, model_tuning.values, label="glm") -plt.legend() -plt.yticks([]) -plt.xlabel("heading angle") -plt.show() ``` -![Alt text](head_dir_tuning.jpg) -### Compatibility with `scikit-learn` + + + +## **Compatibility with `scikit-learn`** -`scikit-learn` is a machine learning toolkit that offers advanced features like pipelines and cross-validation methods. -NeMoS takes advantage of these features, while still gaining the benefit of JAX's just-in-time -compilation and GPU-acceleration! +[`scikit-learn`](https://scikit-learn.org/stable/) is a machine learning toolkit that offers advanced features like pipelines and cross-validation methods. +NeMoS takes advantage of these features, while still gaining the benefit of JAX's just-in-time compilation and GPU-acceleration! -For example, if we would like to tune the critical hyper-parameter `regularizer_strength`, we -could easily run a `K-Fold` cross-validation using `scikit-learn`. +For example, if we would like to tune the critical hyper-parameter `regularizer_strength`, we could easily run a `K-Fold` cross-validation[^1] using `scikit-learn`. + +[^1]: For a detailed explanation and practical examples, refer to the [cross-validation page](https://scikit-learn.org/stable/modules/cross_validation.html) in the `scikit-learn` documentation. ```python -import nemos as nmo -from sklearn.model_selection import GridSearchCV -# ...Assume X and counts are available or generated as shown above +>>> # set up the model +>>> import nemos as nmo +>>> import numpy as np + +>>> # generate data +>>> X, counts = np.random.normal(size=(100, 3)), np.random.poisson(size=100) -# model definition -model = nmo.glm.GLM(regularizer="ridge") +>>> # model definition +>>> model = nmo.glm.GLM(regularizer="Ridge") -# fit a 5-fold cross-validation scheme for comparing two different -# regularizer strengths: +``` -# - define the parameter grid -param_grid = dict(regularizer__regularizer_strength=(0.01, 0.001)) +Fit a 5-fold cross-validation scheme for comparing two different regularizer strengths: -# - define the 5-fold cross-validation grid search from sklearn -cls = GridSearchCV(model, param_grid=param_grid, cv=5) +```python + +>>> from sklearn.model_selection import GridSearchCV + +>>> # define the parameter grid +>>> param_grid = dict(regularizer_strength=(0.01, 0.001)) + +>>> # define the 5-fold cross-validation grid search from sklearn +>>> cls = GridSearchCV(model, param_grid=param_grid, cv=5) + +>>> # run the 5-fold cross-validation grid search +>>> cls = cls.fit(X, counts) -# - run the 5-fold cross-validation grid search -cls.fit(X, counts) ``` -!!! info "Cross-Validation" + +!!! info "Cross-Validation in NeMoS" + For more information and a practical example on how to construct a parameter grid and cross-validate hyperparameters across an entire pipeline, please refer to the [tutorial on pipelining and cross-validation](../generated/how_to_guide/plot_06_sklearn_pipeline_cv_demo). -Now we can print the best coefficient. +Finally, we can print the regularizer strength with the best cross-validated performance: ```python -# print best regularizer strength + +>>> # print best regularizer strength >>> print(cls.best_params_) -{'regularizer__regularizer_strength': 0.001} +{'regularizer_strength': 0.01} + ``` Enjoy modeling with NeMoS! diff --git a/mkdocs.yml b/mkdocs.yml index 09940f6a..b7deaeb7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,8 @@ markdown_extensions: - pymdownx.details # add notes toggleable notes ??? - pymdownx.tabbed: alternate_style: true + - toc: + title: On this page plugins: diff --git a/pyproject.toml b/pyproject.toml index a34bb780..5ee288b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,10 @@ dev = [ "pytest-cov", # Test coverage plugin for pytest "statsmodels", # Used to compare model pseudo-r2 in testing "scikit-learn", # Testing compatibility with CV & pipelines + "matplotlib>=3.7", # Needed by doctest to run docstrings examples + "pooch", # Required by doctest for fetch module + "dandi", # Required by doctest for fetch module + "seaborn", # Required by doctest for _documentation_utils module ] docs = [ "mkdocs", # Documentation generator @@ -112,7 +116,7 @@ testpaths = ["tests"] # Specify the directory where test files are l [tool.coverage.run] omit = [ "src/nemos/fetch/*", - "src/nemos/_documentation_utils/*" + "src/nemos/_documentation_utils/*", ] [tool.coverage.report] diff --git a/src/nemos/base_class.py b/src/nemos/base_class.py index 67b63240..ba4a015a 100644 --- a/src/nemos/base_class.py +++ b/src/nemos/base_class.py @@ -25,7 +25,7 @@ class Base: Additionally, it has methods for selecting target devices and sending arrays to them. """ - def get_params(self, deep=True): + def get_params(self, deep=True) -> dict: """ From scikit-learn, get parameters by inspecting init. diff --git a/src/nemos/base_regressor.py b/src/nemos/base_regressor.py index 5f651313..e4a425ce 100644 --- a/src/nemos/base_regressor.py +++ b/src/nemos/base_regressor.py @@ -139,6 +139,23 @@ def solver_run(self) -> Union[None, SolverRun]: """ return self._solver_run + def set_params(self, **params: Any): + """Manage warnings in case of multiple parameter settings.""" + # if both regularizer and regularizer_strength are set, then only + # warn in case the strength is not expected for the regularizer type + if "regularizer" in params and "regularizer_strength" in params: + reg = params.pop("regularizer") + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="Caution: regularizer strength.*" + "|Unused parameter `regularizer_strength`.*", + ) + super().set_params(regularizer=reg) + + return super().set_params(**params) + @property def regularizer(self) -> Union[None, Regularizer]: """Getter for the regularizer attribute.""" @@ -170,19 +187,16 @@ def regularizer_strength(self) -> float: @regularizer_strength.setter def regularizer_strength(self, strength: Union[float, None]): - # if using unregularized, strength will be None no matter what - if isinstance(self._regularizer, UnRegularized): - self._regularizer_strength = None # check regularizer strength - elif strength is None: + if strength is None and not isinstance(self._regularizer, UnRegularized): warnings.warn( UserWarning( "Caution: regularizer strength has not been set. Defaulting to 1.0. Please see " "the documentation for best practices in setting regularization strength." ) ) - self._regularizer_strength = 1.0 - else: + strength = 1.0 + elif strength is not None: try: # force conversion to float to prevent weird GPU issues strength = float(strength) @@ -191,7 +205,16 @@ def regularizer_strength(self, strength: Union[float, None]): raise ValueError( f"Could not convert the regularizer strength: {strength} to a float." ) - self._regularizer_strength = strength + if isinstance(self._regularizer, UnRegularized): + warnings.warn( + UserWarning( + "Unused parameter `regularizer_strength` for UnRegularized GLM. " + "The regularizer strength parameter is not required and won't be used when the regularizer " + "is set to UnRegularized." + ) + ) + + self._regularizer_strength = strength @property def solver_name(self) -> str: diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 2cc48f95..1b0c9f12 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -152,15 +152,15 @@ class TransformerBasis: >>> # transformer can be used in pipelines >>> transformer = TransformerBasis(basis) >>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),]) - >>> pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API - >>> print(pipeline.predict(np.random.normal(size=(10, 1)))) # predict rate from new data - + >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API + >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas >>> # TransformerBasis parameter can be cross-validated. >>> # 5-fold cross-validate the number of basis >>> param_grid = dict(compute_features__n_basis_funcs=[4, 10]) >>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5) - >>> grid_cv.fit(x[:, None], y) + >>> grid_cv = grid_cv.fit(x[:, None], y) >>> print("Cross-validated number of basis:", grid_cv.best_params_) + Cross-validated number of basis: {'compute_features__n_basis_funcs': 10} """ def __init__(self, basis: Basis): @@ -289,7 +289,7 @@ def __getattr__(self, name: str): return getattr(self._basis, name) def __setattr__(self, name: str, value) -> None: - """ + r""" Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax. Setting any other attribute is not allowed. @@ -312,10 +312,11 @@ def __setattr__(self, name: str, value) -> None: >>> # allowed >>> trans_bas.n_basis_funcs = 20 >>> # not allowed - >>> tran_bas.random_attribute_name = "some value" - Traceback (most recent call last): - ... - ValueError: Only setting _basis or existing attributes of _basis is allowed. + >>> try: + ... trans_bas.random_attribute_name = "some value" + ... except ValueError as e: + ... print(repr(e)) + ValueError('Only setting _basis or existing attributes of _basis is allowed.') """ # allow self._basis = basis if name == "_basis": @@ -357,12 +358,16 @@ def set_params(self, **parameters) -> TransformerBasis: >>> # setting parameters of _basis is allowed >>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs) - + 8 >>> # setting _basis directly is allowed - >>> print(transformer_basis.set_params(_basis=BSplineBasis(10))._basis) - + >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)) + >>> # mixing is not allowed, this will raise an exception - >>> transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) + >>> try: + ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2) + ... except ValueError as e: + ... print(repr(e)) + ValueError('Set either new _basis object or parameters for existing _basis, not both.') """ new_basis = parameters.pop("_basis", None) if new_basis is not None: @@ -479,37 +484,40 @@ def __init__( ) -> None: self.n_basis_funcs = n_basis_funcs self._n_input_dimensionality = 0 - self._check_n_basis_min() self._conv_kwargs = kwargs - self.bounds = bounds # check mode if mode not in ["conv", "eval"]: raise ValueError( f"`mode` should be either 'conv' or 'eval'. '{mode}' provided instead!" ) - if mode == "conv": - if window_size is None: - raise ValueError( - "If the basis is in `conv` mode, you must provide a window_size!" - ) - elif not (isinstance(window_size, int) and window_size > 0): - raise ValueError( - f"`window_size` must be a positive integer. {window_size} provided instead!" - ) - if bounds is not None: - raise ValueError("`bounds` should only be set when `mode=='eval'`.") - else: - if kwargs: - raise ValueError( - f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!" - ) - self._window_size = window_size self._mode = mode + self.window_size = window_size + self.bounds = bounds + + if mode == "eval" and kwargs: + raise ValueError( + f"kwargs should only be set when mode=='conv', but '{mode}' provided instead!" + ) + self.kernel_ = None self._identifiability_constraints = False + @property + def n_basis_funcs(self): + return self._n_basis_funcs + + @n_basis_funcs.setter + def n_basis_funcs(self, value): + orig_n_basis = copy.deepcopy(getattr(self, "_n_basis_funcs", None)) + self._n_basis_funcs = value + try: + self._check_n_basis_min() + except ValueError as e: + self._n_basis_funcs = orig_n_basis + raise e + @property def bounds(self): return self._bounds @@ -517,16 +525,26 @@ def bounds(self): @bounds.setter def bounds(self, values: Union[None, Tuple[float, float]]): """Setter for bounds.""" + + if values is not None and self.mode == "conv": + raise ValueError("`bounds` should only be set when `mode=='eval'`.") + if values is not None and len(values) != 2: raise ValueError( f"The provided `bounds` must be of length two. Length {len(values)} provided instead!" ) + # convert to float and store try: self._bounds = values if values is None else tuple(map(float, values)) except (ValueError, TypeError): raise TypeError("Could not convert `bounds` to float.") + if values is not None and values[1] <= values[0]: + raise ValueError( + f"Invalid bound {values}. Lower bound is greater or equal than the upper bound." + ) + @property def mode(self): return self._mode @@ -535,6 +553,28 @@ def mode(self): def window_size(self): return self._window_size + @window_size.setter + def window_size(self, window_size): + """Setter for the window size parameter.""" + if self.mode == "eval": + if window_size: + raise ValueError( + "If basis is in `mode=='eval'`, `window_size` should be None." + ) + + else: + if window_size is None: + raise ValueError( + "If the basis is in `conv` mode, you must provide a window_size!" + ) + + elif not (isinstance(window_size, int) and window_size > 0): + raise ValueError( + f"`window_size` must be a positive integer. {window_size} provided instead!" + ) + + self._window_size = window_size + @property def identifiability_constraints(self): return self._identifiability_constraints @@ -996,9 +1036,9 @@ def to_transformer(self) -> TransformerBasis: >>> from sklearn.pipeline import Pipeline >>> from sklearn.model_selection import GridSearchCV >>> # load some data - >>> X, y = ... # X: features, y: neural activity - >>> basis = nmo.basis.RaisedCosineBasisLinear(10) - >>> glm = nmo.glm.GLM(regularizer="Ridge") + >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30) + >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer() + >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), @@ -1009,7 +1049,7 @@ def to_transformer(self) -> TransformerBasis: ... param_grid=param_grid, ... cv=5, ... ) - >>> gridsearch.fit(X, y) + >>> gridsearch = gridsearch.fit(X, y) """ return TransformerBasis(copy.deepcopy(self)) @@ -1278,10 +1318,34 @@ def __init__( bounds=bounds, **kwargs, ) + self._n_input_dimensionality = 1 - if self.order < 1: + + @property + def order(self): + return self._order + + @order.setter + def order(self, value): + """Setter for the order parameter.""" + + if value < 1: raise ValueError("Spline order must be positive!") + # Set to None only the first time the setter is called. + orig_order = copy.deepcopy(getattr(self, "_order", None)) + + # Set the order + self._order = value + + # If the order was already initialized, re-check basis + if orig_order is not None: + try: + self._check_n_basis_min() + except ValueError as e: + self._order = orig_order + raise e + def _generate_knots( self, sample_pts: NDArray, @@ -1346,7 +1410,7 @@ def _check_n_basis_min(self) -> None: class MSplineBasis(SplineBasis): - r""" + """ M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation. M-splines are a type of spline basis function used for smooth curve fitting @@ -1502,12 +1566,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: >>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3) >>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100) >>> for i in range(4): - ... plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') + ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}') >>> plt.title('M-Spline Basis Functions') + Text(0.5, 1.0, 'M-Spline Basis Functions') >>> plt.xlabel('Domain') + Text(0.5, 0, 'Domain') >>> plt.ylabel('Basis Function Value') - >>> plt.legend() - >>> plt.show() + Text(0, 0.5, 'Basis Function Value') + >>> l = plt.legend() """ return super().evaluate_on_grid(n_samples) @@ -2021,19 +2087,23 @@ def __init__( # The samples are scaled appropriately in the self._transform_samples which scales # and applies the log-stretch, no additional transform is needed. self._rescale_samples = False + if time_scaling is None: + time_scaling = 50.0 + self.time_scaling = time_scaling self.enforce_decay_to_zero = enforce_decay_to_zero - if time_scaling is None: - self._time_scaling = 50.0 - else: - self._check_time_scaling(time_scaling) - self._time_scaling = time_scaling @property def time_scaling(self): """Getter property for time_scaling.""" return self._time_scaling + @time_scaling.setter + def time_scaling(self, time_scaling): + """Setter property for time_scaling.""" + self._check_time_scaling(time_scaling) + self._time_scaling = time_scaling + @staticmethod def _check_time_scaling(time_scaling: float) -> None: if time_scaling <= 0: diff --git a/src/nemos/exceptions.py b/src/nemos/exceptions.py index 4537aafb..8e3caa29 100644 --- a/src/nemos/exceptions.py +++ b/src/nemos/exceptions.py @@ -15,6 +15,5 @@ class NotFittedError(ValueError, AttributeError): ... GLM().predict([[[1, 2], [2, 3], [3, 4]]]) ... except NotFittedError as e: ... print(repr(e)) - ... # NotFittedError("This GLM instance is not fitted yet. Call 'fit' with - ... # appropriate arguments.") + NotFittedError("This GLM instance is not fitted yet. Call 'fit' with appropriate arguments.") """ diff --git a/src/nemos/fetch/fetch_data.py b/src/nemos/fetch/fetch_data.py index a5b76b5c..1246c993 100644 --- a/src/nemos/fetch/fetch_data.py +++ b/src/nemos/fetch/fetch_data.py @@ -37,6 +37,7 @@ "Achilles_10252013.nwb": "42857015aad4c2f7f6f3d4022611a69bc86d714cf465183ce30955731e614990", "allen_478498617.nwb": "262393d7485a5b39cc80fb55011dcf21f86133f13d088e35439c2559fd4b49fa", "m691l1.nwb": "1990d8d95a70a29af95dade51e60ffae7a176f6207e80dbf9ccefaf418fe22b6", + "A2929-200711.nwb": "f698d7319efa5dfeb18fb5fe718ec1a84fdf96b85a158177849a759cd5e396fe", } DOWNLOADABLE_FILES = list(REGISTRY_DATA.keys()) @@ -50,6 +51,7 @@ "Achilles_10252013.nwb": OSF_TEMPLATE.format("hu5ma"), "allen_478498617.nwb": OSF_TEMPLATE.format("vf2nj"), "m691l1.nwb": OSF_TEMPLATE.format("xesdm"), + "A2929-200711.nwb": OSF_TEMPLATE.format("y7zwd"), } _NEMOS_ENV = "NEMOS_DATA_DIR" @@ -143,10 +145,21 @@ def download_dandi_data(dandiset_id: str, filepath: str) -> NWBHDF5IO: Examples -------- >>> import nemos as nmo + >>> import pynapple as nap >>> io = nmo.fetch.download_dandi_data("000582", - "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb") + ... "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb") >>> nwb = nap.NWBFile(io.read(), lazy_loading=False) >>> print(nwb) + 07020602 + ┍━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━┑ + │ Keys │ Type │ + ┝━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━┥ + │ units │ TsGroup │ + │ ElectricalSeriesLFP │ Tsd │ + │ SpatialSeriesLED2 │ TsdFrame │ + │ SpatialSeriesLED1 │ TsdFrame │ + │ ElectricalSeries │ Tsd │ + ┕━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━┙ """ if dandi is None: diff --git a/src/nemos/glm.py b/src/nemos/glm.py index bec4304f..03684859 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -523,11 +523,11 @@ def _initialize_parameters( >>> import numpy as np >>> X = np.zeros((100, 5)) # Example input >>> y = np.exp(np.random.normal(size=(100, ))) # Simulated firing rates - >>> coeff, intercept = nmo.glm.GLM._initialize_parameters(X, y) + >>> coeff, intercept = nmo.glm.GLM()._initialize_parameters(X, y) >>> coeff.shape - (5, ) + (5,) >>> intercept.shape - (1, ) + (1,) """ if isinstance(X, FeaturePytree): data = X.data @@ -823,9 +823,12 @@ def initialize_params( Examples -------- - >>> X, y = load_data() # Hypothetical function to load data + >>> import numpy as np + >>> import nemos as nmo + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> model = nmo.glm.GLM() >>> params = model.initialize_params(X, y) - >>> opt_state = model.initialize_state(X, y) + >>> opt_state = model.initialize_state(X, y, params) >>> # Now ready to run optimization or update steps """ if init_params is None: @@ -950,7 +953,10 @@ def update( Examples -------- - >>> # Assume glm_instance is an instance of GLM that has been previously fitted. + >>> import nemos as nmo + >>> import numpy as np + >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) + >>> glm_instance = nmo.glm.GLM().fit(X, y) >>> params = glm_instance.coef_, glm_instance.intercept_ >>> opt_state = glm_instance.solver_state_ >>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y) @@ -1057,15 +1063,15 @@ class PopulationGLM(GLM): >>> y = np.random.poisson(np.exp(X.dot(weights))) >>> # Define a feature mask, shape (num_features, num_neurons) >>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]]) - >>> print("Feature mask:") - >>> print(feature_mask) + >>> feature_mask + Array([[1, 0], + [1, 1], + [0, 1]], dtype=int32) >>> # Create and fit the model - >>> model = PopulationGLM(feature_mask=feature_mask) - >>> model.fit(X, y) - >>> # Check the fitted coefficients and intercepts - >>> print("Model coefficients:") - >>> print(model.coef_) - + >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y) + >>> # Check the fitted coefficients + >>> print(model.coef_.shape) + (3, 2) >>> # Example with a FeaturePytree mask >>> from nemos.pytrees import FeaturePytree >>> # Define two features @@ -1078,14 +1084,17 @@ class PopulationGLM(GLM): >>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"])) >>> y = np.random.poisson(rate) >>> # Define a feature mask with arrays of shape (num_neurons, ) + >>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0])) - >>> print("Feature mask:") >>> print(feature_mask) + feature_1: shape (2,), dtype int32 + feature_2: shape (2,), dtype int32 + >>> # Fit a PopulationGLM - >>> model = PopulationGLM(feature_mask=feature_mask) - >>> model.fit(X, y) - >>> print("Model coefficients:") - >>> print(model.coef_) + >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y) + >>> # Coefficients are stored in a dictionary with keys the feature labels + >>> print(model.coef_.keys()) + dict_keys(['feature_1', 'feature_2']) """ def __init__( diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py index ecb7e76b..9d683ae1 100644 --- a/src/nemos/observation_models.py +++ b/src/nemos/observation_models.py @@ -846,7 +846,7 @@ def estimate_scale( def check_observation_model(observation_model): - """ + r""" Check the attributes of an observation model for compliance. This function ensures that the observation model has the required attributes and that each @@ -877,10 +877,10 @@ def check_observation_model(observation_model): ... def _negative_log_likelihood(self, params, y_true, aggregate_sample_scores=jnp.mean): ... return -aggregate_sample_scores(y_true * jax.scipy.special.logit(params) + \ ... (1 - y_true) * jax.scipy.special.logit(1 - params)) - ... def pseudo_r2(self, params, y_true, aggregate_sample_scores): + ... def pseudo_r2(self, params, y_true, aggregate_sample_scores=jnp.mean): ... return 1 - (self._negative_log_likelihood(y_true, params, aggregate_sample_scores) / ... jnp.sum((y_true - y_true.mean()) ** 2)) - ... def sample_generator(self, key, params): + ... def sample_generator(self, key, params, scale=1.): ... return jax.random.bernoulli(key, params) >>> model = MyObservationModel() >>> check_observation_model(model) # Should pass without error if the model is correctly implemented. diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py index 91e59f51..6d6cf0bd 100644 --- a/src/nemos/regularizer.py +++ b/src/nemos/regularizer.py @@ -352,10 +352,11 @@ class GroupLasso(Regularizer): >>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4 >>> # Create the GroupLasso regularizer instance - >>> group_lasso = GroupLasso(regularizer_strength=0.1, mask=mask) + >>> group_lasso = GroupLasso(mask=mask) >>> # fit a group-lasso glm - >>> model = GLM(regularizer=group_lasso).fit(X, y) - >>> print(f"coeff: {model.coef_}") + >>> model = GLM(regularizer=group_lasso, regularizer_strength=0.1).fit(X, y) + >>> print(f"coeff shape: {model.coef_.shape}") + coeff shape: (5,) """ _allowed_solvers = ( @@ -433,7 +434,7 @@ def _check_mask(mask: jnp.ndarray): def _penalization( self, params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], regularizer_strength: float ) -> jnp.ndarray: - """ + r""" Calculate the penalization. Note: the penalty is being calculated according to the following formula: diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py index d1b2deeb..4c060609 100644 --- a/src/nemos/solvers.py +++ b/src/nemos/solvers.py @@ -80,11 +80,13 @@ class ProxSVRG: Examples -------- - >>> def loss_fn(params, X, y): - >>> ... - >>> - >>> svrg = ProxSVRG(loss_fn, prox_fun) - >>> params, state = svrg.run(init_params, hyperparams_prox, X, y) + >>> import numpy as np + >>> from jaxopt.prox import prox_lasso + >>> loss_fn = lambda params, X, y: ((X.dot(params) - y)**2).sum() + >>> svrg = ProxSVRG(loss_fn, prox_lasso) + >>> hyperparams_prox = 0.1 + >>> params, state = svrg.run(np.zeros(2), hyperparams_prox, np.ones((10, 2)), np.zeros(10)) + References ---------- @@ -615,11 +617,10 @@ class SVRG(ProxSVRG): Examples -------- - >>> def loss_fn(params, X, y): - >>> ... - >>> + >>> import numpy as np + >>> loss_fn = lambda params, X, y: ((X.dot(params) - y)**2).sum() >>> svrg = SVRG(loss_fn) - >>> params, state = svrg.run(init_params, X, y) + >>> params, state = svrg.run(np.zeros(2), np.ones((10, 2)), np.zeros(10)) References ---------- diff --git a/tests/conftest.py b/tests/conftest.py index 08815bf0..8a36fadf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,10 +12,14 @@ import jax import jax.numpy as jnp import numpy as np +import pynapple as nap import pytest import nemos as nmo +# shut-off conversion warnings +nap.nap_config.suppress_conversion_warnings = True + # Sample subclass to test instantiation and methods class MockRegressor(nmo.base_regressor.BaseRegressor): @@ -394,7 +398,7 @@ def example_data_prox_operator(): ), ) regularizer_strength = 0.1 - mask = jnp.array([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=jnp.float32) + mask = jnp.array([[1, 0, 1, 0], [0, 1, 0, 1]]).astype(float) scaling = 0.5 return params, regularizer_strength, mask, scaling diff --git a/tests/test_basis.py b/tests/test_basis.py index 6e81d142..3e21db33 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -2,15 +2,14 @@ import inspect import pickle from contextlib import nullcontext as does_not_raise +from typing import Literal import jax.numpy import numpy as np import pynapple as nap import pytest -import sklearn.pipeline as pipeline import utils_testing from sklearn.base import clone as sk_clone -from sklearn.model_selection import GridSearchCV import nemos.basis as basis import nemos.convolve as convolve @@ -18,10 +17,11 @@ # automatic define user accessible basis and check the methods + def list_all_basis_classes() -> list[type]: """ Return all the classes in nemos.basis which are a subclass of Basis, - which should be all concrete classes except TransformerBasis. + which should be all concrete classes except TransformerBasis. """ return [ class_obj @@ -146,18 +146,32 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) with expectation: bas(samples) - @pytest.mark.parametrize("n_basis_funcs", [-1, 0, 1, 3, 10, 20]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) def test_minimum_number_of_basis_required_is_matched( @@ -181,7 +195,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -389,8 +405,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -489,12 +509,108 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "enforce_decay_to_zero, time_scaling, width, window_size, n_basis_funcs, bounds, mode", + [ + (False, 15, 4, None, 10, (1, 2), "eval"), + (False, 15, 4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, + enforce_decay_to_zero, + time_scaling, + width, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + enforce_decay_to_zero=enforce_decay_to_zero, + time_scaling=time_scaling, + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + enforce_decay_to_zero=enforce_decay_to_zero, + time_scaling=time_scaling, + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + mode=mode, + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas = bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -523,24 +639,46 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ (None, does_not_raise()), - ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), - ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(3, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", @@ -548,8 +686,8 @@ def test_vmin_vmax_init(self, bounds, expectation): (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -564,12 +702,14 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): - bas_no_range = self.cls(3, mode="eval", window_size=10, bounds=None) - bas = self.cls(3, mode="eval", window_size=10, bounds=(vmin, vmax)) + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): + bas_no_range = self.cls(3, mode="eval", bounds=None) + bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) _, out2 = bas_no_range.evaluate_on_grid(10) assert np.allclose(out1, out2) @@ -580,8 +720,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -596,8 +736,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -685,11 +825,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -719,7 +874,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -781,12 +938,12 @@ def test_evaluate_on_grid_input_number(self, n_input): if n_input == 0: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) missing 1 required positional argument", + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -868,8 +1025,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -885,11 +1046,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -982,12 +1158,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "width, window_size, n_basis_funcs, bounds, mode", + [ + (4, None, 10, (1, 2), "eval"), + (4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, width, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + width=width, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + width=width, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -1024,22 +1282,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -1054,10 +1336,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(3, mode="eval", bounds=None) bas = self.cls(3, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -1070,8 +1354,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -1086,8 +1370,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -1175,11 +1459,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -1233,7 +1532,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 2)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -1293,12 +1594,12 @@ def test_evaluate_on_grid_input_number(self, n_input): if n_input == 0: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) missing 1 required positional argument", + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -1363,8 +1664,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -1380,11 +1685,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -1472,12 +1792,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (4, None, 10, (1, 2), "eval"), + (4, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -1505,7 +1907,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -1515,22 +1916,52 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(3, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(3, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -1545,10 +1976,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (None, np.arange(5), [4], 1), ((1, 4), np.arange(5), [0], 3), - ((1, 3), np.arange(5), [0, 4], 2) - ] + ((1, 3), np.arange(5), [0, 4], 2), + ], ) - def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval(self, bounds, samples, nan_idx, scaling): + def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval( + self, bounds, samples, nan_idx, scaling + ): """Check that the MSpline has the expected scaling property.""" bas_no_range = self.cls(3, mode="eval", bounds=None) bas = self.cls(3, mode="eval", bounds=bounds) @@ -1565,8 +1998,8 @@ def test_vmin_vmax_eval_on_grid_scaling_effect_on_eval(self, bounds, samples, na (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(3, mode="eval", bounds=None) @@ -1581,8 +2014,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -1596,6 +2029,7 @@ def test_transformer_get_params(self): params_basis = bas.get_params() assert params_transf == params_basis + class TestOrthExponentialBasis(BasisFuncsTesting): cls = basis.OrthExponentialBasis @@ -1684,13 +2118,38 @@ def test_sample_size_of_compute_features_matches_that_of_input( @pytest.mark.parametrize( "samples, vmin, vmax, expectation", [ - (np.linspace(-0.5, -0.001, 7), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - (np.linspace(1.5, 2., 7), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - ([-0.5, -0.1, -0.01, 1.5, 2 , 3], 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + np.linspace(-0.5, -0.001, 7), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), + ( + np.linspace(1.5, 2.0, 7), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), + ( + [-0.5, -0.1, -0.01, 1.5, 2, 3], + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax), decay_rates=np.linspace(0.1, 1, 5)) @@ -1727,7 +2186,9 @@ def test_minimum_number_of_basis_required_is_matched( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """Tests whether the compute_features method correctly processes the number of required inputs.""" basis_obj = self.cls( n_basis_funcs=5, @@ -1787,12 +2248,12 @@ def test_evaluate_on_grid_input_number(self, n_input): if n_input == 0: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) missing 1 required positional argument", + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -1902,8 +2363,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5, np.linspace(0.1, 1, 5)) with expectation: @@ -1918,14 +2383,29 @@ def test_call_equivalent_in_conv(self): @pytest.mark.parametrize( "samples, vmin, vmax, expectation", [ - (np.linspace(-1,-0.5, 10), 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + np.linspace(-1, -0.5, 10), + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, decay_rates=np.linspace(0,1,5), bounds=(vmin, vmax)) + bas = self.cls(5, decay_rates=np.linspace(0, 1, 5), bounds=(vmin, vmax)) with expectation: bas(samples) @@ -2005,31 +2485,145 @@ def test_transform_fails(self): ), ], ) - def test_init_mode(self, mode, expectation): - window_size = None if mode == "eval" else 10 + def test_init_mode(self, mode, expectation): + window_size = None if mode == "eval" else 10 + with expectation: + self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) + + @pytest.mark.parametrize( + "mode, ws, expectation", + [ + ("conv", 2, does_not_raise()), + ("conv", 10, does_not_raise()), + ( + "conv", + -1, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ( + "conv", + 1.5, + pytest.raises(ValueError, match="`window_size` must be a positive "), + ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), + ], + ) + def test_init_window_size(self, mode, ws, expectation): + with expectation: + self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) + + @pytest.mark.parametrize( + "decay_rates, window_size, n_basis_funcs, bounds, mode", + [ + (np.arange(1, 11), None, 10, (1, 2), "eval"), + (np.arange(1, 11), 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, + decay_rates, + window_size, + n_basis_funcs, + bounds, + mode: Literal["eval", "conv"], + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + decay_rates=decay_rates, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + decay_rates=decay_rates, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + mode=mode, + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) with expectation: - self.cls(5, mode=mode, window_size=window_size, decay_rates=np.arange(1, 6)) + self.cls( + decay_rates=np.arange(1, 11), + window_size=ws[mode], + n_basis_funcs=10, + mode=mode, + bounds=(1, 2), + ) + + bas = self.cls( + decay_rates=np.arange(1, 11), + window_size=10, + n_basis_funcs=10, + mode="conv", + bounds=None, + ) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) @pytest.mark.parametrize( - "mode, ws, expectation", + "mode, expectation", [ - ("conv", 2, does_not_raise()), - ("conv", 10, does_not_raise()), - ( - "conv", - -1, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), - ( - "conv", - 1.5, - pytest.raises(ValueError, match="`window_size` must be a positive "), - ), + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), ], ) - def test_init_window_size(self, mode, ws, expectation): + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" with expectation: - self.cls(5, mode=mode, window_size=ws, decay_rates=np.arange(1, 6)) + self.cls( + decay_rates=np.arange(1, 11), + window_size=10, + n_basis_funcs=10, + mode=mode, + ) + + bas = self.cls( + decay_rates=np.arange(1, 11), window_size=10, n_basis_funcs=10, mode="conv" + ) + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls( + decay_rates=np.arange(1, 11), + window_size=None, + n_basis_funcs=10, + mode="eval", + ) + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10, decay_rates=np.arange(1, 6)) @@ -2069,6 +2663,7 @@ def test_transformer_get_params(self): assert params_transf == params_basis assert np.all(rates_transf == rates_basis) + class TestBSplineBasis(BasisFuncsTesting): cls = basis.BSplineBasis @@ -2143,11 +2738,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2216,7 +2826,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -2280,12 +2892,12 @@ def test_evaluate_on_grid_input_number(self, n_input): if n_input == 0: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) missing 1 required positional argument", + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -2350,8 +2962,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -2367,11 +2983,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2461,12 +3092,94 @@ def test_init_mode(self, mode, expectation): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (3, None, 10, (1, 2), "eval"), + (3, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -2494,7 +3207,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -2504,22 +3216,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(5, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -2534,10 +3270,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(5, mode="eval", bounds=None) bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -2550,8 +3288,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(5, mode="eval", bounds=None) @@ -2566,8 +3304,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -2656,11 +3394,26 @@ def test_sample_size_of_compute_features_matches_that_of_input( "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_compute_features_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2747,7 +3500,9 @@ def test_samples_range_matches_compute_features_requirements( @pytest.mark.parametrize("n_input", [0, 1, 2, 3]) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 10)]) - def test_number_of_required_inputs_compute_features(self, n_input, mode, window_size): + def test_number_of_required_inputs_compute_features( + self, n_input, mode, window_size + ): """ Confirms that the compute_features() method correctly handles the number of input samples that are provided. """ @@ -2811,12 +3566,12 @@ def test_evaluate_on_grid_input_number(self, n_input): if n_input == 0: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) missing 1 required positional argument", + match=r"evaluate_on_grid\(\) missing 1 required positional argument", ) elif n_input != basis_obj._n_input_dimensionality: expectation = pytest.raises( TypeError, - match="evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", + match=r"evaluate_on_grid\(\) takes [0-9] positional arguments but [0-9] were given", ) else: expectation = does_not_raise() @@ -2881,8 +3636,12 @@ def test_call_nan(self, mode, window_size): "samples, expectation", [ (np.array([0, 1, 2, 3, 4, 5]), does_not_raise()), - (np.array(['a', '1', '2', '3', '4', '5']), pytest.raises(TypeError, match="Input samples must")), - ]) + ( + np.array(["a", "1", "2", "3", "4", "5"]), + pytest.raises(TypeError, match="Input samples must"), + ), + ], + ) def test_call_input_type(self, samples, expectation): bas = self.cls(5) with expectation: @@ -2898,26 +3657,26 @@ def test_call_equivalent_in_conv(self): "samples, vmin, vmax, expectation", [ (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), + ( + -0.5, + 0, + 1, + pytest.raises(ValueError, match="All the samples lie outside"), + ), (np.linspace(-1, 1, 10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] - ) - def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): - bas = self.cls(5, bounds=(vmin, vmax)) - with expectation: - bas(samples) - - @pytest.mark.parametrize( - "samples, vmin, vmax, expectation", - [ - (0.5, 0, 1, does_not_raise()), - (-0.5, 0, 1, pytest.raises(ValueError, match="All the samples lie outside")), - (np.linspace(-1,1,10), 0, 1, does_not_raise()), - (np.linspace(-1, 0, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - (np.linspace(1, 2, 10), 0, 1, pytest.warns(UserWarning, match="More than 90% of the samples")), - ] + ( + np.linspace(-1, 0, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ( + np.linspace(1, 2, 10), + 0, + 1, + pytest.warns(UserWarning, match="More than 90% of the samples"), + ), + ], ) def test_call_vmin_vmax(self, samples, vmin, vmax, expectation): bas = self.cls(5, bounds=(vmin, vmax)) @@ -2989,12 +3748,94 @@ def test_transform_fails(self): 1.5, pytest.raises(ValueError, match="`window_size` must be a positive "), ), + ("eval", None, does_not_raise()), + ( + "eval", + 10, + pytest.raises( + ValueError, + match=r"If basis is in `mode=='eval'`, `window_size` should be None", + ), + ), ], ) def test_init_window_size(self, mode, ws, expectation): with expectation: self.cls(5, mode=mode, window_size=ws) + @pytest.mark.parametrize( + "order, window_size, n_basis_funcs, bounds, mode", + [ + (3, None, 10, (1, 2), "eval"), + (3, 10, 10, None, "conv"), + ], + ) + def test_set_params( + self, order, window_size, n_basis_funcs, bounds, mode: Literal["eval", "conv"] + ): + """Test the read-only and read/write property of the parameters.""" + pars = dict( + order=order, + window_size=window_size, + n_basis_funcs=n_basis_funcs, + bounds=bounds, + ) + keys = list(pars.keys()) + bas = self.cls( + order=order, window_size=window_size, n_basis_funcs=n_basis_funcs, mode=mode + ) + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + par_set = {keys[i]: pars[keys[i]], keys[j]: pars[keys[j]]} + bas.set_params(**par_set) + assert isinstance(bas, self.cls) + + for i in range(len(pars)): + for j in range(i + 1, len(pars)): + with pytest.raises(AttributeError, match="can't set attribute 'mode'|property 'mode' of "): + par_set = { + keys[i]: pars[keys[i]], + keys[j]: pars[keys[j]], + "mode": mode, + } + bas.set_params(**par_set) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("eval", does_not_raise()), + ("conv", pytest.raises(ValueError, match="`bounds` should only be set")), + ], + ) + def test_set_bounds(self, mode, expectation): + ws = dict(eval=None, conv=10) + with expectation: + self.cls(window_size=ws[mode], n_basis_funcs=10, mode=mode, bounds=(1, 2)) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv", bounds=None) + with pytest.raises(ValueError, match="`bounds` should only be set"): + bas.set_params(bounds=(1, 2)) + + @pytest.mark.parametrize( + "mode, expectation", + [ + ("conv", does_not_raise()), + ("eval", pytest.raises(ValueError, match="If basis is in `mode=='eval'`")), + ], + ) + def test_set_window_size(self, mode, expectation): + """Test window size set behavior.""" + with expectation: + self.cls(window_size=10, n_basis_funcs=10, mode=mode) + + bas = self.cls(window_size=10, n_basis_funcs=10, mode="conv") + with pytest.raises(ValueError, match="If the basis is in `conv` mode"): + bas.set_params(window_size=None) + + bas = self.cls(window_size=None, n_basis_funcs=10, mode="eval") + with pytest.raises(ValueError, match="If basis is in `mode=='eval'`"): + bas.set_params(window_size=10) + def test_convolution_is_performed(self): bas = self.cls(5, mode="conv", window_size=10) x = np.random.normal(size=100) @@ -3022,7 +3863,6 @@ def test_conv_kwargs_error(self): with pytest.raises(ValueError, match="kwargs should only be set"): self.cls(5, mode="eval", test="hi") - @pytest.mark.parametrize( "bounds, expectation", [ @@ -3032,22 +3872,46 @@ def test_conv_kwargs_error(self): ((1, 3), does_not_raise()), (("a", 3), pytest.raises(TypeError, match="Could not convert")), ((1, "a"), pytest.raises(TypeError, match="Could not convert")), - (("a", "a"), pytest.raises(TypeError, match="Could not convert")) - ] + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ], ) def test_vmin_vmax_init(self, bounds, expectation): with expectation: bas = self.cls(5, bounds=bounds) assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( + "bounds, expectation", + [ + (None, does_not_raise()), + ((None, 3), pytest.raises(TypeError, match=r"Could not convert")), + ((1, None), pytest.raises(TypeError, match=r"Could not convert")), + ((1, 3), does_not_raise()), + (("a", 3), pytest.raises(TypeError, match="Could not convert")), + ((1, "a"), pytest.raises(TypeError, match="Could not convert")), + (("a", "a"), pytest.raises(TypeError, match="Could not convert")), + ( + (2, 1), + pytest.raises( + ValueError, match=r"Invalid bound \(2, 1\). Lower bound is greater" + ), + ), + ], + ) + def test_vmin_vmax_setter(self, bounds, expectation): + bas = self.cls(5, bounds=(1, 3)) + with expectation: + bas.set_params(bounds=bounds) + assert bounds == bas.bounds if bounds else bas.bounds is None + @pytest.mark.parametrize( "vmin, vmax, samples, nan_idx", [ (None, None, np.arange(5), []), (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): bounds = None if vmin is None else (vmin, vmax) @@ -3062,10 +3926,12 @@ def test_vmin_vmax_range(self, vmin, vmax, samples, nan_idx): [ (0, 3, np.arange(5), [4]), (1, 4, np.arange(5), [0]), - (1, 3, np.arange(5), [0, 4]) - ] + (1, 3, np.arange(5), [0, 4]), + ], ) - def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan_idx): + def test_vmin_vmax_eval_on_grid_no_effect_on_eval( + self, vmin, vmax, samples, nan_idx + ): bas_no_range = self.cls(5, mode="eval", bounds=None) bas = self.cls(5, mode="eval", bounds=(vmin, vmax)) _, out1 = bas.evaluate_on_grid(10) @@ -3078,8 +3944,8 @@ def test_vmin_vmax_eval_on_grid_no_effect_on_eval(self, vmin, vmax, samples, nan (None, np.arange(5), [4], 0, 1), ((0, 3), np.arange(5), [4], 0, 3), ((1, 4), np.arange(5), [0], 1, 4), - ((1, 3), np.arange(5), [0, 4], 1, 3) - ] + ((1, 3), np.arange(5), [0, 4], 1, 3), + ], ) def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx): bas_no_range = self.cls(5, mode="eval", bounds=None) @@ -3094,8 +3960,8 @@ def test_vmin_vmax_eval_on_grid_affects_x(self, bounds, samples, nan_idx, mn, mx (None, np.arange(5), does_not_raise()), ((0, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), ((1, 4), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), - ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")) - ] + ((1, 3), np.arange(5), pytest.raises(ValueError, match="`bounds` should")), + ], ) def test_vmin_vmax_mode_conv(self, bounds, samples, exception): with exception: @@ -3109,6 +3975,7 @@ def test_transformer_get_params(self): params_basis = bas.get_params() assert params_transf == params_basis + class CombinedBasis(BasisFuncsTesting): """ This class is used to run tests on combination operations (e.g., addition, multiplication) among Basis functions. @@ -3122,6 +3989,10 @@ class CombinedBasis(BasisFuncsTesting): @staticmethod def instantiate_basis(n_basis, basis_class, mode="eval", window_size=10): """Instantiate and return two basis of the type specified.""" + + if mode == "eval": + window_size = None + if basis_class == basis.MSplineBasis: basis_obj = basis_class( n_basis_funcs=n_basis, order=4, mode=mode, window_size=window_size @@ -3244,7 +4115,7 @@ def test_sample_size_of_compute_features_matches_that_of_input( self, n_basis_a, n_basis_b, sample_size, basis_a, basis_b, mode, window_size ): """ - Test whether the output sample size from the `AdditiveBasis` compute_features function matches the input sample size. + Test whether the output sample size from `AdditiveBasis` compute_features function matches input sample size. """ basis_a_obj = self.instantiate_basis( n_basis_a, basis_a, mode=mode, window_size=window_size @@ -3258,7 +4129,8 @@ def test_sample_size_of_compute_features_matches_that_of_input( ) if eval_basis.shape[0] != sample_size: raise ValueError( - f"Dimensions do not agree: The window size should match the second dimension of the output features basis." + f"Dimensions do not agree: The window size should match the second dimension of the " + f"output features basis." f"The window size is {sample_size}", f"The second dimension of the output features basis is {eval_basis.shape[0]}", ) @@ -3574,7 +4446,11 @@ def test_call_non_empty( @pytest.mark.parametrize( "mn, mx, expectation", - [(0, 1, does_not_raise()), (-2, 2, does_not_raise()), (0.1, 2, does_not_raise())], + [ + (0, 1, does_not_raise()), + (-2, 2, does_not_raise()), + (0.1, 2, does_not_raise()), + ], ) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -4095,7 +4971,11 @@ def test_call_non_empty( @pytest.mark.parametrize( "mn, mx, expectation", - [(0, 1, does_not_raise()), (-2, 2, does_not_raise()), (0.1, 2, does_not_raise())], + [ + (0, 1, does_not_raise()), + (-2, 2, does_not_raise()), + (0.1, 2, does_not_raise()), + ], ) @pytest.mark.parametrize("mode, window_size", [("eval", None), ("conv", 3)]) @pytest.mark.parametrize("basis_a", list_all_basis_classes()) @@ -4243,6 +5123,7 @@ def test_basis_to_transformer(basis_cls): for k in bas.__dict__.keys(): assert getattr(bas, k) == getattr(trans_bas, k) + @pytest.mark.parametrize( "basis_cls", [ @@ -4289,7 +5170,11 @@ def test_to_transformer_and_constructor_are_equivalent(basis_cls): trans_bas_b = basis.TransformerBasis(bas) # they both just have a _basis - assert list(trans_bas_a.__dict__.keys()) == list(trans_bas_b.__dict__.keys()) == ["_basis"] + assert ( + list(trans_bas_a.__dict__.keys()) + == list(trans_bas_b.__dict__.keys()) + == ["_basis"] + ) # and those bases are the same assert trans_bas_a._basis.__dict__ == trans_bas_b._basis.__dict__ @@ -4349,7 +5234,7 @@ def test_transformerbasis_getattr(basis_cls, n_basis_funcs): @pytest.mark.parametrize("n_basis_funcs_new", [6, 10, 20]) def test_transformerbasis_set_params(basis_cls, n_basis_funcs_init, n_basis_funcs_new): trans_basis = basis.TransformerBasis(basis_cls(n_basis_funcs_init)) - trans_basis.set_params(n_basis_funcs = n_basis_funcs_new) + trans_basis.set_params(n_basis_funcs=n_basis_funcs_new) assert trans_basis.n_basis_funcs == n_basis_funcs_new assert trans_basis._basis.n_basis_funcs == n_basis_funcs_new @@ -4374,6 +5259,7 @@ def test_transformerbasis_setattr_basis(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) + @pytest.mark.parametrize( "basis_cls", [ @@ -4394,6 +5280,7 @@ def test_transformerbasis_setattr_basis_attribute(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) + @pytest.mark.parametrize( "basis_cls", [ @@ -4415,7 +5302,7 @@ def test_transformerbasis_copy_basis_on_contsruct(basis_cls): assert trans_bas._basis.n_basis_funcs == 20 assert trans_bas._basis.n_basis_funcs == 20 assert isinstance(trans_bas._basis, basis_cls) - + @pytest.mark.parametrize( "basis_cls", @@ -4432,7 +5319,10 @@ def test_transformerbasis_setattr_illegal_attribute(basis_cls): # is not allowed trans_bas = basis.TransformerBasis(basis_cls(10)) - with pytest.raises(ValueError, match="Only setting _basis or existing attributes of _basis is allowed."): + with pytest.raises( + ValueError, + match="Only setting _basis or existing attributes of _basis is allowed.", + ): trans_bas.random_attr = "random value" @@ -4454,11 +5344,18 @@ def test_transformerbasis_addition(basis_cls): trans_bas_sum = trans_bas_a + trans_bas_b assert isinstance(trans_bas_sum, basis.TransformerBasis) assert isinstance(trans_bas_sum._basis, basis.AdditiveBasis) - assert trans_bas_sum.n_basis_funcs == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs - assert trans_bas_sum._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + assert ( + trans_bas_sum.n_basis_funcs + == trans_bas_a.n_basis_funcs + trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_sum._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) assert trans_bas_sum._basis1.n_basis_funcs == n_basis_funcs_a assert trans_bas_sum._basis2.n_basis_funcs == n_basis_funcs_b + @pytest.mark.parametrize( "basis_cls", [ @@ -4477,11 +5374,18 @@ def test_transformerbasis_multiplication(basis_cls): trans_bas_prod = trans_bas_a * trans_bas_b assert isinstance(trans_bas_prod, basis.TransformerBasis) assert isinstance(trans_bas_prod._basis, basis.MultiplicativeBasis) - assert trans_bas_prod.n_basis_funcs == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs - assert trans_bas_prod._n_input_dimensionality == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + assert ( + trans_bas_prod.n_basis_funcs + == trans_bas_a.n_basis_funcs * trans_bas_b.n_basis_funcs + ) + assert ( + trans_bas_prod._n_input_dimensionality + == trans_bas_a._n_input_dimensionality + trans_bas_b._n_input_dimensionality + ) assert trans_bas_prod._basis1.n_basis_funcs == n_basis_funcs_a assert trans_bas_prod._basis2.n_basis_funcs == n_basis_funcs_b + @pytest.mark.parametrize( "basis_cls", [ @@ -4493,23 +5397,26 @@ def test_transformerbasis_multiplication(basis_cls): ], ) @pytest.mark.parametrize( - "exponent, error_type, error_message", - [ - (2, does_not_raise, None), - (5, does_not_raise, None), - (0.5, TypeError, "Exponent should be an integer"), - (-1, ValueError, "Exponent should be a non-negative integer") - ] + "exponent, error_type, error_message", + [ + (2, does_not_raise, None), + (5, does_not_raise, None), + (0.5, TypeError, "Exponent should be an integer"), + (-1, ValueError, "Exponent should be a non-negative integer"), + ], ) -def test_transformerbasis_exponentiation(basis_cls, exponent: int, error_type, error_message): +def test_transformerbasis_exponentiation( + basis_cls, exponent: int, error_type, error_message +): trans_bas = basis.TransformerBasis(basis_cls(5)) if not isinstance(exponent, int): with pytest.raises(error_type, match=error_message): - trans_bas_exp = trans_bas ** exponent + trans_bas_exp = trans_bas**exponent assert isinstance(trans_bas_exp, basis.TransformerBasis) assert isinstance(trans_bas_exp._basis, basis.MultiplicativeBasis) + @pytest.mark.parametrize( "basis_cls", [ @@ -4522,11 +5429,17 @@ def test_transformerbasis_exponentiation(basis_cls, exponent: int, error_type, e ) def test_transformerbasis_dir(basis_cls): trans_bas = basis.TransformerBasis(basis_cls(5)) - for attr_name in ("fit", "transform", "fit_transform", "n_basis_funcs", "mode", "window_size"): + for attr_name in ( + "fit", + "transform", + "fit_transform", + "n_basis_funcs", + "mode", + "window_size", + ): assert attr_name in dir(trans_bas) - @pytest.mark.parametrize( "basis_cls", [ @@ -4602,7 +5515,7 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs): (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (2, False, "acausal", [20, 75]), + (5, False, "acausal", [0, 1, 19, 20, 50, 51, 74, 75]), ], ) @pytest.mark.parametrize( @@ -4690,7 +5603,7 @@ def test_multi_epoch_pynapple_basis( (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (2, False, "acausal", [20, 75]), + (5, False, "acausal", [0, 1, 19, 20, 50, 51, 74, 75]), ], ) @pytest.mark.parametrize( diff --git a/tests/test_convergence.py b/tests/test_convergence.py index 881ee0ac..5a8814e1 100644 --- a/tests/test_convergence.py +++ b/tests/test_convergence.py @@ -76,11 +76,12 @@ def test_ridge_convergence(solver_names): y = np.random.poisson(rate) # instantiate and fit ridge GLM with GradientDescent - model_GD = nmo.glm.GLM(regularizer="Ridge", solver_kwargs=dict(tol=10**-12)) + model_GD = nmo.glm.GLM(regularizer_strength=1., regularizer="Ridge", solver_kwargs=dict(tol=10**-12)) model_GD.fit(X, y) # instantiate and fit ridge GLM with ProximalGradient model_PG = nmo.glm.GLM( + regularizer_strength=1., regularizer="Ridge", solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), @@ -108,6 +109,7 @@ def test_lasso_convergence(solver_name): # instantiate and fit GLM with ProximalGradient model_PG = nmo.glm.GLM( regularizer="Lasso", + regularizer_strength=1., solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12), ) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index a23f8bb0..4c5b75ec 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -329,7 +329,7 @@ def test_tree_structure_match(self, trial_counts, axis): conv = convolve.create_convolutional_predictor( basis_matrix, trial_counts, axis=axis ) - assert jax.tree_util.tree_structure(trial_counts) == jax.tree_structure(conv) + assert jax.tree_util.tree_structure(trial_counts) == jax.tree_util.tree_structure(conv) @pytest.mark.parametrize("axis", [0, 1, 2]) @pytest.mark.parametrize( @@ -346,7 +346,6 @@ def test_tree_structure_match(self, trial_counts, axis): (2, False, "anti-causal", [29]), (2, None, "anti-causal", [29, 28]), (3, False, "acausal", [29, 0]), - (2, False, "acausal", [29]), ], ) def test_expected_nan(self, axis, window_size, shift, predictor_causality, nan_idx): @@ -394,7 +393,6 @@ def test_expected_nan(self, axis, window_size, shift, predictor_causality, nan_i (2, False, "anti-causal", [20, 75]), (2, None, "anti-causal", [20, 19, 75, 74]), (3, False, "acausal", [0, 20, 50, 75]), - (2, False, "acausal", [20, 75]), ], ) def test_multi_epoch_pynapple( diff --git a/tests/test_glm.py b/tests/test_glm.py index 5c979a94..d17ae3a9 100644 --- a/tests/test_glm.py +++ b/tests/test_glm.py @@ -73,7 +73,7 @@ def test_solver_type(self, regularizer, solver_name, expectation, glm_class): Test that an error is raised if a non-compatible solver is passed. """ with expectation: - glm_class(regularizer=regularizer, solver_name=solver_name) + glm_class(regularizer=regularizer, solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "observation, expectation", @@ -166,7 +166,7 @@ def test_get_params(self): assert list(model.get_params().values()) == expected_values # changing regularizer - model.regularizer = "Ridge" + model.set_params(regularizer="Ridge", regularizer_strength=1.) expected_values = [ model.observation_model.inverse_link_function, @@ -491,6 +491,7 @@ def test_fit_mask_grouplasso(self, group_sparse_poisson_glm_model_instantiation) """Test that the group lasso fit goes through""" X, y, model, params, rate, mask = group_sparse_poisson_glm_model_instantiation model.set_params( + regularizer_strength=1., regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", ) @@ -1176,6 +1177,7 @@ def test_initialize_solver_mask_grouplasso( model.set_params( regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", + regularizer_strength=1., ) params = model.initialize_params(X, y) model.initialize_state(X, y, params) @@ -1481,10 +1483,11 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm n_features = sum(x.shape[1] for x in jax.tree.leaves(X)) regularizer_kwargs["mask"] = (np.random.randn(n_features) > 0).reshape(1, -1).astype(float) + reg = regularizer_class(**regularizer_kwargs) + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. glm = glm_class( - regularizer=regularizer_class( - **regularizer_kwargs, - ), + regularizer=reg, + regularizer_strength=strength, solver_name=solver_name, solver_kwargs={ "batch_size": batch_size, @@ -1495,9 +1498,7 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm }, ) glm2 = glm_class( - regularizer=regularizer_class( - **regularizer_kwargs, - ), + regularizer=reg, solver_name=solver_name, solver_kwargs={ "batch_size": batch_size, @@ -1506,6 +1507,7 @@ def test_glm_update_consistent_with_fit_with_svrg(self, request, regr_setup, glm "maxiter": maxiter, "key": key, }, + regularizer_strength=strength, ) glm2.fit(X, y) @@ -1623,7 +1625,8 @@ def test_estimate_dof_resid( Test that the dof is an integer. """ X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = reg + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. + model.set_params(regularizer=reg, regularizer_strength=strength) model.solver_name = model.regularizer.default_solver model.fit(X, y) num = model._estimate_resid_degrees_of_freedom(X, n_samples=n_samples) @@ -1642,15 +1645,68 @@ def test_warning_solver_reg_str(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) # reset to unregularized - model.regularizer = "UnRegularized" + model.set_params(regularizer = "UnRegularized", regularizer_strength=None) with pytest.warns(UserWarning): nmo.glm.GLM(regularizer=reg) @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) def test_reg_strength_reset(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + with pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM"): + model.regularizer = "UnRegularized" + model.regularizer_strength = None + with pytest.warns(UserWarning, match="Caution: regularizer strength has not been set"): + model.regularizer = "Ridge" + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer + ({"regularizer": "Ridge"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso"}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "UnRegularized"}, does_not_raise()), + # set both None or number + ({"regularizer": "Ridge", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Ridge", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "Lasso", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "GroupLasso", "regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": None}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": 1.}, + pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + # set regularizer str only + ({"regularizer_strength": 1.}, + pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + ({"regularizer_strength": None}, does_not_raise()), + ] + ) + def test_reg_set_params(self, params, warns): + model = nmo.glm.GLM() + with warns: + model.set_params(**params) + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer str only + ({"regularizer_strength": 1.}, does_not_raise()), + ({"regularizer_strength": None}, + pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ] + ) + @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) + def test_reg_set_params_reg_str_only(self, params, warns, reg): + model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1) + with warns: + model.set_params(**params) class TestPopulationGLM: @@ -1684,7 +1740,7 @@ def test_solver_type(self, regularizer, expectation, population_glm_class): Test that an error is raised if a non-compatible solver is passed. """ with expectation: - population_glm_class(regularizer=regularizer) + population_glm_class(regularizer=regularizer, regularizer_strength=1.) def test_get_params(self): """ @@ -1732,7 +1788,7 @@ def test_get_params(self): assert list(model.get_params().values()) == expected_values # changing regularizer - model.regularizer = "Ridge" + model.set_params(regularizer="Ridge", regularizer_strength=1.) expected_values = [ model.feature_mask, @@ -1792,7 +1848,7 @@ def test_init_observation_type( """ with expectation: population_glm_class( - regularizer=ridge_regularizer, observation_model=observation + regularizer=ridge_regularizer, observation_model=observation, regularizer_strength=1. ) @pytest.mark.parametrize( @@ -1857,7 +1913,8 @@ def test_estimate_dof_resid( Test that the dof is an integer. """ X, y, model, true_params, firing_rate = poisson_population_GLM_model - model.regularizer = reg + strength = None if isinstance(reg, nmo.regularizer.UnRegularized) else 1. + model.set_params(regularizer=reg, regularizer_strength=strength) model.solver_name = model.regularizer.default_solver model.fit(X, y) num = model._estimate_resid_degrees_of_freedom(X, n_samples=n_samples) @@ -2123,6 +2180,7 @@ def test_fit_mask_grouplasso(self, group_sparse_poisson_glm_model_instantiation) model.set_params( regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", + regularizer_strength=1., ) model.fit(X, y) @@ -2527,6 +2585,7 @@ def test_initialize_solver_mask_grouplasso( """Test that the group lasso initialize_solver goes through""" X, y, model, params, rate, mask = group_sparse_poisson_glm_model_instantiation model.set_params( + regularizer_strength=1., regularizer=nmo.regularizer.GroupLasso(mask=mask), solver_name="ProximalGradient", ) @@ -3210,13 +3269,13 @@ def test_feature_mask_compatibility_fit_tree( [ ( nmo.regularizer.UnRegularized(), - 0.001, + None, "LBFGS", {"stepsize": 0.1, "tol": 10**-14}, ), ( nmo.regularizer.UnRegularized(), - 1.0, + None, "GradientDescent", {"tol": 10**-14}, ), @@ -3262,7 +3321,7 @@ def test_masked_fit_vs_loop( ): jax.config.update("jax_enable_x64", True) if isinstance(mask, dict): - X, y, model, true_params, firing_rate = poisson_population_GLM_model_pytree + X, y, _, true_params, firing_rate = poisson_population_GLM_model_pytree def map_neu(k, coef_): key_ind = {"input_1": [0, 1, 2], "input_2": [3, 4]} @@ -3275,7 +3334,7 @@ def map_neu(k, coef_): return ind_array, coef_stack else: - X, y, model, true_params, firing_rate = poisson_population_GLM_model + X, y, _, true_params, firing_rate = poisson_population_GLM_model def map_neu(k, coef_): ind_array = np.where(mask[:, k])[0] @@ -3284,11 +3343,14 @@ def map_neu(k, coef_): mask_bool = jax.tree_util.tree_map(lambda x: np.asarray(x.T, dtype=bool), mask) # fit pop glm - model.feature_mask = mask - model.regularizer = regularizer - model.regularizer_strength = regularizer_strength - model.solver_name = solver_name - model.solver_kwargs = solver_kwargs + kwargs = dict( + feature_mask=mask, + regularizer=regularizer, + regularizer_strength=regularizer_strength, + solver_name=solver_name, + solver_kwargs=solver_kwargs, + ) + model = nmo.glm.PopulationGLM(**kwargs) model.fit(X, y) coef_vectorized = np.vstack(jax.tree_util.tree_leaves(model.coef_)) @@ -3331,12 +3393,56 @@ def test_waning_solver_reg_str(self, reg): model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) # reset to unregularized - model.regularizer = "UnRegularized" + model.set_params(regularizer="UnRegularized", regularizer_strength=None) with pytest.warns(UserWarning): nmo.glm.GLM(regularizer=reg) @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) def test_reg_strength_reset(self, reg): - model = nmo.glm.GLM(regularizer=reg, regularizer_strength=1.0) - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + model = nmo.glm.PopulationGLM(regularizer=reg, regularizer_strength=1.0) + with pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM"): + model.regularizer = "UnRegularized" + model.regularizer_strength = None + with pytest.warns(UserWarning, match="Caution: regularizer strength has not been set"): + model.regularizer = "Ridge" + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer + ({"regularizer": "Ridge"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso"}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "UnRegularized"}, does_not_raise()), + # set both None or number + ({"regularizer": "Ridge", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Ridge", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "Lasso", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "Lasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "GroupLasso", "regularizer_strength": None}, pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ({"regularizer": "GroupLasso", "regularizer_strength": 1.}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": None}, does_not_raise()), + ({"regularizer": "UnRegularized", "regularizer_strength": 1.}, pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + # set regularizer str only + ({"regularizer_strength": 1.}, pytest.warns(UserWarning, match="Unused parameter `regularizer_strength` for UnRegularized GLM")), + ({"regularizer_strength": None},does_not_raise()), + ] + ) + def test_reg_set_params(self, params, warns): + model = nmo.glm.PopulationGLM() + with warns: + model.set_params(**params) + + @pytest.mark.parametrize( + "params, warns", + [ + # set regularizer str only + ({"regularizer_strength": 1.}, does_not_raise()), + ({"regularizer_strength": None},pytest.warns(UserWarning, match="Caution: regularizer strength has not been set")), + ] + ) + @pytest.mark.parametrize("reg", ["Ridge", "Lasso", "GroupLasso"]) + def test_reg_set_params_reg_str_only(self, params, warns, reg): + model = nmo.glm.PopulationGLM(regularizer=reg, regularizer_strength=1) + with warns: + model.set_params(**params) \ No newline at end of file diff --git a/tests/test_observation_models.py b/tests/test_observation_models.py index ffc27c46..25262035 100644 --- a/tests/test_observation_models.py +++ b/tests/test_observation_models.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext as does_not_raise import jax @@ -501,7 +502,9 @@ def test_pseudo_r2_vs_statsmodels(self, gammaGLM_model_instantiation): X, y, model, _, firing_rate = gammaGLM_model_instantiation # statsmodels mcfadden - mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The InversePower link function does") + mdl = sm.GLM(y, sm.add_constant(X), family=sm.families.Gamma()).fit() pr2_sms = mdl.pseudo_rsquared("mcf") # set params diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 38754143..26b58d14 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,4 @@ +import joblib import numpy as np import pynapple as nap import pytest @@ -39,8 +40,8 @@ def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation): X, y, model, _, _ = poissonGLM_model_instantiation bas = basis.TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) - param_grid = dict(basis__n_basis_funcs=(3, 5, 10)) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3) + param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -60,9 +61,11 @@ def test_sklearn_transformer_pipeline_cv_multiprocess( X, y, model, _, _ = poissonGLM_model_instantiation bas = basis.TransformerBasis(bas) pipe = pipeline.Pipeline([("basis", bas), ("fit", model)]) - param_grid = dict(basis__n_basis_funcs=(3, 5, 10)) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3) - gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) + param_grid = dict(basis__n_basis_funcs=(4, 5, 10)) + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3, error_score='raise') + # use threading instead of fork (this avoids conflicts with jax) + with joblib.parallel_backend("threading"): + gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @pytest.mark.parametrize( @@ -82,7 +85,7 @@ def test_sklearn_transformer_pipeline_cv_directly_over_basis( bas = basis.TransformerBasis(bas_cls(5)) pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3) + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y) @@ -104,9 +107,9 @@ def test_sklearn_transformer_pipeline_cv_illegal_combination( pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)]) param_grid = dict( transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)), - transformerbasis__n_basis_funcs=(3, 5, 10), + transformerbasis__n_basis_funcs=(4, 5, 10), ) - gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3) + gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, error_score='raise') with pytest.raises( ValueError, match="Set either new _basis object or parameters for existing _basis, not both." ): diff --git a/tests/test_proximal_operator.py b/tests/test_proximal_operator.py index 59d162bc..a6a65bfb 100644 --- a/tests/test_proximal_operator.py +++ b/tests/test_proximal_operator.py @@ -121,6 +121,7 @@ def test_prox_operator_shrinks_only_masked(example_data_prox_operator): def test_prox_operator_shrinks_only_masked_multineuron(example_data_prox_operator_multineuron): params, _, mask, _ = example_data_prox_operator_multineuron + mask = mask.astype(float) mask = mask.at[:, 1].set(jnp.zeros(2)) params_new = prox_group_lasso(params, 0.05, mask) assert jnp.all(params_new[0][1] == params[0][1]) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 32565d07..5abba876 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -1,4 +1,5 @@ import copy +import warnings import jax import jax.numpy as jnp @@ -6,6 +7,7 @@ import pytest import statsmodels.api as sm from sklearn.linear_model import GammaRegressor, PoissonRegressor +from statsmodels.tools.sm_exceptions import DomainWarning import nemos as nmo @@ -218,9 +220,9 @@ def test_regularizer_strength_none(self): assert model.regularizer_strength == 1.0 - # assert change back to unregularized is none - model.regularizer = regularizer - assert model.regularizer_strength is None + with pytest.warns(UserWarning): + model.regularizer = regularizer + assert model.regularizer_strength == 1. def test_get_params(self): """Test get_params() returns expected values.""" @@ -275,7 +277,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.instantiate_solver() model.solver_run((true_params[0] * 0.0, true_params[1]), X, y) @@ -290,7 +292,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.instantiate_solver() model.solver_run( @@ -307,7 +309,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 # set model params - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -338,7 +340,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation, solver_name) X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -363,7 +365,7 @@ def test_solver_match_sklearn_gamma( # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} model.instantiate_solver() @@ -396,17 +398,18 @@ def test_solver_match_statsmodels_gamma( # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = inv_link_jax - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-13} model.instantiate_solver() weights_bfgs, intercepts_bfgs = model.solver_run( model._initialize_parameters(X, y), X, y )[0] - - model_sm = sm.GLM( - endog=y, exog=sm.add_constant(X), family=sm.families.Gamma(link=link_sm) - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The InversePower link function does ") + model_sm = sm.GLM( + endog=y, exog=sm.add_constant(X), family=sm.families.Gamma(link=link_sm) + ) res_sm = model_sm.fit(cnvrg_tol=10**-12) @@ -429,7 +432,7 @@ def test_solver_match_statsmodels_gamma( ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls()) model.solver_name = solver_name model.fit(X, y) @@ -465,9 +468,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1.) @pytest.mark.parametrize( "solver_name", @@ -494,7 +497,7 @@ def test_set_solver_name_allowed(self, solver_name): "ProxSVRG", ] regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -518,12 +521,14 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): @@ -535,13 +540,13 @@ def test_regularizer_strength_none(self): assert model.regularizer_strength == 1.0 - # if changed to regularized, should go to None - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + with pytest.warns(UserWarning): + # if changed to regularized, is kept to 1. + model.regularizer = "UnRegularized" + assert model.regularizer_strength == 1.0 # if changed back, should warn and set to 1.0 - with pytest.warns(UserWarning): - model.regularizer = "Ridge" + model.regularizer = "Ridge" assert model.regularizer_strength == 1.0 @@ -556,7 +561,7 @@ def test_loss_is_callable(self, loss): """Test Ridge callable loss.""" raise_exception = not callable(loss) regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) model._predict_and_compute_loss = loss if raise_exception: with pytest.raises(TypeError, match="The `loss` must be a Callable"): @@ -574,7 +579,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) @@ -589,7 +594,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner( @@ -607,7 +612,7 @@ def test_solver_output_match(self, poissonGLM_model_instantiation, solver_name): model.data_type = jnp.float64 # set model params - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} @@ -638,7 +643,7 @@ def test_solver_match_sklearn(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_kwargs = {"tol": 10**-12} model.solver_name = "BFGS" @@ -665,7 +670,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 model.observation_model.inverse_link_function = jnp.exp - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_kwargs = {"tol": 10**-12} model.regularizer_strength = 0.1 model.solver_name = "BFGS" @@ -697,7 +702,7 @@ def test_solver_match_sklearn_gamma(self, gammaGLM_model_instantiation): ) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.fit(X, y) @@ -728,9 +733,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) else: - nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(), solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "solver_name", @@ -751,7 +756,7 @@ def test_set_solver_name_allowed(self, solver_name): "ProxSVRG", ] regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) raise_exception = solver_name not in acceptable_solvers if raise_exception: with pytest.raises( @@ -775,25 +780,27 @@ def test_init_solver_kwargs(self, solver_kwargs, solver_name): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" # if no strength given, should warn and set to 1.0 + regularizer = self.cls() with pytest.warns(UserWarning): - regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) assert model.regularizer_strength == 1.0 # if changed to regularized, should go to None - model.regularizer = "UnRegularized" + model.set_params(regularizer="UnRegularized", regularizer_strength=None) assert model.regularizer_strength is None # if changed back, should warn and set to 1.0 @@ -813,7 +820,7 @@ def test_loss_callable(self, loss): """Test that the loss function is a callable""" raise_exception = not callable(loss) regularizer = self.cls() - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) model._predict_and_compute_loss = loss if raise_exception: with pytest.raises(TypeError, match="The `loss` must be a Callable"): @@ -827,7 +834,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner((true_params[0] * 0.0, true_params[1]), X, y) @@ -839,7 +846,7 @@ def test_run_solver_tree(self, solver_name, poissonGLM_model_instantiation_pytre X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree # set regularizer and solver name - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name runner = model.instantiate_solver().solver_run runner( @@ -857,7 +864,7 @@ def test_solver_match_statsmodels( X, y, model, true_params, firing_rate = poissonGLM_model_instantiation # set precision to float64 for accurate matching of the results model.data_type = jnp.float64 - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1) model.solver_name = solver_name model.solver_kwargs = {"tol": 10**-12} @@ -885,7 +892,7 @@ def test_solver_match_statsmodels( def test_lasso_pytree(self, poissonGLM_model_instantiation_pytree): """Check pytree X can be fit.""" X, y, model, true_params, firing_rate = poissonGLM_model_instantiation_pytree - model.regularizer = nmo.regularizer.Lasso() + model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=1.) model.solver_name = "ProximalGradient" model.fit(X, y) @@ -903,10 +910,9 @@ def test_lasso_pytree_match( X, _, model, _, _ = poissonGLM_model_instantiation_pytree X_array, y, model_array, _, _ = poissonGLM_model_instantiation - model.regularizer_strength = reg_str - model_array.regularizer_strength = reg_str - model.regularizer = nmo.regularizer.Lasso() - model_array.regularizer = nmo.regularizer.Lasso() + + model.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) + model_array.set_params(regularizer=nmo.regularizer.Lasso(), regularizer_strength=reg_str) model.solver_name = solver_name model_array.solver_name = solver_name model.fit(X, y) @@ -918,7 +924,7 @@ def test_lasso_pytree_match( @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(), regularizer_strength=1.) model.solver_name = solver_name model.fit(X, y) @@ -956,9 +962,9 @@ def test_init_solver_name(self, solver_name): with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " ): - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) else: - nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name) + nmo.glm.GLM(regularizer=self.cls(mask=mask), solver_name=solver_name, regularizer_strength=1) @pytest.mark.parametrize( "solver_name", @@ -985,7 +991,7 @@ def test_set_solver_name_allowed(self, solver_name): mask = jnp.asarray(mask) regularizer = self.cls(mask=mask) raise_exception = solver_name not in acceptable_solvers - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1) if raise_exception: with pytest.raises( ValueError, match=f"The solver: {solver_name} is not allowed for " @@ -1016,26 +1022,27 @@ def test_init_solver_kwargs(self, solver_name, solver_kwargs): regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) else: nmo.glm.GLM( regularizer=regularizer, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=1. ) def test_regularizer_strength_none(self): """Assert regularizer strength handled appropriately.""" # if no strength given, should warn and set to 1.0 + regularizer = self.cls() with pytest.warns(UserWarning): - regularizer = self.cls() model = nmo.glm.GLM(regularizer=regularizer) assert model.regularizer_strength == 1.0 # if changed to regularized, should go to None - model.regularizer = "UnRegularized" - assert model.regularizer_strength is None + model.set_params(regularizer="UnRegularized", regularizer_strength=None) # if changed back, should warn and set to 1.0 with pytest.warns(UserWarning): @@ -1061,7 +1068,7 @@ def test_loss_callable(self, loss): mask = jnp.asarray(mask) regularizer = self.cls(mask=mask) - model = nmo.glm.GLM(regularizer=regularizer) + model = nmo.glm.GLM(regularizer=regularizer, regularizer_strength=1.) model._predict_and_compute_loss = loss if raise_exception: @@ -1082,7 +1089,7 @@ def test_run_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1100,7 +1107,7 @@ def test_init_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1126,7 +1133,7 @@ def test_update_solver(self, solver_name, poissonGLM_model_instantiation): mask[1, 2:] = 1 mask = jnp.asarray(mask) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = solver_name model.instantiate_solver() @@ -1186,9 +1193,9 @@ def test_mask_validity_groups( with pytest.raises( ValueError, match="Incorrect group assignment. " "Some of the features" ): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("set_entry", [0, 1, -1, 2, 2.5]) def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): @@ -1206,9 +1213,9 @@ def test_mask_validity_entries(self, set_entry, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="Mask elements be 0s and 1s"): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("n_dim", [0, 1, 2, 3]) def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): @@ -1235,9 +1242,9 @@ def test_mask_dimension_1(self, n_dim, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match="`mask` must be 2-dimensional"): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) @pytest.mark.parametrize("n_groups", [0, 1, 2]) def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): @@ -1256,9 +1263,9 @@ def test_mask_n_groups(self, n_groups, poissonGLM_model_instantiation): if raise_exception: with pytest.raises(ValueError, match=r"Empty mask provided! Mask has "): - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer = self.cls(mask=mask), regularizer_strength=1.) else: - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) def test_group_sparsity_enforcement( self, group_sparse_poisson_glm_model_instantiation @@ -1278,7 +1285,7 @@ def test_group_sparsity_enforcement( mask[1, ~zeros_true] = 1 mask = jnp.asarray(mask, dtype=jnp.float32) - model.regularizer = self.cls(mask=mask) + model.set_params(regularizer=self.cls(mask=mask), regularizer_strength=1.) model.solver_name = "ProximalGradient" runner = model.instantiate_solver().solver_run @@ -1415,14 +1422,15 @@ def test_mask_none(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation with pytest.warns(UserWarning): - model.regularizer = self.cls() + model.regularizer = self.cls(mask=np.ones((1, X.shape[1])).astype(float)) model.solver_name = "ProximalGradient" model.fit(X, y) @pytest.mark.parametrize("solver_name", ["ProximalGradient", "ProxSVRG"]) def test_solver_combination(self, solver_name, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation - model.regularizer = self.cls() + model.set_params(regularizer=self.cls(mask=np.ones((1, X.shape[1])).astype(float)), + regularizer_strength=None if self.cls==nmo.regularizer.UnRegularized else 1.) model.solver_name = solver_name model.fit(X, y) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 72970397..745e6cf9 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -128,7 +128,7 @@ def test_svrg_glm_instantiate_solver(regularizer_name, solver_class, mask): if mask is not None: kwargs["mask"] = mask - glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name) + glm = nmo.glm.GLM(regularizer=regularizer_name, solver_name=solver_name, regularizer_strength=None if regularizer_name == "UnRegularized" else 1,) glm.instantiate_solver() solver = inspect.getclosurevars(glm._solver_run).nonlocals["solver"] @@ -161,6 +161,7 @@ def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_ regularizer=regularizer_name, solver_name=solver_name, solver_kwargs=solver_kwargs, + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) glm.instantiate_solver() @@ -177,9 +178,9 @@ def test_svrg_glm_passes_solver_kwargs(regularizer_name, solver_name, mask, glm_ ( "GroupLasso", ProxSVRG, - np.array([[0], [0], [1]]), + np.array([[0.], [0.], [1.]]), ), - ("GroupLasso", ProxSVRG, None), + ("GroupLasso", ProxSVRG, np.array([[1.], [0.], [0.]])), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), ], @@ -196,15 +197,22 @@ def test_svrg_glm_initialize_state( if glm_class == nmo.glm.PopulationGLM: y = np.expand_dims(y, 1) + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + # only pass mask if it's not None kwargs = {} if mask is not None and glm_class == nmo.glm.PopulationGLM: kwargs["feature_mask"] = mask glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) @@ -225,7 +233,7 @@ def test_svrg_glm_initialize_state( ( "GroupLasso", ProxSVRG, - np.array([[0], [0], [1]]), + np.array([[0.], [0.], [1.]]), ), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), @@ -244,13 +252,20 @@ def test_svrg_glm_update( # only pass mask if it's not None kwargs = {} - if mask is not None and glm_class == nmo.glm.PopulationGLM: + if glm_class == nmo.glm.PopulationGLM: kwargs["feature_mask"] = mask + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, **kwargs, ) @@ -276,15 +291,15 @@ def test_svrg_glm_update( ( "GroupLasso", "ProxSVRG", - np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + np.array([[0, 1, 0, 1, 1], [1, 0, 1, 0, 0]]).astype(float), ), - ("GroupLasso", "ProxSVRG", None), + ("GroupLasso", "ProxSVRG", np.array([[1, 1, 1, 1, 1]]).astype(float)), ( "GroupLasso", "ProximalGradient", - np.array([[0, 1, 0], [0, 0, 1]]).reshape(2, -1).astype(float), + np.array([[0, 1, 0, 1, 1], [1, 0, 1, 0, 0]]).astype(float), ), - ("GroupLasso", "ProximalGradient", None), + ("GroupLasso", "ProximalGradient", np.array([[1, 1, 1, 1, 1]]).astype(float)), ("Ridge", "SVRG", None), ("UnRegularized", "SVRG", None), ], @@ -314,15 +329,23 @@ def test_svrg_glm_fit( } # only pass mask if it's not None + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + kwargs = {} - if mask is not None: - kwargs["feature_mask"] = mask + if glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = np.ones((X.shape[1], 1)) glm = glm_class( - regularizer=regularizer_name, + regularizer=reg, solver_name=solver_name, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), solver_kwargs=solver_kwargs, + regularizer_strength=None if regularizer_name == "UnRegularized" else 1, + **kwargs ) if isinstance(glm, nmo.glm.PopulationGLM): @@ -339,7 +362,7 @@ def test_svrg_glm_fit( "regularizer_name, solver_class, mask", [ ("Lasso", ProxSVRG, None), - ("GroupLasso", ProxSVRG, np.array([0, 1, 0]).reshape(1, -1).astype(float)), + ("GroupLasso", ProxSVRG, np.array([0, 1, 0]).reshape(-1, 1).astype(float)), ("Ridge", SVRG, None), ("UnRegularized", SVRG, None), ], @@ -356,16 +379,23 @@ def test_svrg_glm_update_needs_full_grad_at_reference_point( y = np.expand_dims(y, 1) # only pass mask if it's not None - kwargs = {} - if mask is not None and glm_class == nmo.glm.PopulationGLM: - kwargs["feature_mask"] = mask - - glm = glm_class( - regularizer=regularizer_name, + reg_cls = getattr(nmo.regularizer, regularizer_name) + if regularizer_name == "GroupLasso": + reg = reg_cls(mask=mask) + else: + reg = reg_cls() + kwargs = dict( + regularizer=reg, solver_name=solver_class.__name__, observation_model=nmo.observation_models.PoissonObservations(jax.nn.softplus), + regularizer_strength=None if regularizer_name == "UnRegularized" else 0.1, ) + if mask is not None and glm_class == nmo.glm.PopulationGLM: + kwargs["feature_mask"] = np.array([0, 1, 0]).reshape(-1, 1).astype(float) + + glm = glm_class(**kwargs) + with pytest.raises( ValueError, match=r"Full gradient at the anchor point \(state\.full_grad_at_reference_point\) has to be set", diff --git a/tests/test_tree_utils.py b/tests/test_tree_utils.py index f4e10431..33c58850 100644 --- a/tests/test_tree_utils.py +++ b/tests/test_tree_utils.py @@ -1,5 +1,5 @@ -import numpy as np import jax.numpy as jnp +import numpy as np import pytest from nemos import tree_utils diff --git a/tests/test_type_casting.py b/tests/test_type_casting.py index bed5b1f0..b7cbbe48 100644 --- a/tests/test_type_casting.py +++ b/tests/test_type_casting.py @@ -370,7 +370,7 @@ def func(*x): ( [ nap.Tsd(t=np.arange(10), d=np.arange(10)), - nap.Tsd(t=np.arange(1), d=np.arange(1)), + nap.Tsd(t=np.arange(1), d=np.arange(1), time_support=nap.IntervalSet(0, 10)), nap.Tsd(t=np.arange(10), d=np.arange(10)), ], pytest.raises( diff --git a/tests/test_utils.py b/tests/test_utils.py index c3babd71..abe46314 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +import warnings from contextlib import nullcontext as does_not_raise import jax @@ -107,7 +108,10 @@ def test_conv_type(self, iterable, predictor_causality): with pytest.raises(ValueError, match="predictor_causality must be one of"): utils.nan_pad(iterable, 3, predictor_causality) else: - utils.nan_pad(iterable, 3, predictor_causality) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + utils.nan_pad(iterable, 3, predictor_causality) @pytest.mark.parametrize("iterable", [[np.zeros([2, 4, 5]), np.zeros([2, 4, 6])]]) @pytest.mark.parametrize("pad_size", [0.1, -1, 0, 1, 2, 3, 5, 6]) @@ -159,7 +163,7 @@ def test_padding_nan_anti_causal(self, pad_size, iterable): ), "Size after padding doesn't match expectation. Should be T + window_size - 1." @pytest.mark.parametrize("iterable", [[np.zeros([2, 5, 4]), np.zeros([2, 6, 4])]]) - @pytest.mark.parametrize("pad_size", [-1, 0.2, 0, 1, 2, 3, 5, 6]) + @pytest.mark.parametrize("pad_size", [-1, 0.2, 0, 1, 3, 5]) def test_padding_nan_acausal(self, pad_size, iterable): raise_exception = (not isinstance(pad_size, int)) or (pad_size <= 0) if raise_exception: @@ -170,7 +174,10 @@ def test_padding_nan_acausal(self, pad_size, iterable): else: init_nan, end_nan = pad_size // 2, pad_size - pad_size // 2 - padded = utils.nan_pad(iterable, pad_size, "acausal") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + padded = utils.nan_pad(iterable, pad_size, "acausal") for trial in padded: print(trial.shape, pad_size) assert all(np.isnan(trial[:init_nan]).all() for trial in padded), ( @@ -252,8 +259,11 @@ def test_nan_pad_conv_dtype(self, dtype, expectation): ], ) def test_axis_compatibility(self, pad_size, array, causality, axis, expectation): - with expectation: - utils.nan_pad(array, pad_size, causality, axis=axis) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + with expectation: + utils.nan_pad(array, pad_size, causality, axis=axis) @pytest.mark.parametrize("causality", ["causal", "acausal", "anti-causal"]) @pytest.mark.parametrize( @@ -273,8 +283,11 @@ def test_axis_compatibility(self, pad_size, array, causality, axis, expectation) ) @pytest.mark.parametrize("array", [jnp.zeros((10,)), np.zeros((10, 11))]) def test_pad_size_type(self, pad_size, array, causality, expectation): - with expectation: - utils.nan_pad(array, pad_size, causality, axis=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="With acausal filter, pad_size should probably be even") + with expectation: + utils.nan_pad(array, pad_size, causality, axis=0) @pytest.mark.parametrize( "causality, pad_size, expectation", diff --git a/tox.ini b/tox.ini index d7761318..277799b6 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ isolated_build = True envlist = py310, py311, py312 + [testenv] # means we'll run the equivalent of `pip install .[dev]`, also installing pytest # and the linters from pyproject.toml @@ -19,8 +20,10 @@ commands = isort docs/background --profile=black isort docs/tutorials --profile=black flake8 --config={toxinidir}/tox.ini src + pytest --doctest-modules src/nemos/ pytest --cov=nemos --cov-config=pyproject.toml --cov-report=xml + [gh-actions] python = 3.10: py310